From bc925b73a65ad57a636e6e6cb5648e4aed027af5 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Sat, 26 Mar 2016 20:09:01 -0700 Subject: [SPARK-14157][SQL] Parse Drop Function DDL command ## What changes were proposed in this pull request? JIRA: https://issues.apache.org/jira/browse/SPARK-14157 We only parse create function command. In order to support native drop function command, we need to parse it too. From Hive [manual](https://cwiki.apache.org/confluence/display/Hive/LanguageManual+DDL#LanguageManualDDL-Create/Drop/ReloadFunction), the drop function command has syntax as: DROP [TEMPORARY] FUNCTION [IF EXISTS] function_name; ## How was this patch tested? Added test into `DDLCommandSuite`. Author: Liang-Chi Hsieh Closes #11959 from viirya/parse-drop-func. --- .../org/apache/spark/sql/execution/SparkQl.scala | 38 +++++++++++++++++--- .../apache/spark/sql/execution/command/ddl.scala | 16 ++++++++- .../sql/execution/command/DDLCommandSuite.scala | 42 +++++++++++++++++++++- .../scala/org/apache/spark/sql/hive/HiveQl.scala | 1 - 4 files changed, 90 insertions(+), 7 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkQl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkQl.scala index b9542c7173..c78b9b429c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkQl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkQl.scala @@ -165,11 +165,11 @@ private[sql] class SparkQl(conf: ParserConf = SimpleParserConf()) extends Cataly case _ => parseFailed("Invalid CREATE FUNCTION command", node) } // If database name is specified, there are 3 tokens, otherwise 2. - val (funcName, alias) = funcNameArgs match { + val (dbName, funcName, alias) = funcNameArgs match { case Token(dbName, Nil) :: Token(fname, Nil) :: Token(aname, Nil) :: Nil => - (unquoteString(dbName) + "." + unquoteString(fname), unquoteString(aname)) + (Some(unquoteString(dbName)), unquoteString(fname), unquoteString(aname)) case Token(fname, Nil) :: Token(aname, Nil) :: Nil => - (unquoteString(fname), unquoteString(aname)) + (None, unquoteString(fname), unquoteString(aname)) case _ => parseFailed("Invalid CREATE FUNCTION command", node) } @@ -190,7 +190,37 @@ private[sql] class SparkQl(conf: ParserConf = SimpleParserConf()) extends Cataly } case _ => parseFailed("Invalid CREATE FUNCTION command", node) } - CreateFunction(funcName, alias, resources, temp.isDefined)(node.source) + CreateFunction(dbName, funcName, alias, resources, temp.isDefined)(node.source) + + // DROP [TEMPORARY] FUNCTION [IF EXISTS] function_name; + case Token("TOK_DROPFUNCTION", args) => + // Example format: + // + // TOK_DROPFUNCTION + // :- db_name + // :- func_name + // :- TOK_IFEXISTS + // +- TOK_TEMPORARY + val (funcNameArgs, otherArgs) = args.partition { + case Token("TOK_IFEXISTS", _) => false + case Token("TOK_TEMPORARY", _) => false + case Token(_, Nil) => true + case _ => parseFailed("Invalid DROP FUNCTION command", node) + } + // If database name is specified, there are 2 tokens, otherwise 1. + val (dbName, funcName) = funcNameArgs match { + case Token(dbName, Nil) :: Token(fname, Nil) :: Nil => + (Some(unquoteString(dbName)), unquoteString(fname)) + case Token(fname, Nil) :: Nil => + (None, unquoteString(fname)) + case _ => + parseFailed("Invalid DROP FUNCTION command", node) + } + + val Seq(ifExists, temp) = getClauses(Seq( + "TOK_IFEXISTS", "TOK_TEMPORARY"), otherArgs) + + DropFunction(dbName, funcName, ifExists.isDefined, temp.isDefined)(node.source) case Token("TOK_ALTERTABLE", alterTableArgs) => AlterTableCommandParser.parse(node) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala index 373b557683..a0f5b75284 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala @@ -19,7 +19,8 @@ package org.apache.spark.sql.execution.command import org.apache.spark.internal.Logging import org.apache.spark.sql.{Row, SQLContext} -import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} +import org.apache.spark.sql.catalyst.catalog.CatalogFunction import org.apache.spark.sql.catalyst.catalog.ExternalCatalog.TablePartitionSpec import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} import org.apache.spark.sql.execution.datasources.BucketSpec @@ -71,12 +72,25 @@ case class DropDatabase( extends NativeDDLCommand(sql) with Logging case class CreateFunction( + databaseName: Option[String], functionName: String, alias: String, resources: Seq[(String, String)], isTemp: Boolean)(sql: String) extends NativeDDLCommand(sql) with Logging +/** + * The DDL command that drops a function. + * ifExists: returns an error if the function doesn't exist, unless this is true. + * isTemp: indicates if it is a temporary function. + */ +case class DropFunction( + databaseName: Option[String], + functionName: String, + ifExists: Boolean, + isTemp: Boolean)(sql: String) + extends NativeDDLCommand(sql) with Logging + case class AlterTableRename( oldName: TableIdentifier, newName: TableIdentifier)(sql: String) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala index a33175aa60..18f48ffa94 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala @@ -117,12 +117,14 @@ class DDLCommandSuite extends PlanTest { val parsed1 = parser.parsePlan(sql1) val parsed2 = parser.parsePlan(sql2) val expected1 = CreateFunction( + None, "helloworld", "com.matthewrathbone.example.SimpleUDFExample", Seq(("jar", "/path/to/jar1"), ("jar", "/path/to/jar2")), isTemp = true)(sql1) val expected2 = CreateFunction( - "hello.world", + Some("hello"), + "world", "com.matthewrathbone.example.SimpleUDFExample", Seq(("archive", "/path/to/archive"), ("file", "/path/to/file")), isTemp = false)(sql2) @@ -130,6 +132,44 @@ class DDLCommandSuite extends PlanTest { comparePlans(parsed2, expected2) } + test("drop function") { + val sql1 = "DROP TEMPORARY FUNCTION helloworld" + val sql2 = "DROP TEMPORARY FUNCTION IF EXISTS helloworld" + val sql3 = "DROP FUNCTION hello.world" + val sql4 = "DROP FUNCTION IF EXISTS hello.world" + + val parsed1 = parser.parsePlan(sql1) + val parsed2 = parser.parsePlan(sql2) + val parsed3 = parser.parsePlan(sql3) + val parsed4 = parser.parsePlan(sql4) + + val expected1 = DropFunction( + None, + "helloworld", + ifExists = false, + isTemp = true)(sql1) + val expected2 = DropFunction( + None, + "helloworld", + ifExists = true, + isTemp = true)(sql2) + val expected3 = DropFunction( + Some("hello"), + "world", + ifExists = false, + isTemp = false)(sql3) + val expected4 = DropFunction( + Some("hello"), + "world", + ifExists = true, + isTemp = false)(sql4) + + comparePlans(parsed1, expected1) + comparePlans(parsed2, expected2) + comparePlans(parsed3, expected3) + comparePlans(parsed4, expected4) + } + test("alter table: rename table") { val sql = "ALTER TABLE table_name RENAME TO new_table_name" val parsed = parser.parsePlan(sql) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index 6586b90377..61fe0985c1 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -102,7 +102,6 @@ private[hive] class HiveQl(conf: ParserConf) extends SparkQl(conf) with Logging "TOK_DESCDATABASE", - "TOK_DROPFUNCTION", "TOK_DROPINDEX", "TOK_DROPMACRO", "TOK_DROPROLE", -- cgit v1.2.3 From a01b6a92b5f0287a5236bddb1b817d13f320d489 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Sat, 26 Mar 2016 20:12:30 -0700 Subject: [SPARK-14177][SQL] Native Parsing for DDL Command "Describe Database" and "Alter Database" #### What changes were proposed in this pull request? This PR is to provide native parsing support for two DDL commands: ```Describe Database``` and ```Alter Database Set Properties``` Based on the Hive DDL document: https://cwiki.apache.org/confluence/display/Hive/LanguageManual+DDL ##### 1. ALTER DATABASE **Syntax:** ```SQL ALTER (DATABASE|SCHEMA) database_name SET DBPROPERTIES (property_name=property_value, ...) ``` - `ALTER DATABASE` is to add new (key, value) pairs into `DBPROPERTIES` ##### 2. DESCRIBE DATABASE **Syntax:** ```SQL DESCRIBE DATABASE [EXTENDED] db_name ``` - `DESCRIBE DATABASE` shows the name of the database, its comment (if one has been set), and its root location on the filesystem. When `extended` is true, it also shows the database's properties #### How was this patch tested? Added the related test cases to `DDLCommandSuite` Author: gatorsmile Author: xiaoli Author: Xiao Li This patch had conflicts when merged, resolved by Committer: Yin Huai Closes #11977 from gatorsmile/parseAlterDatabase. --- .../org/apache/spark/sql/execution/SparkQl.scala | 27 +++++++++++++++ .../apache/spark/sql/execution/command/ddl.scala | 15 +++++++++ .../sql/execution/command/DDLCommandSuite.scala | 38 ++++++++++++++++++++++ .../scala/org/apache/spark/sql/hive/HiveQl.scala | 3 -- 4 files changed, 80 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkQl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkQl.scala index c78b9b429c..d4d1992d27 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkQl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkQl.scala @@ -142,6 +142,33 @@ private[sql] class SparkQl(conf: ParserConf = SimpleParserConf()) extends Cataly "TOK_IFEXISTS", "TOK_RESTRICT", "TOK_CASCADE"), otherArgs) DropDatabase(databaseName, ifExists.isDefined, restrict = cascade.isEmpty)(node.source) + // ALTER (DATABASE|SCHEMA) database_name SET DBPROPERTIES (property_name=property_value, ...) + case Token("TOK_ALTERDATABASE_PROPERTIES", Token(dbName, Nil) :: args) => + val databaseName = unquoteString(dbName) + val dbprops = getClause("TOK_DATABASEPROPERTIES", args) + val props = dbprops match { + case Token("TOK_DATABASEPROPERTIES", Token("TOK_DBPROPLIST", propList) :: Nil) => + // Example format: + // + // TOK_DATABASEPROPERTIES + // +- TOK_DBPROPLIST + // :- TOK_TABLEPROPERTY + // : :- 'k1' + // : +- 'v1' + // :- TOK_TABLEPROPERTY + // :- 'k2' + // +- 'v2' + extractProps(propList, "TOK_TABLEPROPERTY") + case _ => parseFailed("Invalid ALTER DATABASE command", node) + } + AlterDatabaseProperties(databaseName, props.toMap)(node.source) + + // DESCRIBE DATABASE [EXTENDED] db_name + case Token("TOK_DESCDATABASE", Token(dbName, Nil) :: describeArgs) => + val databaseName = unquoteString(dbName) + val extended = getClauseOption("EXTENDED", describeArgs) + DescribeDatabase(databaseName, extended.isDefined)(node.source) + // CREATE [TEMPORARY] FUNCTION [db_name.]function_name AS class_name // [USING JAR|FILE|ARCHIVE 'file_uri' [, JAR|FILE|ARCHIVE 'file_uri'] ]; case Token("TOK_CREATEFUNCTION", args) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala index a0f5b75284..0e51abb44b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala @@ -71,6 +71,21 @@ case class DropDatabase( restrict: Boolean)(sql: String) extends NativeDDLCommand(sql) with Logging +/** ALTER DATABASE: add new (key, value) pairs into DBPROPERTIES */ +case class AlterDatabaseProperties( + databaseName: String, + props: Map[String, String])(sql: String) + extends NativeDDLCommand(sql) with Logging + +/** + * DESCRIBE DATABASE: shows the name of the database, its comment (if one has been set), and its + * root location on the filesystem. When extended is true, it also shows the database's properties + */ +case class DescribeDatabase( + databaseName: String, + extended: Boolean)(sql: String) + extends NativeDDLCommand(sql) with Logging + case class CreateFunction( databaseName: Option[String], functionName: String, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala index 18f48ffa94..7a6343748b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala @@ -101,6 +101,44 @@ class DDLCommandSuite extends PlanTest { comparePlans(parsed7, expected7) } + test("alter database set dbproperties") { + // ALTER (DATABASE|SCHEMA) database_name SET DBPROPERTIES (property_name=property_value, ...) + val sql1 = "ALTER DATABASE database_name SET DBPROPERTIES ('a'='a', 'b'='b', 'c'='c')" + val sql2 = "ALTER SCHEMA database_name SET DBPROPERTIES ('a'='a')" + + val parsed1 = parser.parsePlan(sql1) + val parsed2 = parser.parsePlan(sql2) + + val expected1 = AlterDatabaseProperties( + "database_name", + Map("a" -> "a", "b" -> "b", "c" -> "c"))(sql1) + val expected2 = AlterDatabaseProperties( + "database_name", + Map("a" -> "a"))(sql2) + + comparePlans(parsed1, expected1) + comparePlans(parsed2, expected2) + } + + test("describe database") { + // DESCRIBE DATABASE [EXTENDED] db_name; + val sql1 = "DESCRIBE DATABASE EXTENDED db_name" + val sql2 = "DESCRIBE DATABASE db_name" + + val parsed1 = parser.parsePlan(sql1) + val parsed2 = parser.parsePlan(sql2) + + val expected1 = DescribeDatabase( + "db_name", + extended = true)(sql1) + val expected2 = DescribeDatabase( + "db_name", + extended = false)(sql2) + + comparePlans(parsed1, expected1) + comparePlans(parsed2, expected2) + } + test("create function") { val sql1 = """ diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index 61fe0985c1..e5bcb9b1db 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -85,7 +85,6 @@ private[hive] class HiveQl(conf: ParserConf) extends SparkQl(conf) with Logging protected val nativeCommands = Seq( "TOK_ALTERDATABASE_OWNER", - "TOK_ALTERDATABASE_PROPERTIES", "TOK_ALTERINDEX_PROPERTIES", "TOK_ALTERINDEX_REBUILD", "TOK_ALTERTABLE_ALTERPARTS", @@ -100,8 +99,6 @@ private[hive] class HiveQl(conf: ParserConf) extends SparkQl(conf) with Logging "TOK_CREATEMACRO", "TOK_CREATEROLE", - "TOK_DESCDATABASE", - "TOK_DROPINDEX", "TOK_DROPMACRO", "TOK_DROPROLE", -- cgit v1.2.3 From cfcca732b403b1af406c2507f3efab928e8b9c6c Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Sun, 27 Mar 2016 20:06:02 +0100 Subject: [MINOR][SQL] Fix substr/substring testcases. ## What changes were proposed in this pull request? This PR fixes the following two testcases in order to test the correct usages. ``` checkSqlGeneration("SELECT substr('This is a test', 'is')") checkSqlGeneration("SELECT substring('This is a test', 'is')") ``` Actually, the testcases works but tests on exceptional cases. ``` scala> sql("SELECT substr('This is a test', 'is')") res0: org.apache.spark.sql.DataFrame = [substring(This is a test, CAST(is AS INT), 2147483647): string] scala> sql("SELECT substr('This is a test', 'is')").collect() res1: Array[org.apache.spark.sql.Row] = Array([null]) ``` ## How was this patch tested? Pass the modified unit tests. Author: Dongjoon Hyun Closes #11963 from dongjoon-hyun/fix_substr_testcase. --- .../test/scala/org/apache/spark/sql/hive/ExpressionToSQLSuite.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ExpressionToSQLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ExpressionToSQLSuite.scala index 75930086ff..bf85d71c66 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ExpressionToSQLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ExpressionToSQLSuite.scala @@ -213,8 +213,8 @@ class ExpressionToSQLSuite extends SQLBuilderTest with SQLTestUtils { checkSqlGeneration("SELECT space(2)") checkSqlGeneration("SELECT split('aa2bb3cc', '[1-9]+')") checkSqlGeneration("SELECT space(2)") - checkSqlGeneration("SELECT substr('This is a test', 'is')") - checkSqlGeneration("SELECT substring('This is a test', 'is')") + checkSqlGeneration("SELECT substr('This is a test', 1)") + checkSqlGeneration("SELECT substring('This is a test', 1)") checkSqlGeneration("SELECT substring_index('www.apache.org','.',1)") checkSqlGeneration("SELECT translate('translate', 'rnlt', '123')") checkSqlGeneration("SELECT trim(' SparkSql ')") -- cgit v1.2.3 From 0f02a5c6e63a95f910e6aba572729ca8085ac3ab Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Sun, 27 Mar 2016 20:07:31 +0100 Subject: [MINOR][MLLIB] Remove TODO comment DecisionTreeModel.scala ## What changes were proposed in this pull request? This PR fixes the following line and the related code. Historically, this code was added in [SPARK-5597](https://issues.apache.org/jira/browse/SPARK-5597). After [SPARK-5597](https://issues.apache.org/jira/browse/SPARK-5597) was committed, [SPARK-3365](https://issues.apache.org/jira/browse/SPARK-3365) is fixed now. Now, we had better remove the comment without changing persistent code. ```scala - categories: Seq[Double]) { // TODO: Change to List once SPARK-3365 is fixed + categories: Seq[Double]) { ``` ## How was this patch tested? Pass the Jenkins tests. Author: Dongjoon Hyun Closes #11966 from dongjoon-hyun/change_categories_type. --- .../scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala index ea68ff64a8..a87f8a6cde 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala @@ -156,7 +156,7 @@ object DecisionTreeModel extends Loader[DecisionTreeModel] with Logging { feature: Int, threshold: Double, featureType: Int, - categories: Seq[Double]) { // TODO: Change to List once SPARK-3365 is fixed + categories: Seq[Double]) { def toSplit: Split = { new Split(feature, threshold, FeatureType(featureType), categories.toList) } -- cgit v1.2.3 From 8ef493760f58687df766d03ccf64039635a2609f Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Sun, 27 Mar 2016 19:04:18 -0700 Subject: [SPARK-10691][ML] Make LogisticRegressionModel, LinearRegressionModel evaluate() public ## What changes were proposed in this pull request? Made evaluate method public. Fixed LogisticRegressionModel evaluate to handle case when probabilityCol is not specified. ## How was this patch tested? There were already unit tests for these methods. Author: Joseph K. Bradley Closes #11928 from jkbradley/public-evaluate. --- .../apache/spark/ml/classification/LogisticRegression.scala | 12 +++++++----- .../org/apache/spark/ml/regression/LinearRegression.scala | 8 ++++---- 2 files changed, 11 insertions(+), 9 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index 861b1d4b66..3d1d5b6892 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -539,13 +539,15 @@ class LogisticRegressionModel private[spark] ( def hasSummary: Boolean = trainingSummary.isDefined /** - * Evaluates the model on a testset. + * Evaluates the model on a test dataset. * @param dataset Test dataset to evaluate model on. */ - // TODO: decide on a good name before exposing to public API - private[classification] def evaluate(dataset: DataFrame): LogisticRegressionSummary = { - new BinaryLogisticRegressionSummary( - this.transform(dataset), $(probabilityCol), $(labelCol), $(featuresCol)) + @Since("2.0.0") + def evaluate(dataset: DataFrame): LogisticRegressionSummary = { + // Handle possible missing or invalid prediction columns + val (summaryModel, probabilityColName) = findSummaryModelAndProbabilityCol() + new BinaryLogisticRegressionSummary(summaryModel.transform(dataset), + probabilityColName, $(labelCol), $(featuresCol)) } /** diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index b81c588e44..5ec02135cc 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -412,15 +412,15 @@ class LinearRegressionModel private[ml] ( def hasSummary: Boolean = trainingSummary.isDefined /** - * Evaluates the model on a testset. + * Evaluates the model on a test dataset. * @param dataset Test dataset to evaluate model on. */ - // TODO: decide on a good name before exposing to public API - private[regression] def evaluate(dataset: DataFrame): LinearRegressionSummary = { + @Since("2.0.0") + def evaluate(dataset: DataFrame): LinearRegressionSummary = { // Handle possible missing or invalid prediction columns val (summaryModel, predictionColName) = findSummaryModelAndPredictionCol() new LinearRegressionSummary(summaryModel.transform(dataset), predictionColName, - $(labelCol), this, Array(0D)) + $(labelCol), summaryModel, Array(0D)) } /** -- cgit v1.2.3 From aac13fb48c8aa7d6816ea46c2e40154913477717 Mon Sep 17 00:00:00 2001 From: Kousuke Saruta Date: Sun, 27 Mar 2016 23:50:23 -0700 Subject: [SPARK-14185][SQL][MINOR] Make indentation of debug log for generated code proper ## What changes were proposed in this pull request? The indentation of debug log output by `CodeGenerator` is weird. The first line of the generated code should be put on the next line of the first line of the log message. ``` 16/03/28 11:10:24 DEBUG CodeGenerator: /* 001 */ /* 002 */ public java.lang.Object generate(Object[] references) { /* 003 */ return new SpecificSafeProjection(references); ... ``` After this patch is applied, we get debug log like as follows. ``` 16/03/28 10:45:50 DEBUG CodeGenerator: /* 001 */ /* 002 */ public java.lang.Object generate(Object[] references) { /* 003 */ return new SpecificSafeProjection(references); ... ``` ## How was this patch tested? Ran some jobs and checked debug logs. Author: Kousuke Saruta Closes #11990 from sarutak/fix-debuglog-indentation. --- .../apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index b511b4b3a0..cd490dd676 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -661,7 +661,7 @@ object CodeGenerator extends Logging { logDebug({ // Only add extra debugging info to byte code when we are going to print the source code. evaluator.setDebuggingInformation(true, true, false) - formatted + s"\n$formatted" }) try { -- cgit v1.2.3 From 7b841540180e8d1403d6c95b02e93f129267b34f Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Mon, 28 Mar 2016 12:01:33 +0100 Subject: [SPARK-12494][MLLIB] Array out of bound Exception in KMeans Yarn Mode ## What changes were proposed in this pull request? Better error message with k-means init can't be enough samples from input (because it is perhaps empty) ## How was this patch tested? Jenkins tests. Author: Sean Owen Closes #11979 from srowen/SPARK-12494. --- mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala | 2 ++ 1 file changed, 2 insertions(+) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala index a7beb81980..37a21cd879 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala @@ -390,6 +390,8 @@ class KMeans private ( // Initialize each run's first center to a random point. val seed = new XORShiftRandom(this.seed).nextInt() val sample = data.takeSample(true, runs, seed).toSeq + // Could be empty if data is empty; fail with a better message early: + require(sample.size >= runs, s"Required $runs samples but got ${sample.size} from $data") val newCenters = Array.tabulate(runs)(r => ArrayBuffer(sample(r).toDense)) /** Merges new centers to centers. */ -- cgit v1.2.3 From b66aa900619a86b7acbb7c3f96abc96ea2faa53c Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Mon, 28 Mar 2016 12:04:21 +0100 Subject: [SPARK-14102][CORE] Block `reset` command in SparkShell ## What changes were proposed in this pull request? Spark Shell provides an easy way to use Spark in Scala environment. This PR adds `reset` command to a blocked list, also cleaned up according to the Scala coding style. ```scala scala> sc res0: org.apache.spark.SparkContext = org.apache.spark.SparkContext718fad24 scala> :reset scala> sc :11: error: not found: value sc sc ^ ``` If we blocks `reset`, Spark Shell works like the followings. ```scala scala> :reset reset: no such command. Type :help for help. scala> :re re is ambiguous: did you mean :replay or :require? ``` ## How was this patch tested? Manual. Run `bin/spark-shell` and type `:reset`. Author: Dongjoon Hyun Closes #11920 from dongjoon-hyun/SPARK-14102. --- .../main/scala/org/apache/spark/repl/SparkILoop.scala | 17 +++++++---------- 1 file changed, 7 insertions(+), 10 deletions(-) diff --git a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoop.scala b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoop.scala index 7ed6d3b1f9..db09d6ace1 100644 --- a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoop.scala +++ b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoop.scala @@ -19,12 +19,11 @@ package org.apache.spark.repl import java.io.BufferedReader -import Predef.{println => _, _} -import scala.util.Properties.{javaVersion, versionString, javaVmName} - -import scala.tools.nsc.interpreter.{JPrintWriter, ILoop} +import scala.Predef.{println => _, _} import scala.tools.nsc.Settings +import scala.tools.nsc.interpreter.{ILoop, JPrintWriter} import scala.tools.nsc.util.stringFromStream +import scala.util.Properties.{javaVersion, javaVmName, versionString} /** * A Spark-specific interactive shell. @@ -75,11 +74,9 @@ class SparkILoop(in0: Option[BufferedReader], out: JPrintWriter) echo("Type :help for more information.") } - import LoopCommand.{ cmd, nullary } - - private val blockedCommands = Set("implicits", "javap", "power", "type", "kind") + private val blockedCommands = Set("implicits", "javap", "power", "type", "kind", "reset") - /** Standard commands **/ + /** Standard commands */ lazy val sparkStandardCommands: List[SparkILoop.this.LoopCommand] = standardCommands.filter(cmd => !blockedCommands(cmd.name)) @@ -112,9 +109,9 @@ object SparkILoop { val output = new JPrintWriter(new OutputStreamWriter(ostream), true) val repl = new SparkILoop(input, output) - if (sets.classpath.isDefault) + if (sets.classpath.isDefault) { sets.classpath.value = sys.props("java.class.path") - + } repl process sets } } -- cgit v1.2.3 From c8388297c436691a236520d2396deaf556aedb0e Mon Sep 17 00:00:00 2001 From: Chenliang Xu Date: Mon, 28 Mar 2016 08:33:37 -0700 Subject: [SPARK-14187][MLLIB] Fix incorrect use of binarySearch in SparseMatrix ## What changes were proposed in this pull request? Fix incorrect use of binarySearch in SparseMatrix ## How was this patch tested? Unit test added. Author: Chenliang Xu Closes #11992 from luckyrandom/SPARK-14187. --- mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala | 2 +- .../src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala | 4 ++++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala index c6de7751f5..a09bc65cf3 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala @@ -613,7 +613,7 @@ class SparseMatrix @Since("1.3.0") ( private[mllib] def update(i: Int, j: Int, v: Double): Unit = { val ind = index(i, j) - if (ind == -1) { + if (ind < 0) { throw new NoSuchElementException("The given row and column indices correspond to a zero " + "value. Only non-zero elements in Sparse Matrices can be updated.") } else { diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala index a02b8c9635..57907f415c 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala @@ -150,6 +150,10 @@ class MatricesSuite extends SparkFunSuite { sparseMat.update(0, 0, 10.0) } + intercept[NoSuchElementException] { + sparseMat.update(2, 1, 10.0) + } + sparseMat.update(0, 1, 10.0) assert(sparseMat(0, 1) === 10.0) assert(sparseMat.values(2) === 10.0) -- cgit v1.2.3 From 68c0c460bfc51d7f69d09b613c49c212dd0b375c Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Mon, 28 Mar 2016 09:58:47 -0700 Subject: [SPARK-13742] [CORE] Add non-iterator interface to RandomSampler JIRA: https://issues.apache.org/jira/browse/SPARK-13742 ## What changes were proposed in this pull request? `RandomSampler.sample` currently accepts iterator as input and output another iterator. This makes it inappropriate to use in wholestage codegen of `Sampler` operator #11517. This change is to add non-iterator interface to `RandomSampler`. This change adds a new method `def sample(): Int` to the trait `RandomSampler`. As we don't need to know the actual values of the sampling items, so this new method takes no arguments. This method will decide whether to sample the next item or not. It returns how many times the next item will be sampled. For `BernoulliSampler` and `BernoulliCellSampler`, the returned sampling times can only be 0 or 1. It simply means whether to sample the next item or not. For `PoissonSampler`, the returned value can be more than 1, meaning the next item will be sampled multiple times. ## How was this patch tested? Tests are added into `RandomSamplerSuite`. Author: Liang-Chi Hsieh Author: Liang-Chi Hsieh Author: Liang-Chi Hsieh Closes #11578 from viirya/random-sampler-no-iterator. --- .../apache/spark/util/random/RandomSampler.scala | 201 +++++++++------------ .../spark/rdd/PartitionwiseSampledRDDSuite.scala | 2 + .../spark/util/random/RandomSamplerSuite.scala | 197 ++++++++++++++++++++ 3 files changed, 289 insertions(+), 111 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala b/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala index 3c61528ab5..2921b939bc 100644 --- a/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala +++ b/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala @@ -39,7 +39,14 @@ import org.apache.spark.annotation.DeveloperApi trait RandomSampler[T, U] extends Pseudorandom with Cloneable with Serializable { /** take a random sample */ - def sample(items: Iterator[T]): Iterator[U] + def sample(items: Iterator[T]): Iterator[U] = + items.filter(_ => sample > 0).asInstanceOf[Iterator[U]] + + /** + * Whether to sample the next item or not. + * Return how many times the next item will be sampled. Return 0 if it is not sampled. + */ + def sample(): Int /** return a copy of the RandomSampler object */ override def clone: RandomSampler[T, U] = @@ -107,21 +114,13 @@ class BernoulliCellSampler[T](lb: Double, ub: Double, complement: Boolean = fals override def setSeed(seed: Long): Unit = rng.setSeed(seed) - override def sample(items: Iterator[T]): Iterator[T] = { + override def sample(): Int = { if (ub - lb <= 0.0) { - if (complement) items else Iterator.empty + if (complement) 1 else 0 } else { - if (complement) { - items.filter { item => { - val x = rng.nextDouble() - (x < lb) || (x >= ub) - }} - } else { - items.filter { item => { - val x = rng.nextDouble() - (x >= lb) && (x < ub) - }} - } + val x = rng.nextDouble() + val n = if ((x >= lb) && (x < ub)) 1 else 0 + if (complement) 1 - n else n } } @@ -155,15 +154,22 @@ class BernoulliSampler[T: ClassTag](fraction: Double) extends RandomSampler[T, T override def setSeed(seed: Long): Unit = rng.setSeed(seed) - override def sample(items: Iterator[T]): Iterator[T] = { + private lazy val gapSampling: GapSampling = + new GapSampling(fraction, rng, RandomSampler.rngEpsilon) + + override def sample(): Int = { if (fraction <= 0.0) { - Iterator.empty + 0 } else if (fraction >= 1.0) { - items + 1 } else if (fraction <= RandomSampler.defaultMaxGapSamplingFraction) { - new GapSamplingIterator(items, fraction, rng, RandomSampler.rngEpsilon) + gapSampling.sample() } else { - items.filter { _ => rng.nextDouble() <= fraction } + if (rng.nextDouble() <= fraction) { + 1 + } else { + 0 + } } } @@ -201,15 +207,29 @@ class PoissonSampler[T: ClassTag]( rngGap.setSeed(seed) } - override def sample(items: Iterator[T]): Iterator[T] = { + private lazy val gapSamplingReplacement = + new GapSamplingReplacement(fraction, rngGap, RandomSampler.rngEpsilon) + + override def sample(): Int = { if (fraction <= 0.0) { - Iterator.empty + 0 } else if (useGapSamplingIfPossible && fraction <= RandomSampler.defaultMaxGapSamplingFraction) { - new GapSamplingReplacementIterator(items, fraction, rngGap, RandomSampler.rngEpsilon) + gapSamplingReplacement.sample() + } else { + rng.sample() + } + } + + override def sample(items: Iterator[T]): Iterator[T] = { + if (fraction <= 0.0) { + Iterator.empty } else { + val useGapSampling = useGapSamplingIfPossible && + fraction <= RandomSampler.defaultMaxGapSamplingFraction + items.flatMap { item => - val count = rng.sample() + val count = if (useGapSampling) gapSamplingReplacement.sample() else rng.sample() if (count == 0) Iterator.empty else Iterator.fill(count)(item) } } @@ -220,50 +240,36 @@ class PoissonSampler[T: ClassTag]( private[spark] -class GapSamplingIterator[T: ClassTag]( - var data: Iterator[T], +class GapSampling( f: Double, rng: Random = RandomSampler.newDefaultRNG, - epsilon: Double = RandomSampler.rngEpsilon) extends Iterator[T] { + epsilon: Double = RandomSampler.rngEpsilon) extends Serializable { require(f > 0.0 && f < 1.0, s"Sampling fraction ($f) must reside on open interval (0, 1)") require(epsilon > 0.0, s"epsilon ($epsilon) must be > 0") - /** implement efficient linear-sequence drop until Scala includes fix for jira SI-8835. */ - private val iterDrop: Int => Unit = { - val arrayClass = Array.empty[T].iterator.getClass - val arrayBufferClass = ArrayBuffer.empty[T].iterator.getClass - data.getClass match { - case `arrayClass` => - (n: Int) => { data = data.drop(n) } - case `arrayBufferClass` => - (n: Int) => { data = data.drop(n) } - case _ => - (n: Int) => { - var j = 0 - while (j < n && data.hasNext) { - data.next() - j += 1 - } - } - } - } - - override def hasNext: Boolean = data.hasNext + private val lnq = math.log1p(-f) - override def next(): T = { - val r = data.next() - advance() - r + /** Return 1 if the next item should be sampled. Otherwise, return 0. */ + def sample(): Int = { + if (countForDropping > 0) { + countForDropping -= 1 + 0 + } else { + advance() + 1 + } } - private val lnq = math.log1p(-f) + private var countForDropping: Int = 0 - /** skip elements that won't be sampled, according to geometric dist P(k) = (f)(1-f)^k. */ + /** + * Decide the number of elements that won't be sampled, + * according to geometric dist P(k) = (f)(1-f)^k. + */ private def advance(): Unit = { val u = math.max(rng.nextDouble(), epsilon) - val k = (math.log(u) / lnq).toInt - iterDrop(k) + countForDropping = (math.log(u) / lnq).toInt } /** advance to first sample as part of object construction. */ @@ -273,73 +279,24 @@ class GapSamplingIterator[T: ClassTag]( // work reliably. } + private[spark] -class GapSamplingReplacementIterator[T: ClassTag]( - var data: Iterator[T], - f: Double, - rng: Random = RandomSampler.newDefaultRNG, - epsilon: Double = RandomSampler.rngEpsilon) extends Iterator[T] { +class GapSamplingReplacement( + val f: Double, + val rng: Random = RandomSampler.newDefaultRNG, + epsilon: Double = RandomSampler.rngEpsilon) extends Serializable { require(f > 0.0, s"Sampling fraction ($f) must be > 0") require(epsilon > 0.0, s"epsilon ($epsilon) must be > 0") - /** implement efficient linear-sequence drop until scala includes fix for jira SI-8835. */ - private val iterDrop: Int => Unit = { - val arrayClass = Array.empty[T].iterator.getClass - val arrayBufferClass = ArrayBuffer.empty[T].iterator.getClass - data.getClass match { - case `arrayClass` => - (n: Int) => { data = data.drop(n) } - case `arrayBufferClass` => - (n: Int) => { data = data.drop(n) } - case _ => - (n: Int) => { - var j = 0 - while (j < n && data.hasNext) { - data.next() - j += 1 - } - } - } - } - - /** current sampling value, and its replication factor, as we are sampling with replacement. */ - private var v: T = _ - private var rep: Int = 0 - - override def hasNext: Boolean = data.hasNext || rep > 0 - - override def next(): T = { - val r = v - rep -= 1 - if (rep <= 0) advance() - r - } - - /** - * Skip elements with replication factor zero (i.e. elements that won't be sampled). - * Samples 'k' from geometric distribution P(k) = (1-q)(q)^k, where q = e^(-f), that is - * q is the probability of Poisson(0; f) - */ - private def advance(): Unit = { - val u = math.max(rng.nextDouble(), epsilon) - val k = (math.log(u) / (-f)).toInt - iterDrop(k) - // set the value and replication factor for the next value - if (data.hasNext) { - v = data.next() - rep = poissonGE1 - } - } - - private val q = math.exp(-f) + protected val q = math.exp(-f) /** * Sample from Poisson distribution, conditioned such that the sampled value is >= 1. * This is an adaptation from the algorithm for Generating Poisson distributed random variables: * http://en.wikipedia.org/wiki/Poisson_distribution */ - private def poissonGE1: Int = { + protected def poissonGE1: Int = { // simulate that the standard poisson sampling // gave us at least one iteration, for a sample of >= 1 var pp = q + ((1.0 - q) * rng.nextDouble()) @@ -353,6 +310,28 @@ class GapSamplingReplacementIterator[T: ClassTag]( } r } + private var countForDropping: Int = 0 + + def sample(): Int = { + if (countForDropping > 0) { + countForDropping -= 1 + 0 + } else { + val r = poissonGE1 + advance() + r + } + } + + /** + * Skip elements with replication factor zero (i.e. elements that won't be sampled). + * Samples 'k' from geometric distribution P(k) = (1-q)(q)^k, where q = e^(-f), that is + * q is the probabililty of Poisson(0; f) + */ + private def advance(): Unit = { + val u = math.max(rng.nextDouble(), epsilon) + countForDropping = (math.log(u) / (-f)).toInt + } /** advance to first sample as part of object construction. */ advance() diff --git a/core/src/test/scala/org/apache/spark/rdd/PartitionwiseSampledRDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/PartitionwiseSampledRDDSuite.scala index 132a5fa9a8..cb0de1c6be 100644 --- a/core/src/test/scala/org/apache/spark/rdd/PartitionwiseSampledRDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/PartitionwiseSampledRDDSuite.scala @@ -29,6 +29,8 @@ class MockSampler extends RandomSampler[Long, Long] { s = seed } + override def sample(): Int = 1 + override def sample(items: Iterator[Long]): Iterator[Long] = { Iterator(s) } diff --git a/core/src/test/scala/org/apache/spark/util/random/RandomSamplerSuite.scala b/core/src/test/scala/org/apache/spark/util/random/RandomSamplerSuite.scala index 791491daf0..7eb2f56c20 100644 --- a/core/src/test/scala/org/apache/spark/util/random/RandomSamplerSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/random/RandomSamplerSuite.scala @@ -129,6 +129,13 @@ class RandomSamplerSuite extends SparkFunSuite with Matchers { t(m / 2) } + def replacementSampling(data: Iterator[Int], sampler: PoissonSampler[Int]): Iterator[Int] = { + data.flatMap { item => + val count = sampler.sample() + if (count == 0) Iterator.empty else Iterator.fill(count)(item) + } + } + test("utilities") { val s1 = Array(0, 1, 1, 0, 2) val s2 = Array(1, 0, 3, 2, 1) @@ -189,6 +196,36 @@ class RandomSamplerSuite extends SparkFunSuite with Matchers { d should be > D } + test("bernoulli sampling without iterator") { + // Tests expect maximum gap sampling fraction to be this value + RandomSampler.defaultMaxGapSamplingFraction should be (0.4) + + var d: Double = 0.0 + + val data = Iterator.from(0) + + var sampler: RandomSampler[Int, Int] = new BernoulliSampler[Int](0.5) + sampler.setSeed(rngSeed.nextLong) + d = medianKSD(gaps(data.filter(_ => sampler.sample() > 0)), gaps(sample(Iterator.from(0), 0.5))) + d should be < D + + sampler = new BernoulliSampler[Int](0.7) + sampler.setSeed(rngSeed.nextLong) + d = medianKSD(gaps(data.filter(_ => sampler.sample() > 0)), gaps(sample(Iterator.from(0), 0.7))) + d should be < D + + sampler = new BernoulliSampler[Int](0.9) + sampler.setSeed(rngSeed.nextLong) + d = medianKSD(gaps(data.filter(_ => sampler.sample() > 0)), gaps(sample(Iterator.from(0), 0.9))) + d should be < D + + // sampling at different frequencies should show up as statistically different: + sampler = new BernoulliSampler[Int](0.5) + sampler.setSeed(rngSeed.nextLong) + d = medianKSD(gaps(data.filter(_ => sampler.sample() > 0)), gaps(sample(Iterator.from(0), 0.6))) + d should be > D + } + test("bernoulli sampling with gap sampling optimization") { // Tests expect maximum gap sampling fraction to be this value RandomSampler.defaultMaxGapSamplingFraction should be (0.4) @@ -217,6 +254,37 @@ class RandomSamplerSuite extends SparkFunSuite with Matchers { d should be > D } + test("bernoulli sampling (without iterator) with gap sampling optimization") { + // Tests expect maximum gap sampling fraction to be this value + RandomSampler.defaultMaxGapSamplingFraction should be (0.4) + + var d: Double = 0.0 + + val data = Iterator.from(0) + + var sampler: RandomSampler[Int, Int] = new BernoulliSampler[Int](0.01) + sampler.setSeed(rngSeed.nextLong) + d = medianKSD(gaps(data.filter(_ => sampler.sample() > 0)), + gaps(sample(Iterator.from(0), 0.01))) + d should be < D + + sampler = new BernoulliSampler[Int](0.1) + sampler.setSeed(rngSeed.nextLong) + d = medianKSD(gaps(data.filter(_ => sampler.sample() > 0)), gaps(sample(Iterator.from(0), 0.1))) + d should be < D + + sampler = new BernoulliSampler[Int](0.3) + sampler.setSeed(rngSeed.nextLong) + d = medianKSD(gaps(data.filter(_ => sampler.sample() > 0)), gaps(sample(Iterator.from(0), 0.3))) + d should be < D + + // sampling at different frequencies should show up as statistically different: + sampler = new BernoulliSampler[Int](0.3) + sampler.setSeed(rngSeed.nextLong) + d = medianKSD(gaps(data.filter(_ => sampler.sample() > 0)), gaps(sample(Iterator.from(0), 0.4))) + d should be > D + } + test("bernoulli boundary cases") { val data = (1 to 100).toArray @@ -233,6 +301,22 @@ class RandomSamplerSuite extends SparkFunSuite with Matchers { sampler.sample(data.iterator).toArray should be (data) } + test("bernoulli (without iterator) boundary cases") { + val data = (1 to 100).toArray + + var sampler = new BernoulliSampler[Int](0.0) + data.filter(_ => sampler.sample() > 0) should be (Array.empty[Int]) + + sampler = new BernoulliSampler[Int](1.0) + data.filter(_ => sampler.sample() > 0) should be (data) + + sampler = new BernoulliSampler[Int](0.0 - (RandomSampler.roundingEpsilon / 2.0)) + data.filter(_ => sampler.sample() > 0) should be (Array.empty[Int]) + + sampler = new BernoulliSampler[Int](1.0 + (RandomSampler.roundingEpsilon / 2.0)) + data.filter(_ => sampler.sample() > 0) should be (data) + } + test("bernoulli data types") { // Tests expect maximum gap sampling fraction to be this value RandomSampler.defaultMaxGapSamplingFraction should be (0.4) @@ -341,6 +425,36 @@ class RandomSamplerSuite extends SparkFunSuite with Matchers { d should be > D } + test("replacement sampling without iterator") { + // Tests expect maximum gap sampling fraction to be this value + RandomSampler.defaultMaxGapSamplingFraction should be (0.4) + + var d: Double = 0.0 + + val data = Iterator.from(0) + + var sampler = new PoissonSampler[Int](0.5) + sampler.setSeed(rngSeed.nextLong) + d = medianKSD(gaps(replacementSampling(data, sampler)), gaps(sampleWR(Iterator.from(0), 0.5))) + d should be < D + + sampler = new PoissonSampler[Int](0.7) + sampler.setSeed(rngSeed.nextLong) + d = medianKSD(gaps(replacementSampling(data, sampler)), gaps(sampleWR(Iterator.from(0), 0.7))) + d should be < D + + sampler = new PoissonSampler[Int](0.9) + sampler.setSeed(rngSeed.nextLong) + d = medianKSD(gaps(replacementSampling(data, sampler)), gaps(sampleWR(Iterator.from(0), 0.9))) + d should be < D + + // sampling at different frequencies should show up as statistically different: + sampler = new PoissonSampler[Int](0.5) + sampler.setSeed(rngSeed.nextLong) + d = medianKSD(gaps(replacementSampling(data, sampler)), gaps(sampleWR(Iterator.from(0), 0.6))) + d should be > D + } + test("replacement sampling with gap sampling") { // Tests expect maximum gap sampling fraction to be this value RandomSampler.defaultMaxGapSamplingFraction should be (0.4) @@ -369,6 +483,36 @@ class RandomSamplerSuite extends SparkFunSuite with Matchers { d should be > D } + test("replacement sampling (without iterator) with gap sampling") { + // Tests expect maximum gap sampling fraction to be this value + RandomSampler.defaultMaxGapSamplingFraction should be (0.4) + + var d: Double = 0.0 + + val data = Iterator.from(0) + + var sampler = new PoissonSampler[Int](0.01) + sampler.setSeed(rngSeed.nextLong) + d = medianKSD(gaps(replacementSampling(data, sampler)), gaps(sampleWR(Iterator.from(0), 0.01))) + d should be < D + + sampler = new PoissonSampler[Int](0.1) + sampler.setSeed(rngSeed.nextLong) + d = medianKSD(gaps(replacementSampling(data, sampler)), gaps(sampleWR(Iterator.from(0), 0.1))) + d should be < D + + sampler = new PoissonSampler[Int](0.3) + sampler.setSeed(rngSeed.nextLong) + d = medianKSD(gaps(replacementSampling(data, sampler)), gaps(sampleWR(Iterator.from(0), 0.3))) + d should be < D + + // sampling at different frequencies should show up as statistically different: + sampler = new PoissonSampler[Int](0.3) + sampler.setSeed(rngSeed.nextLong) + d = medianKSD(gaps(replacementSampling(data, sampler)), gaps(sampleWR(Iterator.from(0), 0.4))) + d should be > D + } + test("replacement boundary cases") { val data = (1 to 100).toArray @@ -383,6 +527,20 @@ class RandomSamplerSuite extends SparkFunSuite with Matchers { sampler.sample(data.iterator).length should be > (data.length) } + test("replacement (without) boundary cases") { + val data = (1 to 100).toArray + + var sampler = new PoissonSampler[Int](0.0) + replacementSampling(data.iterator, sampler).toArray should be (Array.empty[Int]) + + sampler = new PoissonSampler[Int](0.0 - (RandomSampler.roundingEpsilon / 2.0)) + replacementSampling(data.iterator, sampler).toArray should be (Array.empty[Int]) + + // sampling with replacement has no upper bound on sampling fraction + sampler = new PoissonSampler[Int](2.0) + replacementSampling(data.iterator, sampler).length should be > (data.length) + } + test("replacement data types") { // Tests expect maximum gap sampling fraction to be this value RandomSampler.defaultMaxGapSamplingFraction should be (0.4) @@ -477,6 +635,22 @@ class RandomSamplerSuite extends SparkFunSuite with Matchers { d should be < D } + test("bernoulli partitioning sampling without iterator") { + var d: Double = 0.0 + + val data = Iterator.from(0) + + var sampler = new BernoulliCellSampler[Int](0.1, 0.2) + sampler.setSeed(rngSeed.nextLong) + d = medianKSD(gaps(data.filter(_ => sampler.sample() > 0)), gaps(sample(Iterator.from(0), 0.1))) + d should be < D + + sampler = new BernoulliCellSampler[Int](0.1, 0.2, true) + sampler.setSeed(rngSeed.nextLong) + d = medianKSD(gaps(data.filter(_ => sampler.sample() > 0)), gaps(sample(Iterator.from(0), 0.9))) + d should be < D + } + test("bernoulli partitioning boundary cases") { val data = (1 to 100).toArray val d = RandomSampler.roundingEpsilon / 2.0 @@ -500,6 +674,29 @@ class RandomSamplerSuite extends SparkFunSuite with Matchers { sampler.sample(data.iterator).toArray should be (Array.empty[Int]) } + test("bernoulli partitioning (without iterator) boundary cases") { + val data = (1 to 100).toArray + val d = RandomSampler.roundingEpsilon / 2.0 + + var sampler = new BernoulliCellSampler[Int](0.0, 0.0) + data.filter(_ => sampler.sample() > 0).toArray should be (Array.empty[Int]) + + sampler = new BernoulliCellSampler[Int](0.5, 0.5) + data.filter(_ => sampler.sample() > 0).toArray should be (Array.empty[Int]) + + sampler = new BernoulliCellSampler[Int](1.0, 1.0) + data.filter(_ => sampler.sample() > 0).toArray should be (Array.empty[Int]) + + sampler = new BernoulliCellSampler[Int](0.0, 1.0) + data.filter(_ => sampler.sample() > 0).toArray should be (data) + + sampler = new BernoulliCellSampler[Int](0.0 - d, 1.0 + d) + data.filter(_ => sampler.sample() > 0).toArray should be (data) + + sampler = new BernoulliCellSampler[Int](0.5, 0.5 - d) + data.filter(_ => sampler.sample() > 0).toArray should be (Array.empty[Int]) + } + test("bernoulli partitioning data") { val seed = rngSeed.nextLong val data = (1 to 100).toArray -- cgit v1.2.3 From 40984f67065eeaea731940008e6677c2323dda3e Mon Sep 17 00:00:00 2001 From: Sun Rui Date: Mon, 28 Mar 2016 10:14:28 -0700 Subject: [SPARK-12792] [SPARKR] Refactor RRDD to support R UDF. Refactor RRDD by separating the common logic interacting with the R worker to a new class RRunner, which can be used to evaluate R UDFs. Now RRDD relies on RRuner for RDD computation and RRDD could be reomved if we want to remove RDD API in SparkR later. Author: Sun Rui Closes #10947 from sun-rui/SPARK-12792. --- R/pkg/inst/tests/testthat/test_rdd.R | 8 + .../main/scala/org/apache/spark/api/r/RRDD.scala | 328 +----------------- .../scala/org/apache/spark/api/r/RRunner.scala | 367 +++++++++++++++++++++ 3 files changed, 379 insertions(+), 324 deletions(-) create mode 100644 core/src/main/scala/org/apache/spark/api/r/RRunner.scala diff --git a/R/pkg/inst/tests/testthat/test_rdd.R b/R/pkg/inst/tests/testthat/test_rdd.R index 3b0c16be5a..b6c8e1dc6c 100644 --- a/R/pkg/inst/tests/testthat/test_rdd.R +++ b/R/pkg/inst/tests/testthat/test_rdd.R @@ -791,3 +791,11 @@ test_that("sampleByKey() on pairwise RDDs", { expect_equal(lookup(sample, 3)[which.min(lookup(sample, 3))] >= 0, TRUE) expect_equal(lookup(sample, 3)[which.max(lookup(sample, 3))] <= 2000, TRUE) }) + +test_that("Test correct concurrency of RRDD.compute()", { + rdd <- parallelize(sc, 1:1000, 100) + jrdd <- getJRDD(lapply(rdd, function(x) { x }), "row") + zrdd <- callJMethod(jrdd, "zip", jrdd) + count <- callJMethod(zrdd, "count") + expect_equal(count, 1000) +}) diff --git a/core/src/main/scala/org/apache/spark/api/r/RRDD.scala b/core/src/main/scala/org/apache/spark/api/r/RRDD.scala index 588a57e65f..606ba6ef86 100644 --- a/core/src/main/scala/org/apache/spark/api/r/RRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/r/RRDD.scala @@ -17,21 +17,16 @@ package org.apache.spark.api.r -import java.io._ -import java.net.{InetAddress, ServerSocket} -import java.util.{Arrays, Map => JMap} +import java.util.{Map => JMap} import scala.collection.JavaConverters._ -import scala.io.Source import scala.reflect.ClassTag -import scala.util.Try import org.apache.spark._ import org.apache.spark.api.java.{JavaPairRDD, JavaRDD, JavaSparkContext} import org.apache.spark.broadcast.Broadcast import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD -import org.apache.spark.util.Utils private abstract class BaseRRDD[T: ClassTag, U: ClassTag]( parent: RDD[T], @@ -42,188 +37,16 @@ private abstract class BaseRRDD[T: ClassTag, U: ClassTag]( packageNames: Array[Byte], broadcastVars: Array[Broadcast[Object]]) extends RDD[U](parent) with Logging { - protected var dataStream: DataInputStream = _ - private var bootTime: Double = _ override def getPartitions: Array[Partition] = parent.partitions override def compute(partition: Partition, context: TaskContext): Iterator[U] = { - - // Timing start - bootTime = System.currentTimeMillis / 1000.0 + val runner = new RRunner[U]( + func, deserializer, serializer, packageNames, broadcastVars, numPartitions) // The parent may be also an RRDD, so we should launch it first. val parentIterator = firstParent[T].iterator(partition, context) - // we expect two connections - val serverSocket = new ServerSocket(0, 2, InetAddress.getByName("localhost")) - val listenPort = serverSocket.getLocalPort() - - // The stdout/stderr is shared by multiple tasks, because we use one daemon - // to launch child process as worker. - val errThread = RRDD.createRWorker(listenPort) - - // We use two sockets to separate input and output, then it's easy to manage - // the lifecycle of them to avoid deadlock. - // TODO: optimize it to use one socket - - // the socket used to send out the input of task - serverSocket.setSoTimeout(10000) - val inSocket = serverSocket.accept() - startStdinThread(inSocket.getOutputStream(), parentIterator, partition.index) - - // the socket used to receive the output of task - val outSocket = serverSocket.accept() - val inputStream = new BufferedInputStream(outSocket.getInputStream) - dataStream = new DataInputStream(inputStream) - serverSocket.close() - - try { - - return new Iterator[U] { - def next(): U = { - val obj = _nextObj - if (hasNext) { - _nextObj = read() - } - obj - } - - var _nextObj = read() - - def hasNext(): Boolean = { - val hasMore = (_nextObj != null) - if (!hasMore) { - dataStream.close() - } - hasMore - } - } - } catch { - case e: Exception => - throw new SparkException("R computation failed with\n " + errThread.getLines()) - } - } - - /** - * Start a thread to write RDD data to the R process. - */ - private def startStdinThread[T]( - output: OutputStream, - iter: Iterator[T], - partition: Int): Unit = { - - val env = SparkEnv.get - val taskContext = TaskContext.get() - val bufferSize = System.getProperty("spark.buffer.size", "65536").toInt - val stream = new BufferedOutputStream(output, bufferSize) - - new Thread("writer for R") { - override def run(): Unit = { - try { - SparkEnv.set(env) - TaskContext.setTaskContext(taskContext) - val dataOut = new DataOutputStream(stream) - dataOut.writeInt(partition) - - SerDe.writeString(dataOut, deserializer) - SerDe.writeString(dataOut, serializer) - - dataOut.writeInt(packageNames.length) - dataOut.write(packageNames) - - dataOut.writeInt(func.length) - dataOut.write(func) - - dataOut.writeInt(broadcastVars.length) - broadcastVars.foreach { broadcast => - // TODO(shivaram): Read a Long in R to avoid this cast - dataOut.writeInt(broadcast.id.toInt) - // TODO: Pass a byte array from R to avoid this cast ? - val broadcastByteArr = broadcast.value.asInstanceOf[Array[Byte]] - dataOut.writeInt(broadcastByteArr.length) - dataOut.write(broadcastByteArr) - } - - dataOut.writeInt(numPartitions) - - if (!iter.hasNext) { - dataOut.writeInt(0) - } else { - dataOut.writeInt(1) - } - - val printOut = new PrintStream(stream) - - def writeElem(elem: Any): Unit = { - if (deserializer == SerializationFormats.BYTE) { - val elemArr = elem.asInstanceOf[Array[Byte]] - dataOut.writeInt(elemArr.length) - dataOut.write(elemArr) - } else if (deserializer == SerializationFormats.ROW) { - dataOut.write(elem.asInstanceOf[Array[Byte]]) - } else if (deserializer == SerializationFormats.STRING) { - // write string(for StringRRDD) - // scalastyle:off println - printOut.println(elem) - // scalastyle:on println - } - } - - for (elem <- iter) { - elem match { - case (key, value) => - writeElem(key) - writeElem(value) - case _ => - writeElem(elem) - } - } - stream.flush() - } catch { - // TODO: We should propogate this error to the task thread - case e: Exception => - logError("R Writer thread got an exception", e) - } finally { - Try(output.close()) - } - } - }.start() - } - - protected def readData(length: Int): U - - protected def read(): U = { - try { - val length = dataStream.readInt() - - length match { - case SpecialLengths.TIMING_DATA => - // Timing data from R worker - val boot = dataStream.readDouble - bootTime - val init = dataStream.readDouble - val broadcast = dataStream.readDouble - val input = dataStream.readDouble - val compute = dataStream.readDouble - val output = dataStream.readDouble - logInfo( - ("Times: boot = %.3f s, init = %.3f s, broadcast = %.3f s, " + - "read-input = %.3f s, compute = %.3f s, write-output = %.3f s, " + - "total = %.3f s").format( - boot, - init, - broadcast, - input, - compute, - output, - boot + init + broadcast + input + compute + output)) - read() - case length if length >= 0 => - readData(length) - } - } catch { - case eof: EOFException => - throw new SparkException("R worker exited unexpectedly (cranshed)", eof) - } + runner.compute(parentIterator, partition.index, context) } } @@ -242,19 +65,6 @@ private class PairwiseRRDD[T: ClassTag]( parent, numPartitions, hashFunc, deserializer, SerializationFormats.BYTE, packageNames, broadcastVars.map(x => x.asInstanceOf[Broadcast[Object]])) { - - override protected def readData(length: Int): (Int, Array[Byte]) = { - length match { - case length if length == 2 => - val hashedKey = dataStream.readInt() - val contentPairsLength = dataStream.readInt() - val contentPairs = new Array[Byte](contentPairsLength) - dataStream.readFully(contentPairs) - (hashedKey, contentPairs) - case _ => null - } - } - lazy val asJavaPairRDD : JavaPairRDD[Int, Array[Byte]] = JavaPairRDD.fromRDD(this) } @@ -271,17 +81,6 @@ private class RRDD[T: ClassTag]( extends BaseRRDD[T, Array[Byte]]( parent, -1, func, deserializer, serializer, packageNames, broadcastVars.map(x => x.asInstanceOf[Broadcast[Object]])) { - - override protected def readData(length: Int): Array[Byte] = { - length match { - case length if length > 0 => - val obj = new Array[Byte](length) - dataStream.readFully(obj) - obj - case _ => null - } - } - lazy val asJavaRDD : JavaRDD[Array[Byte]] = JavaRDD.fromRDD(this) } @@ -297,55 +96,10 @@ private class StringRRDD[T: ClassTag]( extends BaseRRDD[T, String]( parent, -1, func, deserializer, SerializationFormats.STRING, packageNames, broadcastVars.map(x => x.asInstanceOf[Broadcast[Object]])) { - - override protected def readData(length: Int): String = { - length match { - case length if length > 0 => - SerDe.readStringBytes(dataStream, length) - case _ => null - } - } - lazy val asJavaRDD : JavaRDD[String] = JavaRDD.fromRDD(this) } -private object SpecialLengths { - val TIMING_DATA = -1 -} - -private[r] class BufferedStreamThread( - in: InputStream, - name: String, - errBufferSize: Int) extends Thread(name) with Logging { - val lines = new Array[String](errBufferSize) - var lineIdx = 0 - override def run() { - for (line <- Source.fromInputStream(in).getLines) { - synchronized { - lines(lineIdx) = line - lineIdx = (lineIdx + 1) % errBufferSize - } - logInfo(line) - } - } - - def getLines(): String = synchronized { - (0 until errBufferSize).filter { x => - lines((x + lineIdx) % errBufferSize) != null - }.map { x => - lines((x + lineIdx) % errBufferSize) - }.mkString("\n") - } -} - private[r] object RRDD { - // Because forking processes from Java is expensive, we prefer to launch - // a single R daemon (daemon.R) and tell it to fork new workers for our tasks. - // This daemon currently only works on UNIX-based systems now, so we should - // also fall back to launching workers (worker.R) directly. - private[this] var errThread: BufferedStreamThread = _ - private[this] var daemonChannel: DataOutputStream = _ - def createSparkContext( master: String, appName: String, @@ -353,7 +107,6 @@ private[r] object RRDD { jars: Array[String], sparkEnvirMap: JMap[Object, Object], sparkExecutorEnvMap: JMap[Object, Object]): JavaSparkContext = { - val sparkConf = new SparkConf().setAppName(appName) .setSparkHome(sparkHome) @@ -380,78 +133,6 @@ private[r] object RRDD { jsc } - /** - * Start a thread to print the process's stderr to ours - */ - private def startStdoutThread(proc: Process): BufferedStreamThread = { - val BUFFER_SIZE = 100 - val thread = new BufferedStreamThread(proc.getInputStream, "stdout reader for R", BUFFER_SIZE) - thread.setDaemon(true) - thread.start() - thread - } - - private def createRProcess(port: Int, script: String): BufferedStreamThread = { - // "spark.sparkr.r.command" is deprecated and replaced by "spark.r.command", - // but kept here for backward compatibility. - val sparkConf = SparkEnv.get.conf - var rCommand = sparkConf.get("spark.sparkr.r.command", "Rscript") - rCommand = sparkConf.get("spark.r.command", rCommand) - - val rOptions = "--vanilla" - val rLibDir = RUtils.sparkRPackagePath(isDriver = false) - val rExecScript = rLibDir(0) + "/SparkR/worker/" + script - val pb = new ProcessBuilder(Arrays.asList(rCommand, rOptions, rExecScript)) - // Unset the R_TESTS environment variable for workers. - // This is set by R CMD check as startup.Rs - // (http://svn.r-project.org/R/trunk/src/library/tools/R/testing.R) - // and confuses worker script which tries to load a non-existent file - pb.environment().put("R_TESTS", "") - pb.environment().put("SPARKR_RLIBDIR", rLibDir.mkString(",")) - pb.environment().put("SPARKR_WORKER_PORT", port.toString) - pb.redirectErrorStream(true) // redirect stderr into stdout - val proc = pb.start() - val errThread = startStdoutThread(proc) - errThread - } - - /** - * ProcessBuilder used to launch worker R processes. - */ - def createRWorker(port: Int): BufferedStreamThread = { - val useDaemon = SparkEnv.get.conf.getBoolean("spark.sparkr.use.daemon", true) - if (!Utils.isWindows && useDaemon) { - synchronized { - if (daemonChannel == null) { - // we expect one connections - val serverSocket = new ServerSocket(0, 1, InetAddress.getByName("localhost")) - val daemonPort = serverSocket.getLocalPort - errThread = createRProcess(daemonPort, "daemon.R") - // the socket used to send out the input of task - serverSocket.setSoTimeout(10000) - val sock = serverSocket.accept() - daemonChannel = new DataOutputStream(new BufferedOutputStream(sock.getOutputStream)) - serverSocket.close() - } - try { - daemonChannel.writeInt(port) - daemonChannel.flush() - } catch { - case e: IOException => - // daemon process died - daemonChannel.close() - daemonChannel = null - errThread = null - // fail the current task, retry by scheduler - throw e - } - errThread - } - } else { - createRProcess(port, "worker.R") - } - } - /** * Create an RRDD given a sequence of byte arrays. Used to create RRDD when `parallelize` is * called from R. @@ -459,5 +140,4 @@ private[r] object RRDD { def createRDDFromArray(jsc: JavaSparkContext, arr: Array[Array[Byte]]): JavaRDD[Array[Byte]] = { JavaRDD.fromRDD(jsc.sc.parallelize(arr, arr.length)) } - } diff --git a/core/src/main/scala/org/apache/spark/api/r/RRunner.scala b/core/src/main/scala/org/apache/spark/api/r/RRunner.scala new file mode 100644 index 0000000000..e8fcada453 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/api/r/RRunner.scala @@ -0,0 +1,367 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.r + +import java.io._ +import java.net.{InetAddress, ServerSocket} +import java.util.Arrays + +import scala.io.Source +import scala.util.Try + +import org.apache.spark._ +import org.apache.spark.broadcast.Broadcast +import org.apache.spark.util.Utils + +/** + * A helper class to run R UDFs in Spark. + */ +private[spark] class RRunner[U]( + func: Array[Byte], + deserializer: String, + serializer: String, + packageNames: Array[Byte], + broadcastVars: Array[Broadcast[Object]], + numPartitions: Int = -1) + extends Logging { + private var bootTime: Double = _ + private var dataStream: DataInputStream = _ + val readData = numPartitions match { + case -1 => + serializer match { + case SerializationFormats.STRING => readStringData _ + case _ => readByteArrayData _ + } + case _ => readShuffledData _ + } + + def compute( + inputIterator: Iterator[_], + partitionIndex: Int, + context: TaskContext): Iterator[U] = { + // Timing start + bootTime = System.currentTimeMillis / 1000.0 + + // we expect two connections + val serverSocket = new ServerSocket(0, 2, InetAddress.getByName("localhost")) + val listenPort = serverSocket.getLocalPort() + + // The stdout/stderr is shared by multiple tasks, because we use one daemon + // to launch child process as worker. + val errThread = RRunner.createRWorker(listenPort) + + // We use two sockets to separate input and output, then it's easy to manage + // the lifecycle of them to avoid deadlock. + // TODO: optimize it to use one socket + + // the socket used to send out the input of task + serverSocket.setSoTimeout(10000) + val inSocket = serverSocket.accept() + startStdinThread(inSocket.getOutputStream(), inputIterator, partitionIndex) + + // the socket used to receive the output of task + val outSocket = serverSocket.accept() + val inputStream = new BufferedInputStream(outSocket.getInputStream) + dataStream = new DataInputStream(inputStream) + serverSocket.close() + + try { + return new Iterator[U] { + def next(): U = { + val obj = _nextObj + if (hasNext) { + _nextObj = read() + } + obj + } + + var _nextObj = read() + + def hasNext(): Boolean = { + val hasMore = (_nextObj != null) + if (!hasMore) { + dataStream.close() + } + hasMore + } + } + } catch { + case e: Exception => + throw new SparkException("R computation failed with\n " + errThread.getLines()) + } + } + + /** + * Start a thread to write RDD data to the R process. + */ + private def startStdinThread( + output: OutputStream, + iter: Iterator[_], + partitionIndex: Int): Unit = { + val env = SparkEnv.get + val taskContext = TaskContext.get() + val bufferSize = System.getProperty("spark.buffer.size", "65536").toInt + val stream = new BufferedOutputStream(output, bufferSize) + + new Thread("writer for R") { + override def run(): Unit = { + try { + SparkEnv.set(env) + TaskContext.setTaskContext(taskContext) + val dataOut = new DataOutputStream(stream) + dataOut.writeInt(partitionIndex) + + SerDe.writeString(dataOut, deserializer) + SerDe.writeString(dataOut, serializer) + + dataOut.writeInt(packageNames.length) + dataOut.write(packageNames) + + dataOut.writeInt(func.length) + dataOut.write(func) + + dataOut.writeInt(broadcastVars.length) + broadcastVars.foreach { broadcast => + // TODO(shivaram): Read a Long in R to avoid this cast + dataOut.writeInt(broadcast.id.toInt) + // TODO: Pass a byte array from R to avoid this cast ? + val broadcastByteArr = broadcast.value.asInstanceOf[Array[Byte]] + dataOut.writeInt(broadcastByteArr.length) + dataOut.write(broadcastByteArr) + } + + dataOut.writeInt(numPartitions) + + if (!iter.hasNext) { + dataOut.writeInt(0) + } else { + dataOut.writeInt(1) + } + + val printOut = new PrintStream(stream) + + def writeElem(elem: Any): Unit = { + if (deserializer == SerializationFormats.BYTE) { + val elemArr = elem.asInstanceOf[Array[Byte]] + dataOut.writeInt(elemArr.length) + dataOut.write(elemArr) + } else if (deserializer == SerializationFormats.ROW) { + dataOut.write(elem.asInstanceOf[Array[Byte]]) + } else if (deserializer == SerializationFormats.STRING) { + // write string(for StringRRDD) + // scalastyle:off println + printOut.println(elem) + // scalastyle:on println + } + } + + for (elem <- iter) { + elem match { + case (key, value) => + writeElem(key) + writeElem(value) + case _ => + writeElem(elem) + } + } + stream.flush() + } catch { + // TODO: We should propogate this error to the task thread + case e: Exception => + logError("R Writer thread got an exception", e) + } finally { + Try(output.close()) + } + } + }.start() + } + + private def read(): U = { + try { + val length = dataStream.readInt() + + length match { + case SpecialLengths.TIMING_DATA => + // Timing data from R worker + val boot = dataStream.readDouble - bootTime + val init = dataStream.readDouble + val broadcast = dataStream.readDouble + val input = dataStream.readDouble + val compute = dataStream.readDouble + val output = dataStream.readDouble + logInfo( + ("Times: boot = %.3f s, init = %.3f s, broadcast = %.3f s, " + + "read-input = %.3f s, compute = %.3f s, write-output = %.3f s, " + + "total = %.3f s").format( + boot, + init, + broadcast, + input, + compute, + output, + boot + init + broadcast + input + compute + output)) + read() + case length if length >= 0 => + readData(length).asInstanceOf[U] + } + } catch { + case eof: EOFException => + throw new SparkException("R worker exited unexpectedly (cranshed)", eof) + } + } + + private def readShuffledData(length: Int): (Int, Array[Byte]) = { + length match { + case length if length == 2 => + val hashedKey = dataStream.readInt() + val contentPairsLength = dataStream.readInt() + val contentPairs = new Array[Byte](contentPairsLength) + dataStream.readFully(contentPairs) + (hashedKey, contentPairs) + case _ => null + } + } + + private def readByteArrayData(length: Int): Array[Byte] = { + length match { + case length if length > 0 => + val obj = new Array[Byte](length) + dataStream.readFully(obj) + obj + case _ => null + } + } + + private def readStringData(length: Int): String = { + length match { + case length if length > 0 => + SerDe.readStringBytes(dataStream, length) + case _ => null + } + } +} + +private object SpecialLengths { + val TIMING_DATA = -1 +} + +private[r] class BufferedStreamThread( + in: InputStream, + name: String, + errBufferSize: Int) extends Thread(name) with Logging { + val lines = new Array[String](errBufferSize) + var lineIdx = 0 + override def run() { + for (line <- Source.fromInputStream(in).getLines) { + synchronized { + lines(lineIdx) = line + lineIdx = (lineIdx + 1) % errBufferSize + } + logInfo(line) + } + } + + def getLines(): String = synchronized { + (0 until errBufferSize).filter { x => + lines((x + lineIdx) % errBufferSize) != null + }.map { x => + lines((x + lineIdx) % errBufferSize) + }.mkString("\n") + } +} + +private[r] object RRunner { + // Because forking processes from Java is expensive, we prefer to launch + // a single R daemon (daemon.R) and tell it to fork new workers for our tasks. + // This daemon currently only works on UNIX-based systems now, so we should + // also fall back to launching workers (worker.R) directly. + private[this] var errThread: BufferedStreamThread = _ + private[this] var daemonChannel: DataOutputStream = _ + + /** + * Start a thread to print the process's stderr to ours + */ + private def startStdoutThread(proc: Process): BufferedStreamThread = { + val BUFFER_SIZE = 100 + val thread = new BufferedStreamThread(proc.getInputStream, "stdout reader for R", BUFFER_SIZE) + thread.setDaemon(true) + thread.start() + thread + } + + private def createRProcess(port: Int, script: String): BufferedStreamThread = { + // "spark.sparkr.r.command" is deprecated and replaced by "spark.r.command", + // but kept here for backward compatibility. + val sparkConf = SparkEnv.get.conf + var rCommand = sparkConf.get("spark.sparkr.r.command", "Rscript") + rCommand = sparkConf.get("spark.r.command", rCommand) + + val rOptions = "--vanilla" + val rLibDir = RUtils.sparkRPackagePath(isDriver = false) + val rExecScript = rLibDir(0) + "/SparkR/worker/" + script + val pb = new ProcessBuilder(Arrays.asList(rCommand, rOptions, rExecScript)) + // Unset the R_TESTS environment variable for workers. + // This is set by R CMD check as startup.Rs + // (http://svn.r-project.org/R/trunk/src/library/tools/R/testing.R) + // and confuses worker script which tries to load a non-existent file + pb.environment().put("R_TESTS", "") + pb.environment().put("SPARKR_RLIBDIR", rLibDir.mkString(",")) + pb.environment().put("SPARKR_WORKER_PORT", port.toString) + pb.redirectErrorStream(true) // redirect stderr into stdout + val proc = pb.start() + val errThread = startStdoutThread(proc) + errThread + } + + /** + * ProcessBuilder used to launch worker R processes. + */ + def createRWorker(port: Int): BufferedStreamThread = { + val useDaemon = SparkEnv.get.conf.getBoolean("spark.sparkr.use.daemon", true) + if (!Utils.isWindows && useDaemon) { + synchronized { + if (daemonChannel == null) { + // we expect one connections + val serverSocket = new ServerSocket(0, 1, InetAddress.getByName("localhost")) + val daemonPort = serverSocket.getLocalPort + errThread = createRProcess(daemonPort, "daemon.R") + // the socket used to send out the input of task + serverSocket.setSoTimeout(10000) + val sock = serverSocket.accept() + daemonChannel = new DataOutputStream(new BufferedOutputStream(sock.getOutputStream)) + serverSocket.close() + } + try { + daemonChannel.writeInt(port) + daemonChannel.flush() + } catch { + case e: IOException => + // daemon process died + daemonChannel.close() + daemonChannel = null + errThread = null + // fail the current task, retry by scheduler + throw e + } + errThread + } + } else { + createRProcess(port, "worker.R") + } + } +} -- cgit v1.2.3 From e5a1b301fbe191f1a9627a1083d960c98f543d13 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Mon, 28 Mar 2016 10:21:02 -0700 Subject: Revert "[SPARK-12792] [SPARKR] Refactor RRDD to support R UDF." This reverts commit 40984f67065eeaea731940008e6677c2323dda3e. --- R/pkg/inst/tests/testthat/test_rdd.R | 8 - .../main/scala/org/apache/spark/api/r/RRDD.scala | 328 +++++++++++++++++- .../scala/org/apache/spark/api/r/RRunner.scala | 367 --------------------- 3 files changed, 324 insertions(+), 379 deletions(-) delete mode 100644 core/src/main/scala/org/apache/spark/api/r/RRunner.scala diff --git a/R/pkg/inst/tests/testthat/test_rdd.R b/R/pkg/inst/tests/testthat/test_rdd.R index b6c8e1dc6c..3b0c16be5a 100644 --- a/R/pkg/inst/tests/testthat/test_rdd.R +++ b/R/pkg/inst/tests/testthat/test_rdd.R @@ -791,11 +791,3 @@ test_that("sampleByKey() on pairwise RDDs", { expect_equal(lookup(sample, 3)[which.min(lookup(sample, 3))] >= 0, TRUE) expect_equal(lookup(sample, 3)[which.max(lookup(sample, 3))] <= 2000, TRUE) }) - -test_that("Test correct concurrency of RRDD.compute()", { - rdd <- parallelize(sc, 1:1000, 100) - jrdd <- getJRDD(lapply(rdd, function(x) { x }), "row") - zrdd <- callJMethod(jrdd, "zip", jrdd) - count <- callJMethod(zrdd, "count") - expect_equal(count, 1000) -}) diff --git a/core/src/main/scala/org/apache/spark/api/r/RRDD.scala b/core/src/main/scala/org/apache/spark/api/r/RRDD.scala index 606ba6ef86..588a57e65f 100644 --- a/core/src/main/scala/org/apache/spark/api/r/RRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/r/RRDD.scala @@ -17,16 +17,21 @@ package org.apache.spark.api.r -import java.util.{Map => JMap} +import java.io._ +import java.net.{InetAddress, ServerSocket} +import java.util.{Arrays, Map => JMap} import scala.collection.JavaConverters._ +import scala.io.Source import scala.reflect.ClassTag +import scala.util.Try import org.apache.spark._ import org.apache.spark.api.java.{JavaPairRDD, JavaRDD, JavaSparkContext} import org.apache.spark.broadcast.Broadcast import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD +import org.apache.spark.util.Utils private abstract class BaseRRDD[T: ClassTag, U: ClassTag]( parent: RDD[T], @@ -37,16 +42,188 @@ private abstract class BaseRRDD[T: ClassTag, U: ClassTag]( packageNames: Array[Byte], broadcastVars: Array[Broadcast[Object]]) extends RDD[U](parent) with Logging { + protected var dataStream: DataInputStream = _ + private var bootTime: Double = _ override def getPartitions: Array[Partition] = parent.partitions override def compute(partition: Partition, context: TaskContext): Iterator[U] = { - val runner = new RRunner[U]( - func, deserializer, serializer, packageNames, broadcastVars, numPartitions) + + // Timing start + bootTime = System.currentTimeMillis / 1000.0 // The parent may be also an RRDD, so we should launch it first. val parentIterator = firstParent[T].iterator(partition, context) - runner.compute(parentIterator, partition.index, context) + // we expect two connections + val serverSocket = new ServerSocket(0, 2, InetAddress.getByName("localhost")) + val listenPort = serverSocket.getLocalPort() + + // The stdout/stderr is shared by multiple tasks, because we use one daemon + // to launch child process as worker. + val errThread = RRDD.createRWorker(listenPort) + + // We use two sockets to separate input and output, then it's easy to manage + // the lifecycle of them to avoid deadlock. + // TODO: optimize it to use one socket + + // the socket used to send out the input of task + serverSocket.setSoTimeout(10000) + val inSocket = serverSocket.accept() + startStdinThread(inSocket.getOutputStream(), parentIterator, partition.index) + + // the socket used to receive the output of task + val outSocket = serverSocket.accept() + val inputStream = new BufferedInputStream(outSocket.getInputStream) + dataStream = new DataInputStream(inputStream) + serverSocket.close() + + try { + + return new Iterator[U] { + def next(): U = { + val obj = _nextObj + if (hasNext) { + _nextObj = read() + } + obj + } + + var _nextObj = read() + + def hasNext(): Boolean = { + val hasMore = (_nextObj != null) + if (!hasMore) { + dataStream.close() + } + hasMore + } + } + } catch { + case e: Exception => + throw new SparkException("R computation failed with\n " + errThread.getLines()) + } + } + + /** + * Start a thread to write RDD data to the R process. + */ + private def startStdinThread[T]( + output: OutputStream, + iter: Iterator[T], + partition: Int): Unit = { + + val env = SparkEnv.get + val taskContext = TaskContext.get() + val bufferSize = System.getProperty("spark.buffer.size", "65536").toInt + val stream = new BufferedOutputStream(output, bufferSize) + + new Thread("writer for R") { + override def run(): Unit = { + try { + SparkEnv.set(env) + TaskContext.setTaskContext(taskContext) + val dataOut = new DataOutputStream(stream) + dataOut.writeInt(partition) + + SerDe.writeString(dataOut, deserializer) + SerDe.writeString(dataOut, serializer) + + dataOut.writeInt(packageNames.length) + dataOut.write(packageNames) + + dataOut.writeInt(func.length) + dataOut.write(func) + + dataOut.writeInt(broadcastVars.length) + broadcastVars.foreach { broadcast => + // TODO(shivaram): Read a Long in R to avoid this cast + dataOut.writeInt(broadcast.id.toInt) + // TODO: Pass a byte array from R to avoid this cast ? + val broadcastByteArr = broadcast.value.asInstanceOf[Array[Byte]] + dataOut.writeInt(broadcastByteArr.length) + dataOut.write(broadcastByteArr) + } + + dataOut.writeInt(numPartitions) + + if (!iter.hasNext) { + dataOut.writeInt(0) + } else { + dataOut.writeInt(1) + } + + val printOut = new PrintStream(stream) + + def writeElem(elem: Any): Unit = { + if (deserializer == SerializationFormats.BYTE) { + val elemArr = elem.asInstanceOf[Array[Byte]] + dataOut.writeInt(elemArr.length) + dataOut.write(elemArr) + } else if (deserializer == SerializationFormats.ROW) { + dataOut.write(elem.asInstanceOf[Array[Byte]]) + } else if (deserializer == SerializationFormats.STRING) { + // write string(for StringRRDD) + // scalastyle:off println + printOut.println(elem) + // scalastyle:on println + } + } + + for (elem <- iter) { + elem match { + case (key, value) => + writeElem(key) + writeElem(value) + case _ => + writeElem(elem) + } + } + stream.flush() + } catch { + // TODO: We should propogate this error to the task thread + case e: Exception => + logError("R Writer thread got an exception", e) + } finally { + Try(output.close()) + } + } + }.start() + } + + protected def readData(length: Int): U + + protected def read(): U = { + try { + val length = dataStream.readInt() + + length match { + case SpecialLengths.TIMING_DATA => + // Timing data from R worker + val boot = dataStream.readDouble - bootTime + val init = dataStream.readDouble + val broadcast = dataStream.readDouble + val input = dataStream.readDouble + val compute = dataStream.readDouble + val output = dataStream.readDouble + logInfo( + ("Times: boot = %.3f s, init = %.3f s, broadcast = %.3f s, " + + "read-input = %.3f s, compute = %.3f s, write-output = %.3f s, " + + "total = %.3f s").format( + boot, + init, + broadcast, + input, + compute, + output, + boot + init + broadcast + input + compute + output)) + read() + case length if length >= 0 => + readData(length) + } + } catch { + case eof: EOFException => + throw new SparkException("R worker exited unexpectedly (cranshed)", eof) + } } } @@ -65,6 +242,19 @@ private class PairwiseRRDD[T: ClassTag]( parent, numPartitions, hashFunc, deserializer, SerializationFormats.BYTE, packageNames, broadcastVars.map(x => x.asInstanceOf[Broadcast[Object]])) { + + override protected def readData(length: Int): (Int, Array[Byte]) = { + length match { + case length if length == 2 => + val hashedKey = dataStream.readInt() + val contentPairsLength = dataStream.readInt() + val contentPairs = new Array[Byte](contentPairsLength) + dataStream.readFully(contentPairs) + (hashedKey, contentPairs) + case _ => null + } + } + lazy val asJavaPairRDD : JavaPairRDD[Int, Array[Byte]] = JavaPairRDD.fromRDD(this) } @@ -81,6 +271,17 @@ private class RRDD[T: ClassTag]( extends BaseRRDD[T, Array[Byte]]( parent, -1, func, deserializer, serializer, packageNames, broadcastVars.map(x => x.asInstanceOf[Broadcast[Object]])) { + + override protected def readData(length: Int): Array[Byte] = { + length match { + case length if length > 0 => + val obj = new Array[Byte](length) + dataStream.readFully(obj) + obj + case _ => null + } + } + lazy val asJavaRDD : JavaRDD[Array[Byte]] = JavaRDD.fromRDD(this) } @@ -96,10 +297,55 @@ private class StringRRDD[T: ClassTag]( extends BaseRRDD[T, String]( parent, -1, func, deserializer, SerializationFormats.STRING, packageNames, broadcastVars.map(x => x.asInstanceOf[Broadcast[Object]])) { + + override protected def readData(length: Int): String = { + length match { + case length if length > 0 => + SerDe.readStringBytes(dataStream, length) + case _ => null + } + } + lazy val asJavaRDD : JavaRDD[String] = JavaRDD.fromRDD(this) } +private object SpecialLengths { + val TIMING_DATA = -1 +} + +private[r] class BufferedStreamThread( + in: InputStream, + name: String, + errBufferSize: Int) extends Thread(name) with Logging { + val lines = new Array[String](errBufferSize) + var lineIdx = 0 + override def run() { + for (line <- Source.fromInputStream(in).getLines) { + synchronized { + lines(lineIdx) = line + lineIdx = (lineIdx + 1) % errBufferSize + } + logInfo(line) + } + } + + def getLines(): String = synchronized { + (0 until errBufferSize).filter { x => + lines((x + lineIdx) % errBufferSize) != null + }.map { x => + lines((x + lineIdx) % errBufferSize) + }.mkString("\n") + } +} + private[r] object RRDD { + // Because forking processes from Java is expensive, we prefer to launch + // a single R daemon (daemon.R) and tell it to fork new workers for our tasks. + // This daemon currently only works on UNIX-based systems now, so we should + // also fall back to launching workers (worker.R) directly. + private[this] var errThread: BufferedStreamThread = _ + private[this] var daemonChannel: DataOutputStream = _ + def createSparkContext( master: String, appName: String, @@ -107,6 +353,7 @@ private[r] object RRDD { jars: Array[String], sparkEnvirMap: JMap[Object, Object], sparkExecutorEnvMap: JMap[Object, Object]): JavaSparkContext = { + val sparkConf = new SparkConf().setAppName(appName) .setSparkHome(sparkHome) @@ -133,6 +380,78 @@ private[r] object RRDD { jsc } + /** + * Start a thread to print the process's stderr to ours + */ + private def startStdoutThread(proc: Process): BufferedStreamThread = { + val BUFFER_SIZE = 100 + val thread = new BufferedStreamThread(proc.getInputStream, "stdout reader for R", BUFFER_SIZE) + thread.setDaemon(true) + thread.start() + thread + } + + private def createRProcess(port: Int, script: String): BufferedStreamThread = { + // "spark.sparkr.r.command" is deprecated and replaced by "spark.r.command", + // but kept here for backward compatibility. + val sparkConf = SparkEnv.get.conf + var rCommand = sparkConf.get("spark.sparkr.r.command", "Rscript") + rCommand = sparkConf.get("spark.r.command", rCommand) + + val rOptions = "--vanilla" + val rLibDir = RUtils.sparkRPackagePath(isDriver = false) + val rExecScript = rLibDir(0) + "/SparkR/worker/" + script + val pb = new ProcessBuilder(Arrays.asList(rCommand, rOptions, rExecScript)) + // Unset the R_TESTS environment variable for workers. + // This is set by R CMD check as startup.Rs + // (http://svn.r-project.org/R/trunk/src/library/tools/R/testing.R) + // and confuses worker script which tries to load a non-existent file + pb.environment().put("R_TESTS", "") + pb.environment().put("SPARKR_RLIBDIR", rLibDir.mkString(",")) + pb.environment().put("SPARKR_WORKER_PORT", port.toString) + pb.redirectErrorStream(true) // redirect stderr into stdout + val proc = pb.start() + val errThread = startStdoutThread(proc) + errThread + } + + /** + * ProcessBuilder used to launch worker R processes. + */ + def createRWorker(port: Int): BufferedStreamThread = { + val useDaemon = SparkEnv.get.conf.getBoolean("spark.sparkr.use.daemon", true) + if (!Utils.isWindows && useDaemon) { + synchronized { + if (daemonChannel == null) { + // we expect one connections + val serverSocket = new ServerSocket(0, 1, InetAddress.getByName("localhost")) + val daemonPort = serverSocket.getLocalPort + errThread = createRProcess(daemonPort, "daemon.R") + // the socket used to send out the input of task + serverSocket.setSoTimeout(10000) + val sock = serverSocket.accept() + daemonChannel = new DataOutputStream(new BufferedOutputStream(sock.getOutputStream)) + serverSocket.close() + } + try { + daemonChannel.writeInt(port) + daemonChannel.flush() + } catch { + case e: IOException => + // daemon process died + daemonChannel.close() + daemonChannel = null + errThread = null + // fail the current task, retry by scheduler + throw e + } + errThread + } + } else { + createRProcess(port, "worker.R") + } + } + /** * Create an RRDD given a sequence of byte arrays. Used to create RRDD when `parallelize` is * called from R. @@ -140,4 +459,5 @@ private[r] object RRDD { def createRDDFromArray(jsc: JavaSparkContext, arr: Array[Array[Byte]]): JavaRDD[Array[Byte]] = { JavaRDD.fromRDD(jsc.sc.parallelize(arr, arr.length)) } + } diff --git a/core/src/main/scala/org/apache/spark/api/r/RRunner.scala b/core/src/main/scala/org/apache/spark/api/r/RRunner.scala deleted file mode 100644 index e8fcada453..0000000000 --- a/core/src/main/scala/org/apache/spark/api/r/RRunner.scala +++ /dev/null @@ -1,367 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.api.r - -import java.io._ -import java.net.{InetAddress, ServerSocket} -import java.util.Arrays - -import scala.io.Source -import scala.util.Try - -import org.apache.spark._ -import org.apache.spark.broadcast.Broadcast -import org.apache.spark.util.Utils - -/** - * A helper class to run R UDFs in Spark. - */ -private[spark] class RRunner[U]( - func: Array[Byte], - deserializer: String, - serializer: String, - packageNames: Array[Byte], - broadcastVars: Array[Broadcast[Object]], - numPartitions: Int = -1) - extends Logging { - private var bootTime: Double = _ - private var dataStream: DataInputStream = _ - val readData = numPartitions match { - case -1 => - serializer match { - case SerializationFormats.STRING => readStringData _ - case _ => readByteArrayData _ - } - case _ => readShuffledData _ - } - - def compute( - inputIterator: Iterator[_], - partitionIndex: Int, - context: TaskContext): Iterator[U] = { - // Timing start - bootTime = System.currentTimeMillis / 1000.0 - - // we expect two connections - val serverSocket = new ServerSocket(0, 2, InetAddress.getByName("localhost")) - val listenPort = serverSocket.getLocalPort() - - // The stdout/stderr is shared by multiple tasks, because we use one daemon - // to launch child process as worker. - val errThread = RRunner.createRWorker(listenPort) - - // We use two sockets to separate input and output, then it's easy to manage - // the lifecycle of them to avoid deadlock. - // TODO: optimize it to use one socket - - // the socket used to send out the input of task - serverSocket.setSoTimeout(10000) - val inSocket = serverSocket.accept() - startStdinThread(inSocket.getOutputStream(), inputIterator, partitionIndex) - - // the socket used to receive the output of task - val outSocket = serverSocket.accept() - val inputStream = new BufferedInputStream(outSocket.getInputStream) - dataStream = new DataInputStream(inputStream) - serverSocket.close() - - try { - return new Iterator[U] { - def next(): U = { - val obj = _nextObj - if (hasNext) { - _nextObj = read() - } - obj - } - - var _nextObj = read() - - def hasNext(): Boolean = { - val hasMore = (_nextObj != null) - if (!hasMore) { - dataStream.close() - } - hasMore - } - } - } catch { - case e: Exception => - throw new SparkException("R computation failed with\n " + errThread.getLines()) - } - } - - /** - * Start a thread to write RDD data to the R process. - */ - private def startStdinThread( - output: OutputStream, - iter: Iterator[_], - partitionIndex: Int): Unit = { - val env = SparkEnv.get - val taskContext = TaskContext.get() - val bufferSize = System.getProperty("spark.buffer.size", "65536").toInt - val stream = new BufferedOutputStream(output, bufferSize) - - new Thread("writer for R") { - override def run(): Unit = { - try { - SparkEnv.set(env) - TaskContext.setTaskContext(taskContext) - val dataOut = new DataOutputStream(stream) - dataOut.writeInt(partitionIndex) - - SerDe.writeString(dataOut, deserializer) - SerDe.writeString(dataOut, serializer) - - dataOut.writeInt(packageNames.length) - dataOut.write(packageNames) - - dataOut.writeInt(func.length) - dataOut.write(func) - - dataOut.writeInt(broadcastVars.length) - broadcastVars.foreach { broadcast => - // TODO(shivaram): Read a Long in R to avoid this cast - dataOut.writeInt(broadcast.id.toInt) - // TODO: Pass a byte array from R to avoid this cast ? - val broadcastByteArr = broadcast.value.asInstanceOf[Array[Byte]] - dataOut.writeInt(broadcastByteArr.length) - dataOut.write(broadcastByteArr) - } - - dataOut.writeInt(numPartitions) - - if (!iter.hasNext) { - dataOut.writeInt(0) - } else { - dataOut.writeInt(1) - } - - val printOut = new PrintStream(stream) - - def writeElem(elem: Any): Unit = { - if (deserializer == SerializationFormats.BYTE) { - val elemArr = elem.asInstanceOf[Array[Byte]] - dataOut.writeInt(elemArr.length) - dataOut.write(elemArr) - } else if (deserializer == SerializationFormats.ROW) { - dataOut.write(elem.asInstanceOf[Array[Byte]]) - } else if (deserializer == SerializationFormats.STRING) { - // write string(for StringRRDD) - // scalastyle:off println - printOut.println(elem) - // scalastyle:on println - } - } - - for (elem <- iter) { - elem match { - case (key, value) => - writeElem(key) - writeElem(value) - case _ => - writeElem(elem) - } - } - stream.flush() - } catch { - // TODO: We should propogate this error to the task thread - case e: Exception => - logError("R Writer thread got an exception", e) - } finally { - Try(output.close()) - } - } - }.start() - } - - private def read(): U = { - try { - val length = dataStream.readInt() - - length match { - case SpecialLengths.TIMING_DATA => - // Timing data from R worker - val boot = dataStream.readDouble - bootTime - val init = dataStream.readDouble - val broadcast = dataStream.readDouble - val input = dataStream.readDouble - val compute = dataStream.readDouble - val output = dataStream.readDouble - logInfo( - ("Times: boot = %.3f s, init = %.3f s, broadcast = %.3f s, " + - "read-input = %.3f s, compute = %.3f s, write-output = %.3f s, " + - "total = %.3f s").format( - boot, - init, - broadcast, - input, - compute, - output, - boot + init + broadcast + input + compute + output)) - read() - case length if length >= 0 => - readData(length).asInstanceOf[U] - } - } catch { - case eof: EOFException => - throw new SparkException("R worker exited unexpectedly (cranshed)", eof) - } - } - - private def readShuffledData(length: Int): (Int, Array[Byte]) = { - length match { - case length if length == 2 => - val hashedKey = dataStream.readInt() - val contentPairsLength = dataStream.readInt() - val contentPairs = new Array[Byte](contentPairsLength) - dataStream.readFully(contentPairs) - (hashedKey, contentPairs) - case _ => null - } - } - - private def readByteArrayData(length: Int): Array[Byte] = { - length match { - case length if length > 0 => - val obj = new Array[Byte](length) - dataStream.readFully(obj) - obj - case _ => null - } - } - - private def readStringData(length: Int): String = { - length match { - case length if length > 0 => - SerDe.readStringBytes(dataStream, length) - case _ => null - } - } -} - -private object SpecialLengths { - val TIMING_DATA = -1 -} - -private[r] class BufferedStreamThread( - in: InputStream, - name: String, - errBufferSize: Int) extends Thread(name) with Logging { - val lines = new Array[String](errBufferSize) - var lineIdx = 0 - override def run() { - for (line <- Source.fromInputStream(in).getLines) { - synchronized { - lines(lineIdx) = line - lineIdx = (lineIdx + 1) % errBufferSize - } - logInfo(line) - } - } - - def getLines(): String = synchronized { - (0 until errBufferSize).filter { x => - lines((x + lineIdx) % errBufferSize) != null - }.map { x => - lines((x + lineIdx) % errBufferSize) - }.mkString("\n") - } -} - -private[r] object RRunner { - // Because forking processes from Java is expensive, we prefer to launch - // a single R daemon (daemon.R) and tell it to fork new workers for our tasks. - // This daemon currently only works on UNIX-based systems now, so we should - // also fall back to launching workers (worker.R) directly. - private[this] var errThread: BufferedStreamThread = _ - private[this] var daemonChannel: DataOutputStream = _ - - /** - * Start a thread to print the process's stderr to ours - */ - private def startStdoutThread(proc: Process): BufferedStreamThread = { - val BUFFER_SIZE = 100 - val thread = new BufferedStreamThread(proc.getInputStream, "stdout reader for R", BUFFER_SIZE) - thread.setDaemon(true) - thread.start() - thread - } - - private def createRProcess(port: Int, script: String): BufferedStreamThread = { - // "spark.sparkr.r.command" is deprecated and replaced by "spark.r.command", - // but kept here for backward compatibility. - val sparkConf = SparkEnv.get.conf - var rCommand = sparkConf.get("spark.sparkr.r.command", "Rscript") - rCommand = sparkConf.get("spark.r.command", rCommand) - - val rOptions = "--vanilla" - val rLibDir = RUtils.sparkRPackagePath(isDriver = false) - val rExecScript = rLibDir(0) + "/SparkR/worker/" + script - val pb = new ProcessBuilder(Arrays.asList(rCommand, rOptions, rExecScript)) - // Unset the R_TESTS environment variable for workers. - // This is set by R CMD check as startup.Rs - // (http://svn.r-project.org/R/trunk/src/library/tools/R/testing.R) - // and confuses worker script which tries to load a non-existent file - pb.environment().put("R_TESTS", "") - pb.environment().put("SPARKR_RLIBDIR", rLibDir.mkString(",")) - pb.environment().put("SPARKR_WORKER_PORT", port.toString) - pb.redirectErrorStream(true) // redirect stderr into stdout - val proc = pb.start() - val errThread = startStdoutThread(proc) - errThread - } - - /** - * ProcessBuilder used to launch worker R processes. - */ - def createRWorker(port: Int): BufferedStreamThread = { - val useDaemon = SparkEnv.get.conf.getBoolean("spark.sparkr.use.daemon", true) - if (!Utils.isWindows && useDaemon) { - synchronized { - if (daemonChannel == null) { - // we expect one connections - val serverSocket = new ServerSocket(0, 1, InetAddress.getByName("localhost")) - val daemonPort = serverSocket.getLocalPort - errThread = createRProcess(daemonPort, "daemon.R") - // the socket used to send out the input of task - serverSocket.setSoTimeout(10000) - val sock = serverSocket.accept() - daemonChannel = new DataOutputStream(new BufferedOutputStream(sock.getOutputStream)) - serverSocket.close() - } - try { - daemonChannel.writeInt(port) - daemonChannel.flush() - } catch { - case e: IOException => - // daemon process died - daemonChannel.close() - daemonChannel = null - errThread = null - // fail the current task, retry by scheduler - throw e - } - errThread - } - } else { - createRProcess(port, "worker.R") - } - } -} -- cgit v1.2.3 From 4a7636f2da2121ee8c6fb7e6614820aaf3db8e0f Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Mon, 28 Mar 2016 10:35:48 -0700 Subject: [SPARK-13844] [SQL] Generate better code for filters with a non-nullable column ## What changes were proposed in this pull request? This PR simplifies generated code with a non-nullable column. This PR addresses three items: 1. Generate simplified code for and / or 2. Generate better code for divide and remainder with non-zero dividend 3. Pass nullable information into BoundReference at WholeStageCodegen I have attached the generated code with and without this PR ## How was this patch tested? Tested by existing test suites in sql/core Here is a motivating example ```` (0 to 6).map(i => (i.toString, i.toInt)).toDF("k", "v") .filter("v % 2 == 0").filter("v <= 4").filter("v > 1").show() ```` Generated code without this PR ````java /* 032 */ protected void processNext() throws java.io.IOException { /* 033 */ /*** PRODUCE: Project [_1#0 AS k#3,_2#1 AS v#4] */ /* 034 */ /* 035 */ /*** PRODUCE: Filter ((isnotnull((_2#1 % 2)) && ((_2#1 % 2) = 0)) && ((_2#1 <= 4) && (_2#1 > 1))) */ /* 036 */ /* 037 */ /*** PRODUCE: INPUT */ /* 038 */ /* 039 */ while (!shouldStop() && inputadapter_input.hasNext()) { /* 040 */ InternalRow inputadapter_row = (InternalRow) inputadapter_input.next(); /* 041 */ /*** CONSUME: Filter ((isnotnull((_2#1 % 2)) && ((_2#1 % 2) = 0)) && ((_2#1 <= 4) && (_2#1 > 1))) */ /* 042 */ /* input[1, int] */ /* 043 */ int filter_value1 = inputadapter_row.getInt(1); /* 044 */ /* 045 */ /* isnotnull((input[1, int] % 2)) */ /* 046 */ /* (input[1, int] % 2) */ /* 047 */ boolean filter_isNull3 = false; /* 048 */ int filter_value3 = -1; /* 049 */ if (false || 2 == 0) { /* 050 */ filter_isNull3 = true; /* 051 */ } else { /* 052 */ if (false) { /* 053 */ filter_isNull3 = true; /* 054 */ } else { /* 055 */ filter_value3 = (int)(filter_value1 % 2); /* 056 */ } /* 057 */ } /* 058 */ if (!(!(filter_isNull3))) continue; /* 059 */ /* 060 */ /* ((input[1, int] % 2) = 0) */ /* 061 */ boolean filter_isNull6 = true; /* 062 */ boolean filter_value6 = false; /* 063 */ /* (input[1, int] % 2) */ /* 064 */ boolean filter_isNull7 = false; /* 065 */ int filter_value7 = -1; /* 066 */ if (false || 2 == 0) { /* 067 */ filter_isNull7 = true; /* 068 */ } else { /* 069 */ if (false) { /* 070 */ filter_isNull7 = true; /* 071 */ } else { /* 072 */ filter_value7 = (int)(filter_value1 % 2); /* 073 */ } /* 074 */ } /* 075 */ if (!filter_isNull7) { /* 076 */ filter_isNull6 = false; // resultCode could change nullability. /* 077 */ filter_value6 = filter_value7 == 0; /* 078 */ /* 079 */ } /* 080 */ if (filter_isNull6 || !filter_value6) continue; /* 081 */ /* 082 */ /* (input[1, int] <= 4) */ /* 083 */ boolean filter_value11 = false; /* 084 */ filter_value11 = filter_value1 <= 4; /* 085 */ if (!filter_value11) continue; /* 086 */ /* 087 */ /* (input[1, int] > 1) */ /* 088 */ boolean filter_value14 = false; /* 089 */ filter_value14 = filter_value1 > 1; /* 090 */ if (!filter_value14) continue; /* 091 */ /* 092 */ filter_metricValue.add(1); /* 093 */ /* 094 */ /*** CONSUME: Project [_1#0 AS k#3,_2#1 AS v#4] */ /* 095 */ /* 096 */ /* input[0, string] */ /* 097 */ /* input[0, string] */ /* 098 */ boolean filter_isNull = inputadapter_row.isNullAt(0); /* 099 */ UTF8String filter_value = filter_isNull ? null : (inputadapter_row.getUTF8String(0)); /* 100 */ project_holder.reset(); /* 101 */ /* 102 */ project_rowWriter.zeroOutNullBytes(); /* 103 */ /* 104 */ if (filter_isNull) { /* 105 */ project_rowWriter.setNullAt(0); /* 106 */ } else { /* 107 */ project_rowWriter.write(0, filter_value); /* 108 */ } /* 109 */ /* 110 */ project_rowWriter.write(1, filter_value1); /* 111 */ project_result.setTotalSize(project_holder.totalSize()); /* 112 */ append(project_result.copy()); /* 113 */ } /* 114 */ } /* 115 */ } ```` Generated code with this PR ````java /* 032 */ protected void processNext() throws java.io.IOException { /* 033 */ /*** PRODUCE: Project [_1#0 AS k#3,_2#1 AS v#4] */ /* 034 */ /* 035 */ /*** PRODUCE: Filter (((_2#1 % 2) = 0) && ((_2#1 <= 5) && (_2#1 > 1))) */ /* 036 */ /* 037 */ /*** PRODUCE: INPUT */ /* 038 */ /* 039 */ while (!shouldStop() && inputadapter_input.hasNext()) { /* 040 */ InternalRow inputadapter_row = (InternalRow) inputadapter_input.next(); /* 041 */ /*** CONSUME: Filter (((_2#1 % 2) = 0) && ((_2#1 <= 5) && (_2#1 > 1))) */ /* 042 */ /* input[1, int] */ /* 043 */ int filter_value1 = inputadapter_row.getInt(1); /* 044 */ /* 045 */ /* ((input[1, int] % 2) = 0) */ /* 046 */ /* (input[1, int] % 2) */ /* 047 */ int filter_value3 = (int)(filter_value1 % 2); /* 048 */ /* 049 */ boolean filter_value2 = false; /* 050 */ filter_value2 = filter_value3 == 0; /* 051 */ if (!filter_value2) continue; /* 052 */ /* 053 */ /* (input[1, int] <= 5) */ /* 054 */ boolean filter_value7 = false; /* 055 */ filter_value7 = filter_value1 <= 5; /* 056 */ if (!filter_value7) continue; /* 057 */ /* 058 */ /* (input[1, int] > 1) */ /* 059 */ boolean filter_value10 = false; /* 060 */ filter_value10 = filter_value1 > 1; /* 061 */ if (!filter_value10) continue; /* 062 */ /* 063 */ filter_metricValue.add(1); /* 064 */ /* 065 */ /*** CONSUME: Project [_1#0 AS k#3,_2#1 AS v#4] */ /* 066 */ /* 067 */ /* input[0, string] */ /* 068 */ /* input[0, string] */ /* 069 */ boolean filter_isNull = inputadapter_row.isNullAt(0); /* 070 */ UTF8String filter_value = filter_isNull ? null : (inputadapter_row.getUTF8String(0)); /* 071 */ project_holder.reset(); /* 072 */ /* 073 */ project_rowWriter.zeroOutNullBytes(); /* 074 */ /* 075 */ if (filter_isNull) { /* 076 */ project_rowWriter.setNullAt(0); /* 077 */ } else { /* 078 */ project_rowWriter.write(0, filter_value); /* 079 */ } /* 080 */ /* 081 */ project_rowWriter.write(1, filter_value1); /* 082 */ project_result.setTotalSize(project_holder.totalSize()); /* 083 */ append(project_result.copy()); /* 084 */ } /* 085 */ } /* 086 */ } ```` Author: Kazuaki Ishizaki Closes #11684 from kiszk/SPARK-13844. --- .../sql/catalyst/expressions/arithmetic.scala | 72 ++++++++++++++------ .../sql/catalyst/expressions/predicates.scala | 78 ++++++++++++++-------- 2 files changed, 102 insertions(+), 48 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index ed812e0679..1e9c971800 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -237,21 +237,35 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic } else { s"($javaType)(${eval1.value} $symbol ${eval2.value})" } - s""" - ${eval2.code} - boolean ${ev.isNull} = false; - $javaType ${ev.value} = ${ctx.defaultValue(javaType)}; - if (${eval2.isNull} || $isZero) { - ${ev.isNull} = true; - } else { - ${eval1.code} - if (${eval1.isNull}) { + if (!left.nullable && !right.nullable) { + s""" + ${eval2.code} + boolean ${ev.isNull} = false; + $javaType ${ev.value} = ${ctx.defaultValue(javaType)}; + if ($isZero) { ${ev.isNull} = true; } else { + ${eval1.code} ${ev.value} = $divide; } - } - """ + """ + } else { + s""" + ${eval2.code} + boolean ${ev.isNull} = false; + $javaType ${ev.value} = ${ctx.defaultValue(javaType)}; + if (${eval2.isNull} || $isZero) { + ${ev.isNull} = true; + } else { + ${eval1.code} + if (${eval1.isNull}) { + ${ev.isNull} = true; + } else { + ${ev.value} = $divide; + } + } + """ + } } } @@ -299,21 +313,35 @@ case class Remainder(left: Expression, right: Expression) extends BinaryArithmet } else { s"($javaType)(${eval1.value} $symbol ${eval2.value})" } - s""" - ${eval2.code} - boolean ${ev.isNull} = false; - $javaType ${ev.value} = ${ctx.defaultValue(javaType)}; - if (${eval2.isNull} || $isZero) { - ${ev.isNull} = true; - } else { - ${eval1.code} - if (${eval1.isNull}) { + if (!left.nullable && !right.nullable) { + s""" + ${eval2.code} + boolean ${ev.isNull} = false; + $javaType ${ev.value} = ${ctx.defaultValue(javaType)}; + if ($isZero) { ${ev.isNull} = true; } else { + ${eval1.code} ${ev.value} = $remainder; } - } - """ + """ + } else { + s""" + ${eval2.code} + boolean ${ev.isNull} = false; + $javaType ${ev.value} = ${ctx.defaultValue(javaType)}; + if (${eval2.isNull} || $isZero) { + ${ev.isNull} = true; + } else { + ${eval1.code} + if (${eval1.isNull}) { + ${ev.isNull} = true; + } else { + ${ev.value} = $remainder; + } + } + """ + } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 20818bfb1a..e23ad5596b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -274,22 +274,35 @@ case class And(left: Expression, right: Expression) extends BinaryOperator with val eval2 = right.gen(ctx) // The result should be `false`, if any of them is `false` whenever the other is null or not. - s""" - ${eval1.code} - boolean ${ev.isNull} = false; - boolean ${ev.value} = false; + if (!left.nullable && !right.nullable) { + ev.isNull = "false" + s""" + ${eval1.code} + boolean ${ev.value} = false; - if (!${eval1.isNull} && !${eval1.value}) { - } else { - ${eval2.code} - if (!${eval2.isNull} && !${eval2.value}) { - } else if (!${eval1.isNull} && !${eval2.isNull}) { - ${ev.value} = true; + if (${eval1.value}) { + ${eval2.code} + ${ev.value} = ${eval2.value}; + } + """ + } else { + s""" + ${eval1.code} + boolean ${ev.isNull} = false; + boolean ${ev.value} = false; + + if (!${eval1.isNull} && !${eval1.value}) { } else { - ${ev.isNull} = true; + ${eval2.code} + if (!${eval2.isNull} && !${eval2.value}) { + } else if (!${eval1.isNull} && !${eval2.isNull}) { + ${ev.value} = true; + } else { + ${ev.isNull} = true; + } } - } - """ + """ + } } } @@ -325,22 +338,35 @@ case class Or(left: Expression, right: Expression) extends BinaryOperator with P val eval2 = right.gen(ctx) // The result should be `true`, if any of them is `true` whenever the other is null or not. - s""" - ${eval1.code} - boolean ${ev.isNull} = false; - boolean ${ev.value} = true; + if (!left.nullable && !right.nullable) { + ev.isNull = "false" + s""" + ${eval1.code} + boolean ${ev.value} = true; - if (!${eval1.isNull} && ${eval1.value}) { - } else { - ${eval2.code} - if (!${eval2.isNull} && ${eval2.value}) { - } else if (!${eval1.isNull} && !${eval2.isNull}) { - ${ev.value} = false; + if (!${eval1.value}) { + ${eval2.code} + ${ev.value} = ${eval2.value}; + } + """ + } else { + s""" + ${eval1.code} + boolean ${ev.isNull} = false; + boolean ${ev.value} = true; + + if (!${eval1.isNull} && ${eval1.value}) { } else { - ${ev.isNull} = true; + ${eval2.code} + if (!${eval2.isNull} && ${eval2.value}) { + } else if (!${eval1.isNull} && !${eval2.isNull}) { + ${ev.value} = false; + } else { + ${ev.isNull} = true; + } } - } - """ + """ + } } } -- cgit v1.2.3 From 1528ff4c9affe1df103c4b3abd56a86c71d8b753 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Mon, 28 Mar 2016 10:43:54 -0700 Subject: [SPARK-14156][SQL] Use executedPlan in HiveComparisonTest for the messages of computed tables ## What changes were proposed in this pull request? JIRA: https://issues.apache.org/jira/browse/SPARK-14156 In HiveComparisonTest, when catalyst results are different to hive results, we will collect the messages for computed tables during the test. During creating the message, we use sparkPlan. But we actually run the query with executedPlan. So the error message is sometimes confusing. For example, as wholestage codegen is enabled by default now. The shown spark plan for computed tables is the plan before wholestage codegen. A concrete is the following error message shown before this patch. It is the error shown when running `HiveCompatibilityTest` `auto_join26`. auto_join26 has one SQL to create table: INSERT OVERWRITE TABLE dest_j1 SELECT x.key, count(1) FROM src1 x JOIN src y ON (x.key = y.key) group by x.key; (1) Then a SQL to retrieve the result: select * from dest_j1 x order by x.key; (2) When the above SQL (2) to retrieve the result fails, In `HiveComparisonTest` we will try to collect and show the generated data from table `dest_j1` using the SQL (1)'s spark plan. The you will see this error: TungstenAggregate(key=[key#8804], functions=[(count(1),mode=Partial,isDistinct=false)], output=[key#8804,count#8834L]) +- Project [key#8804] +- BroadcastHashJoin [key#8804], [key#8806], Inner, BuildRight, None :- Filter isnotnull(key#8804) : +- InMemoryColumnarTableScan [key#8804], [isnotnull(key#8804)], InMemoryRelation [key#8804,value#8805], true, 5, StorageLevel(true, true, false, true, 1), HiveTableScan [key#8717,value#8718], MetastoreRelation default, src1, None, Some(src1) +- Filter isnotnull(key#8806) +- InMemoryColumnarTableScan [key#8806], [isnotnull(key#8806)], InMemoryRelation [key#8806,value#8807], true, 5, StorageLevel(true, true, false, true, 1), HiveTableScan [key#8760,value#8761], MetastoreRelation default, src, None, Some(src) at org.apache.spark.sql.catalyst.errors.package$.attachTree(package.scala:47) at org.apache.spark.sql.execution.aggregate.TungstenAggregate.doExecute(TungstenAggregate.scala:82) at org.apache.spark.sql.execution.SparkPlan$$anonfun$execute$1.apply(SparkPlan.scala:121) at org.apache.spark.sql.execution.SparkPlan$$anonfun$execute$1.apply(SparkPlan.scala:121) at org.apache.spark.sql.execution.SparkPlan$$anonfun$executeQuery$1.apply(SparkPlan.scala:140) at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151) at org.apache.spark.sql.execution.SparkPlan.executeQuery(SparkPlan.scala:137) at org.apache.spark.sql.execution.SparkPlan.execute(SparkPlan.scala:120) at org.apache.spark.sql.execution.aggregate.TungstenAggregate$$anonfun$doExecute$1.apply(TungstenAggregate.scala:87) at org.apache.spark.sql.execution.aggregate.TungstenAggregate$$anonfun$doExecute$1.apply(TungstenAggregate.scala:82) at org.apache.spark.sql.catalyst.errors.package$.attachTree(package.scala:46) ... 70 more Caused by: java.lang.UnsupportedOperationException: Filter does not implement doExecuteBroadcast at org.apache.spark.sql.execution.SparkPlan.doExecuteBroadcast(SparkPlan.scala:221) The message is confusing because it is not the plan actually run by SparkSQL engine to create the generated table. The plan actually run is no problem. But as before this patch, we run `e.sparkPlan.collect` to retrieve and show the generated data, spark plan is not the plan we can run. So the above error will be shown. After this patch, we won't see the error because the executed plan is no problem and works. ## How was this patch tested? Existing tests. Author: Liang-Chi Hsieh Closes #11957 from viirya/use-executedplan. --- .../org/apache/spark/sql/hive/execution/HiveComparisonTest.scala | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala index cfca93bbf0..4c1b425b16 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala @@ -480,7 +480,11 @@ abstract class HiveComparisonTest val executions = queryList.map(new TestHive.QueryExecution(_)) executions.foreach(_.toRdd) val tablesGenerated = queryList.zip(executions).flatMap { - case (q, e) => e.sparkPlan.collect { + // We should take executedPlan instead of sparkPlan, because in following codes we + // will run the collected plans. As we will do extra processing for sparkPlan such + // as adding exchage, collapsing codegen stages, etc., collecing sparkPlan here + // will cause some errors when running these plans later. + case (q, e) => e.executedPlan.collect { case i: InsertIntoHiveTable if tablesRead contains i.table.tableName => (q, e, i) } -- cgit v1.2.3 From 600c0b69cab4767e8e5a6f4284777d8b9d4bd40e Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Mon, 28 Mar 2016 12:31:12 -0700 Subject: [SPARK-13713][SQL] Migrate parser from ANTLR3 to ANTLR4 ### What changes were proposed in this pull request? The current ANTLR3 parser is quite complex to maintain and suffers from code blow-ups. This PR introduces a new parser that is based on ANTLR4. This parser is based on the [Presto's SQL parser](https://github.com/facebook/presto/blob/master/presto-parser/src/main/antlr4/com/facebook/presto/sql/parser/SqlBase.g4). The current implementation can parse and create Catalyst and SQL plans. Large parts of the HiveQl DDL and some of the DML functionality is currently missing, the plan is to add this in follow-up PRs. This PR is a work in progress, and work needs to be done in the following area's: - [x] Error handling should be improved. - [x] Documentation should be improved. - [x] Multi-Insert needs to be tested. - [ ] Naming and package locations. ### How was this patch tested? Catalyst and SQL unit tests. Author: Herman van Hovell Closes #11557 from hvanhovell/ngParser. --- LICENSE | 1 + dev/deps/spark-deps-hadoop-2.2 | 1 + dev/deps/spark-deps-hadoop-2.3 | 1 + dev/deps/spark-deps-hadoop-2.4 | 1 + dev/deps/spark-deps-hadoop-2.6 | 1 + dev/deps/spark-deps-hadoop-2.7 | 1 + pom.xml | 6 + project/SparkBuild.scala | 8 +- project/plugins.sbt | 6 + python/pyspark/sql/tests.py | 6 +- python/pyspark/sql/utils.py | 8 + sql/catalyst/pom.xml | 4 + .../apache/spark/sql/catalyst/parser/ng/SqlBase.g4 | 911 ++++++++++++ .../apache/spark/sql/catalyst/dsl/package.scala | 34 +- .../spark/sql/catalyst/parser/ng/AstBuilder.scala | 1452 ++++++++++++++++++++ .../spark/sql/catalyst/parser/ng/ParseDriver.scala | 240 ++++ .../spark/sql/catalyst/parser/ng/ParserUtils.scala | 118 ++ .../sql/catalyst/parser/CatalystQlSuite.scala | 52 +- .../sql/catalyst/parser/DataTypeParserSuite.scala | 55 +- .../sql/catalyst/parser/ng/ErrorParserSuite.scala | 67 + .../catalyst/parser/ng/ExpressionParserSuite.scala | 497 +++++++ .../sql/catalyst/parser/ng/PlanParserSuite.scala | 429 ++++++ .../parser/ng/TableIdentifierParserSuite.scala | 42 + .../apache/spark/sql/catalyst/plans/PlanTest.scala | 20 +- .../spark/sql/execution/SparkSqlParser.scala | 219 +++ .../scala/org/apache/spark/sql/functions.scala | 5 +- .../apache/spark/sql/internal/SessionState.scala | 2 +- .../scala/org/apache/spark/sql/JoinSuite.scala | 4 +- .../scala/org/apache/spark/sql/SQLQuerySuite.scala | 2 +- 29 files changed, 4127 insertions(+), 66 deletions(-) create mode 100644 sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/ng/SqlBase.g4 create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ng/AstBuilder.scala create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ng/ParseDriver.scala create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ng/ParserUtils.scala create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ng/ErrorParserSuite.scala create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ng/ExpressionParserSuite.scala create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ng/PlanParserSuite.scala create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ng/TableIdentifierParserSuite.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala diff --git a/LICENSE b/LICENSE index d7a790a628..5a8c78b98b 100644 --- a/LICENSE +++ b/LICENSE @@ -238,6 +238,7 @@ The text of each license is also included at licenses/LICENSE-[project].txt. (BSD 3 Clause) netlib core (com.github.fommil.netlib:core:1.1.2 - https://github.com/fommil/netlib-java/core) (BSD 3 Clause) JPMML-Model (org.jpmml:pmml-model:1.2.7 - https://github.com/jpmml/jpmml-model) (BSD License) AntLR Parser Generator (antlr:antlr:2.7.7 - http://www.antlr.org/) + (BSD License) ANTLR 4.5.2-1 (org.antlr:antlr4:4.5.2-1 - http://wwww.antlr.org/) (BSD licence) ANTLR ST4 4.0.4 (org.antlr:ST4:4.0.4 - http://www.stringtemplate.org) (BSD licence) ANTLR StringTemplate (org.antlr:stringtemplate:3.2.1 - http://www.stringtemplate.org) (BSD License) Javolution (javolution:javolution:5.5.1 - http://javolution.org) diff --git a/dev/deps/spark-deps-hadoop-2.2 b/dev/deps/spark-deps-hadoop-2.2 index 512675a599..7c2f88bdb1 100644 --- a/dev/deps/spark-deps-hadoop-2.2 +++ b/dev/deps/spark-deps-hadoop-2.2 @@ -3,6 +3,7 @@ RoaringBitmap-0.5.11.jar ST4-4.0.4.jar activation-1.1.jar antlr-runtime-3.5.2.jar +antlr4-runtime-4.5.2-1.jar aopalliance-1.0.jar apache-log4j-extras-1.2.17.jar arpack_combined_all-0.1.jar diff --git a/dev/deps/spark-deps-hadoop-2.3 b/dev/deps/spark-deps-hadoop-2.3 index 31f8694fed..f4d600038d 100644 --- a/dev/deps/spark-deps-hadoop-2.3 +++ b/dev/deps/spark-deps-hadoop-2.3 @@ -3,6 +3,7 @@ RoaringBitmap-0.5.11.jar ST4-4.0.4.jar activation-1.1.1.jar antlr-runtime-3.5.2.jar +antlr4-runtime-4.5.2-1.jar aopalliance-1.0.jar apache-log4j-extras-1.2.17.jar arpack_combined_all-0.1.jar diff --git a/dev/deps/spark-deps-hadoop-2.4 b/dev/deps/spark-deps-hadoop-2.4 index 0fa8bccab0..7c5e2c35bd 100644 --- a/dev/deps/spark-deps-hadoop-2.4 +++ b/dev/deps/spark-deps-hadoop-2.4 @@ -3,6 +3,7 @@ RoaringBitmap-0.5.11.jar ST4-4.0.4.jar activation-1.1.1.jar antlr-runtime-3.5.2.jar +antlr4-runtime-4.5.2-1.jar aopalliance-1.0.jar apache-log4j-extras-1.2.17.jar arpack_combined_all-0.1.jar diff --git a/dev/deps/spark-deps-hadoop-2.6 b/dev/deps/spark-deps-hadoop-2.6 index 8d2f6e6e32..03d9a51057 100644 --- a/dev/deps/spark-deps-hadoop-2.6 +++ b/dev/deps/spark-deps-hadoop-2.6 @@ -3,6 +3,7 @@ RoaringBitmap-0.5.11.jar ST4-4.0.4.jar activation-1.1.1.jar antlr-runtime-3.5.2.jar +antlr4-runtime-4.5.2-1.jar aopalliance-1.0.jar apache-log4j-extras-1.2.17.jar apacheds-i18n-2.0.0-M15.jar diff --git a/dev/deps/spark-deps-hadoop-2.7 b/dev/deps/spark-deps-hadoop-2.7 index a114c4ae8d..5765071a1c 100644 --- a/dev/deps/spark-deps-hadoop-2.7 +++ b/dev/deps/spark-deps-hadoop-2.7 @@ -3,6 +3,7 @@ RoaringBitmap-0.5.11.jar ST4-4.0.4.jar activation-1.1.1.jar antlr-runtime-3.5.2.jar +antlr4-runtime-4.5.2-1.jar aopalliance-1.0.jar apache-log4j-extras-1.2.17.jar apacheds-i18n-2.0.0-M15.jar diff --git a/pom.xml b/pom.xml index b4cfa3a598..475f0544bd 100644 --- a/pom.xml +++ b/pom.xml @@ -178,6 +178,7 @@ 1.3.9 0.9.2 3.5.2 + 4.5.2-1 ${java.home} @@ -1759,6 +1760,11 @@ antlr-runtime ${antlr.version} + + org.antlr + antlr4-runtime + ${antlr4.version} + diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index fb229b979d..39a9e16f7e 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -25,6 +25,7 @@ import sbt._ import sbt.Classpaths.publishTask import sbt.Keys._ import sbtunidoc.Plugin.UnidocKeys.unidocGenjavadocVersion +import com.simplytyped.Antlr4Plugin._ import com.typesafe.sbt.pom.{PomBuild, SbtPomKeys} import com.typesafe.tools.mima.plugin.MimaKeys @@ -401,7 +402,10 @@ object OldDeps { } object Catalyst { - lazy val settings = Seq( + lazy val settings = antlr4Settings ++ Seq( + antlr4PackageName in Antlr4 := Some("org.apache.spark.sql.catalyst.parser.ng"), + antlr4GenListener in Antlr4 := true, + antlr4GenVisitor in Antlr4 := true, // ANTLR code-generation step. // // This has been heavily inspired by com.github.stefri.sbt-antlr (0.5.3). It fixes a number of @@ -414,7 +418,7 @@ object Catalyst { "SparkSqlLexer.g", "SparkSqlParser.g") val sourceDir = (sourceDirectory in Compile).value / "antlr3" - val targetDir = (sourceManaged in Compile).value + val targetDir = (sourceManaged in Compile).value / "antlr3" // Create default ANTLR Tool. val antlr = new org.antlr.Tool diff --git a/project/plugins.sbt b/project/plugins.sbt index eeca94a47c..d9ed7962bf 100644 --- a/project/plugins.sbt +++ b/project/plugins.sbt @@ -23,3 +23,9 @@ libraryDependencies += "org.ow2.asm" % "asm" % "5.0.3" libraryDependencies += "org.ow2.asm" % "asm-commons" % "5.0.3" libraryDependencies += "org.antlr" % "antlr" % "3.5.2" + + +// TODO I am not sure we want such a dep. +resolvers += "simplytyped" at "http://simplytyped.github.io/repo/releases" + +addSbtPlugin("com.simplytyped" % "sbt-antlr4" % "0.7.10") diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 83ef76c13c..1a5d422af9 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -51,7 +51,7 @@ from pyspark.sql.types import UserDefinedType, _infer_type from pyspark.tests import ReusedPySparkTestCase from pyspark.sql.functions import UserDefinedFunction, sha2 from pyspark.sql.window import Window -from pyspark.sql.utils import AnalysisException, IllegalArgumentException +from pyspark.sql.utils import AnalysisException, ParseException, IllegalArgumentException class UTCOffsetTimezone(datetime.tzinfo): @@ -1130,7 +1130,9 @@ class SQLTests(ReusedPySparkTestCase): def test_capture_analysis_exception(self): self.assertRaises(AnalysisException, lambda: self.sqlCtx.sql("select abc")) self.assertRaises(AnalysisException, lambda: self.df.selectExpr("a + b")) - self.assertRaises(AnalysisException, lambda: self.sqlCtx.sql("abc")) + + def test_capture_parse_exception(self): + self.assertRaises(ParseException, lambda: self.sqlCtx.sql("abc")) def test_capture_illegalargument_exception(self): self.assertRaisesRegexp(IllegalArgumentException, "Setting negative mapred.reduce.tasks", diff --git a/python/pyspark/sql/utils.py b/python/pyspark/sql/utils.py index b0a0373372..b89ea8c6e0 100644 --- a/python/pyspark/sql/utils.py +++ b/python/pyspark/sql/utils.py @@ -33,6 +33,12 @@ class AnalysisException(CapturedException): """ +class ParseException(CapturedException): + """ + Failed to parse a SQL command. + """ + + class IllegalArgumentException(CapturedException): """ Passed an illegal or inappropriate argument. @@ -49,6 +55,8 @@ def capture_sql_exception(f): e.java_exception.getStackTrace())) if s.startswith('org.apache.spark.sql.AnalysisException: '): raise AnalysisException(s.split(': ', 1)[1], stackTrace) + if s.startswith('org.apache.spark.sql.catalyst.parser.ng.ParseException: '): + raise ParseException(s.split(': ', 1)[1], stackTrace) if s.startswith('java.lang.IllegalArgumentException: '): raise IllegalArgumentException(s.split(': ', 1)[1], stackTrace) raise diff --git a/sql/catalyst/pom.xml b/sql/catalyst/pom.xml index 5d1d9edd25..c834a011f1 100644 --- a/sql/catalyst/pom.xml +++ b/sql/catalyst/pom.xml @@ -75,6 +75,10 @@ org.antlr antlr-runtime + + org.antlr + antlr4-runtime + commons-codec commons-codec diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/ng/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/ng/SqlBase.g4 new file mode 100644 index 0000000000..e46fd9bed5 --- /dev/null +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/ng/SqlBase.g4 @@ -0,0 +1,911 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * This file is an adaptation of Presto's presto-parser/src/main/antlr4/com/facebook/presto/sql/parser/SqlBase.g4 grammar. + */ + +grammar SqlBase; + +tokens { + DELIMITER +} + +singleStatement + : statement EOF + ; + +singleExpression + : namedExpression EOF + ; + +singleTableIdentifier + : tableIdentifier EOF + ; + +singleDataType + : dataType EOF + ; + +statement + : query #statementDefault + | USE db=identifier #use + | CREATE DATABASE (IF NOT EXISTS)? identifier + (COMMENT comment=STRING)? locationSpec? + (WITH DBPROPERTIES tablePropertyList)? #createDatabase + | ALTER DATABASE identifier SET DBPROPERTIES tablePropertyList #setDatabaseProperties + | DROP DATABASE (IF EXISTS)? identifier (RESTRICT | CASCADE)? #dropDatabase + | createTableHeader ('(' colTypeList ')')? tableProvider + (OPTIONS tablePropertyList)? #createTableUsing + | createTableHeader tableProvider + (OPTIONS tablePropertyList)? AS? query #createTableUsing + | createTableHeader ('(' colTypeList ')')? (COMMENT STRING)? + (PARTITIONED BY identifierList)? bucketSpec? skewSpec? + rowFormat? createFileFormat? locationSpec? + (TBLPROPERTIES tablePropertyList)? + (AS? query)? #createTable + | ANALYZE TABLE tableIdentifier partitionSpec? COMPUTE STATISTICS + (identifier | FOR COLUMNS identifierSeq?) #analyze + | ALTER TABLE from=tableIdentifier RENAME TO to=tableIdentifier #renameTable + | ALTER TABLE tableIdentifier + SET TBLPROPERTIES tablePropertyList #setTableProperties + | ALTER TABLE tableIdentifier + UNSET TBLPROPERTIES (IF EXISTS)? tablePropertyList #unsetTableProperties + | ALTER TABLE tableIdentifier (partitionSpec)? + SET SERDE STRING (WITH SERDEPROPERTIES tablePropertyList)? #setTableSerDe + | ALTER TABLE tableIdentifier (partitionSpec)? + SET SERDEPROPERTIES tablePropertyList #setTableSerDe + | ALTER TABLE tableIdentifier bucketSpec #bucketTable + | ALTER TABLE tableIdentifier NOT CLUSTERED #unclusterTable + | ALTER TABLE tableIdentifier NOT SORTED #unsortTable + | ALTER TABLE tableIdentifier skewSpec #skewTable + | ALTER TABLE tableIdentifier NOT SKEWED #unskewTable + | ALTER TABLE tableIdentifier NOT STORED AS DIRECTORIES #unstoreTable + | ALTER TABLE tableIdentifier + SET SKEWED LOCATION skewedLocationList #setTableSkewLocations + | ALTER TABLE tableIdentifier ADD (IF NOT EXISTS)? + partitionSpecLocation+ #addTablePartition + | ALTER TABLE tableIdentifier + from=partitionSpec RENAME TO to=partitionSpec #renameTablePartition + | ALTER TABLE from=tableIdentifier + EXCHANGE partitionSpec WITH TABLE to=tableIdentifier #exchangeTablePartition + | ALTER TABLE tableIdentifier + DROP (IF EXISTS)? partitionSpec (',' partitionSpec)* PURGE? #dropTablePartitions + | ALTER TABLE tableIdentifier ARCHIVE partitionSpec #archiveTablePartition + | ALTER TABLE tableIdentifier UNARCHIVE partitionSpec #unarchiveTablePartition + | ALTER TABLE tableIdentifier partitionSpec? + SET FILEFORMAT fileFormat #setTableFileFormat + | ALTER TABLE tableIdentifier partitionSpec? SET locationSpec #setTableLocation + | ALTER TABLE tableIdentifier TOUCH partitionSpec? #touchTable + | ALTER TABLE tableIdentifier partitionSpec? COMPACT STRING #compactTable + | ALTER TABLE tableIdentifier partitionSpec? CONCATENATE #concatenateTable + | ALTER TABLE tableIdentifier partitionSpec? + CHANGE COLUMN? oldName=identifier colType + (FIRST | AFTER after=identifier)? (CASCADE | RESTRICT)? #changeColumn + | ALTER TABLE tableIdentifier partitionSpec? + ADD COLUMNS '(' colTypeList ')' (CASCADE | RESTRICT)? #addColumns + | ALTER TABLE tableIdentifier partitionSpec? + REPLACE COLUMNS '(' colTypeList ')' (CASCADE | RESTRICT)? #replaceColumns + | DROP TABLE (IF EXISTS)? tableIdentifier PURGE? + (FOR METADATA? REPLICATION '(' STRING ')')? #dropTable + | CREATE (OR REPLACE)? VIEW (IF NOT EXISTS)? tableIdentifier + identifierCommentList? (COMMENT STRING)? + (PARTITIONED ON identifierList)? + (TBLPROPERTIES tablePropertyList)? AS query #createView + | ALTER VIEW tableIdentifier AS? query #alterViewQuery + | CREATE TEMPORARY? FUNCTION qualifiedName AS className=STRING + (USING resource (',' resource)*)? #createFunction + | DROP TEMPORARY? FUNCTION (IF EXISTS)? qualifiedName #dropFunction + | EXPLAIN explainOption* statement #explain + | SHOW TABLES ((FROM | IN) db=identifier)? + (LIKE (qualifiedName | pattern=STRING))? #showTables + | SHOW FUNCTIONS (LIKE? (qualifiedName | pattern=STRING))? #showFunctions + | (DESC | DESCRIBE) FUNCTION EXTENDED? qualifiedName #describeFunction + | (DESC | DESCRIBE) option=(EXTENDED | FORMATTED)? + tableIdentifier partitionSpec? describeColName? #describeTable + | (DESC | DESCRIBE) DATABASE EXTENDED? identifier #describeDatabase + | REFRESH TABLE tableIdentifier #refreshTable + | CACHE LAZY? TABLE identifier (AS? query)? #cacheTable + | UNCACHE TABLE identifier #uncacheTable + | CLEAR CACHE #clearCache + | ADD identifier .*? #addResource + | SET .*? #setConfiguration + | hiveNativeCommands #executeNativeCommand + ; + +hiveNativeCommands + : createTableHeader LIKE tableIdentifier + rowFormat? createFileFormat? locationSpec? + (TBLPROPERTIES tablePropertyList)? + | DELETE FROM tableIdentifier (WHERE booleanExpression)? + | TRUNCATE TABLE tableIdentifier partitionSpec? + (COLUMNS identifierList)? + | ALTER VIEW from=tableIdentifier AS? RENAME TO to=tableIdentifier + | ALTER VIEW from=tableIdentifier AS? + SET TBLPROPERTIES tablePropertyList + | ALTER VIEW from=tableIdentifier AS? + UNSET TBLPROPERTIES (IF EXISTS)? tablePropertyList + | ALTER VIEW from=tableIdentifier AS? + ADD (IF NOT EXISTS)? partitionSpecLocation+ + | ALTER VIEW from=tableIdentifier AS? + DROP (IF EXISTS)? partitionSpec (',' partitionSpec)* PURGE? + | DROP VIEW (IF EXISTS)? qualifiedName + | SHOW COLUMNS (FROM | IN) tableIdentifier ((FROM|IN) identifier)? + | START TRANSACTION (transactionMode (',' transactionMode)*)? + | COMMIT WORK? + | ROLLBACK WORK? + | SHOW PARTITIONS tableIdentifier partitionSpec? + | DFS .*? + | (CREATE | ALTER | DROP | SHOW | DESC | DESCRIBE | REVOKE | GRANT | LOCK | UNLOCK | MSCK | EXPORT | IMPORT | LOAD) .*? + ; + +createTableHeader + : CREATE TEMPORARY? EXTERNAL? TABLE (IF NOT EXISTS)? tableIdentifier + ; + +bucketSpec + : CLUSTERED BY identifierList + (SORTED BY orderedIdentifierList)? + INTO INTEGER_VALUE BUCKETS + ; + +skewSpec + : SKEWED BY identifierList + ON (constantList | nestedConstantList) + (STORED AS DIRECTORIES)? + ; + +locationSpec + : LOCATION STRING + ; + +query + : ctes? queryNoWith + ; + +insertInto + : INSERT OVERWRITE TABLE tableIdentifier partitionSpec? (IF NOT EXISTS)? + | INSERT INTO TABLE? tableIdentifier partitionSpec? + ; + +partitionSpecLocation + : partitionSpec locationSpec? + ; + +partitionSpec + : PARTITION '(' partitionVal (',' partitionVal)* ')' + ; + +partitionVal + : identifier (EQ constant)? + ; + +describeColName + : identifier ('.' (identifier | STRING))* + ; + +ctes + : WITH namedQuery (',' namedQuery)* + ; + +namedQuery + : name=identifier AS? '(' queryNoWith ')' + ; + +tableProvider + : USING qualifiedName + ; + +tablePropertyList + : '(' tableProperty (',' tableProperty)* ')' + ; + +tableProperty + : key=tablePropertyKey (EQ? value=STRING)? + ; + +tablePropertyKey + : looseIdentifier ('.' looseIdentifier)* + | STRING + ; + +constantList + : '(' constant (',' constant)* ')' + ; + +nestedConstantList + : '(' constantList (',' constantList)* ')' + ; + +skewedLocation + : (constant | constantList) EQ STRING + ; + +skewedLocationList + : '(' skewedLocation (',' skewedLocation)* ')' + ; + +createFileFormat + : STORED AS fileFormat + | STORED BY storageHandler + ; + +fileFormat + : INPUTFORMAT inFmt=STRING OUTPUTFORMAT outFmt=STRING (SERDE serdeCls=STRING)? + (INPUTDRIVER inDriver=STRING OUTPUTDRIVER outDriver=STRING)? #tableFileFormat + | identifier #genericFileFormat + ; + +storageHandler + : STRING (WITH SERDEPROPERTIES tablePropertyList)? + ; + +resource + : identifier STRING + ; + +queryNoWith + : insertInto? queryTerm queryOrganization #singleInsertQuery + | fromClause multiInsertQueryBody+ #multiInsertQuery + ; + +queryOrganization + : (ORDER BY order+=sortItem (',' order+=sortItem)*)? + (CLUSTER BY clusterBy+=expression (',' clusterBy+=expression)*)? + (DISTRIBUTE BY distributeBy+=expression (',' distributeBy+=expression)*)? + (SORT BY sort+=sortItem (',' sort+=sortItem)*)? + windows? + (LIMIT limit=expression)? + ; + +multiInsertQueryBody + : insertInto? + querySpecification + queryOrganization + ; + +queryTerm + : queryPrimary #queryTermDefault + | left=queryTerm operator=(INTERSECT | UNION | EXCEPT) setQuantifier? right=queryTerm #setOperation + ; + +queryPrimary + : querySpecification #queryPrimaryDefault + | TABLE tableIdentifier #table + | inlineTable #inlineTableDefault1 + | '(' queryNoWith ')' #subquery + ; + +sortItem + : expression ordering=(ASC | DESC)? + ; + +querySpecification + : (((SELECT kind=TRANSFORM '(' namedExpressionSeq ')' + | kind=MAP namedExpressionSeq + | kind=REDUCE namedExpressionSeq)) + inRowFormat=rowFormat? + (RECORDWRITER recordWriter=STRING)? + USING script=STRING + (AS (identifierSeq | colTypeList | ('(' (identifierSeq | colTypeList) ')')))? + outRowFormat=rowFormat? + (RECORDREADER recordReader=STRING)? + fromClause? + (WHERE where=booleanExpression)?) + | ((kind=SELECT setQuantifier? namedExpressionSeq fromClause? + | fromClause (kind=SELECT setQuantifier? namedExpressionSeq)?) + lateralView* + (WHERE where=booleanExpression)? + aggregation? + (HAVING having=booleanExpression)? + windows?) + ; + +fromClause + : FROM relation (',' relation)* lateralView* + ; + +aggregation + : GROUP BY groupingExpressions+=expression (',' groupingExpressions+=expression)* ( + WITH kind=ROLLUP + | WITH kind=CUBE + | kind=GROUPING SETS '(' groupingSet (',' groupingSet)* ')')? + ; + +groupingSet + : '(' (expression (',' expression)*)? ')' + | expression + ; + +lateralView + : LATERAL VIEW (OUTER)? qualifiedName '(' (expression (',' expression)*)? ')' tblName=identifier (AS? colName+=identifier (',' colName+=identifier)*)? + ; + +setQuantifier + : DISTINCT + | ALL + ; + +relation + : left=relation + ((CROSS | joinType) JOIN right=relation joinCriteria? + | NATURAL joinType JOIN right=relation + ) #joinRelation + | relationPrimary #relationDefault + ; + +joinType + : INNER? + | LEFT OUTER? + | LEFT SEMI + | RIGHT OUTER? + | FULL OUTER? + ; + +joinCriteria + : ON booleanExpression + | USING '(' identifier (',' identifier)* ')' + ; + +sample + : TABLESAMPLE '(' + ( (percentage=(INTEGER_VALUE | DECIMAL_VALUE) sampleType=PERCENTLIT) + | (expression sampleType=ROWS) + | (sampleType=BUCKET numerator=INTEGER_VALUE OUT OF denominator=INTEGER_VALUE (ON identifier)?)) + ')' + ; + +identifierList + : '(' identifierSeq ')' + ; + +identifierSeq + : identifier (',' identifier)* + ; + +orderedIdentifierList + : '(' orderedIdentifier (',' orderedIdentifier)* ')' + ; + +orderedIdentifier + : identifier ordering=(ASC | DESC)? + ; + +identifierCommentList + : '(' identifierComment (',' identifierComment)* ')' + ; + +identifierComment + : identifier (COMMENT STRING)? + ; + +relationPrimary + : tableIdentifier sample? (AS? identifier)? #tableName + | '(' queryNoWith ')' sample? (AS? identifier)? #aliasedQuery + | '(' relation ')' sample? (AS? identifier)? #aliasedRelation + | inlineTable #inlineTableDefault2 + ; + +inlineTable + : VALUES expression (',' expression)* (AS? identifier identifierList?)? + ; + +rowFormat + : ROW FORMAT SERDE name=STRING (WITH SERDEPROPERTIES props=tablePropertyList)? #rowFormatSerde + | ROW FORMAT DELIMITED + (FIELDS TERMINATED BY fieldsTerminatedBy=STRING (ESCAPED BY escapedBy=STRING)?)? + (COLLECTION ITEMS TERMINATED BY collectionItemsTerminatedBy=STRING)? + (MAP KEYS TERMINATED BY keysTerminatedBy=STRING)? + (LINES TERMINATED BY linesSeparatedBy=STRING)? + (NULL DEFINED AS nullDefinedAs=STRING)? #rowFormatDelimited + ; + +tableIdentifier + : (db=identifier '.')? table=identifier + ; + +namedExpression + : expression (AS? (identifier | identifierList))? + ; + +namedExpressionSeq + : namedExpression (',' namedExpression)* + ; + +expression + : booleanExpression + ; + +booleanExpression + : predicated #booleanDefault + | NOT booleanExpression #logicalNot + | left=booleanExpression operator=AND right=booleanExpression #logicalBinary + | left=booleanExpression operator=OR right=booleanExpression #logicalBinary + | EXISTS '(' query ')' #exists + ; + +// workaround for: +// https://github.com/antlr/antlr4/issues/780 +// https://github.com/antlr/antlr4/issues/781 +predicated + : valueExpression predicate[$valueExpression.ctx]? + ; + +predicate[ParserRuleContext value] + : NOT? BETWEEN lower=valueExpression AND upper=valueExpression #between + | NOT? IN '(' expression (',' expression)* ')' #inList + | NOT? IN '(' query ')' #inSubquery + | NOT? like=(RLIKE | LIKE) pattern=valueExpression #like + | IS NOT? NULL #nullPredicate + ; + +valueExpression + : primaryExpression #valueExpressionDefault + | operator=(MINUS | PLUS | TILDE) valueExpression #arithmeticUnary + | left=valueExpression operator=(ASTERISK | SLASH | PERCENT | DIV) right=valueExpression #arithmeticBinary + | left=valueExpression operator=(PLUS | MINUS) right=valueExpression #arithmeticBinary + | left=valueExpression operator=AMPERSAND right=valueExpression #arithmeticBinary + | left=valueExpression operator=HAT right=valueExpression #arithmeticBinary + | left=valueExpression operator=PIPE right=valueExpression #arithmeticBinary + | left=valueExpression comparisonOperator right=valueExpression #comparison + ; + +primaryExpression + : constant #constantDefault + | ASTERISK #star + | qualifiedName '.' ASTERISK #star + | '(' expression (',' expression)+ ')' #rowConstructor + | qualifiedName '(' (setQuantifier? expression (',' expression)*)? ')' (OVER windowSpec)? #functionCall + | '(' query ')' #subqueryExpression + | CASE valueExpression whenClause+ (ELSE elseExpression=expression)? END #simpleCase + | CASE whenClause+ (ELSE elseExpression=expression)? END #searchedCase + | CAST '(' expression AS dataType ')' #cast + | value=primaryExpression '[' index=valueExpression ']' #subscript + | identifier #columnReference + | base=primaryExpression '.' fieldName=identifier #dereference + | '(' expression ')' #parenthesizedExpression + ; + +constant + : NULL #nullLiteral + | interval #intervalLiteral + | identifier STRING #typeConstructor + | number #numericLiteral + | booleanValue #booleanLiteral + | STRING+ #stringLiteral + ; + +comparisonOperator + : EQ | NEQ | NEQJ | LT | LTE | GT | GTE | NSEQ + ; + +booleanValue + : TRUE | FALSE + ; + +interval + : INTERVAL intervalField* + ; + +intervalField + : value=intervalValue unit=identifier (TO to=identifier)? + ; + +intervalValue + : (PLUS | MINUS)? (INTEGER_VALUE | DECIMAL_VALUE) + | STRING + ; + +dataType + : complex=ARRAY '<' dataType '>' #complexDataType + | complex=MAP '<' dataType ',' dataType '>' #complexDataType + | complex=STRUCT ('<' colTypeList? '>' | NEQ) #complexDataType + | identifier ('(' INTEGER_VALUE (',' INTEGER_VALUE)* ')')? #primitiveDataType + ; + +colTypeList + : colType (',' colType)* + ; + +colType + : identifier ':'? dataType (COMMENT STRING)? + ; + +whenClause + : WHEN condition=expression THEN result=expression + ; + +windows + : WINDOW namedWindow (',' namedWindow)* + ; + +namedWindow + : identifier AS windowSpec + ; + +windowSpec + : name=identifier #windowRef + | '(' + ( CLUSTER BY partition+=expression (',' partition+=expression)* + | ((PARTITION | DISTRIBUTE) BY partition+=expression (',' partition+=expression)*)? + ((ORDER | SORT) BY sortItem (',' sortItem)*)?) + windowFrame? + ')' #windowDef + ; + +windowFrame + : frameType=RANGE start=frameBound + | frameType=ROWS start=frameBound + | frameType=RANGE BETWEEN start=frameBound AND end=frameBound + | frameType=ROWS BETWEEN start=frameBound AND end=frameBound + ; + +frameBound + : UNBOUNDED boundType=(PRECEDING | FOLLOWING) + | boundType=CURRENT ROW + | expression boundType=(PRECEDING | FOLLOWING) + ; + + +explainOption + : LOGICAL | FORMATTED | EXTENDED + ; + +transactionMode + : ISOLATION LEVEL SNAPSHOT #isolationLevel + | READ accessMode=(ONLY | WRITE) #transactionAccessMode + ; + +qualifiedName + : identifier ('.' identifier)* + ; + +// Identifier that also allows the use of a number of SQL keywords (mainly for backwards compatibility). +looseIdentifier + : identifier + | FROM + | TO + | TABLE + | WITH + ; + +identifier + : IDENTIFIER #unquotedIdentifier + | quotedIdentifier #quotedIdentifierAlternative + | nonReserved #unquotedIdentifier + ; + +quotedIdentifier + : BACKQUOTED_IDENTIFIER + ; + +number + : DECIMAL_VALUE #decimalLiteral + | SCIENTIFIC_DECIMAL_VALUE #scientificDecimalLiteral + | INTEGER_VALUE #integerLiteral + | BIGINT_LITERAL #bigIntLiteral + | SMALLINT_LITERAL #smallIntLiteral + | TINYINT_LITERAL #tinyIntLiteral + | DOUBLE_LITERAL #doubleLiteral + ; + +nonReserved + : SHOW | TABLES | COLUMNS | COLUMN | PARTITIONS | FUNCTIONS + | ADD + | OVER | PARTITION | RANGE | ROWS | PRECEDING | FOLLOWING | CURRENT | ROW | MAP | ARRAY | STRUCT + | LATERAL | WINDOW | REDUCE | TRANSFORM | USING | SERDE | SERDEPROPERTIES | RECORDREADER + | DELIMITED | FIELDS | TERMINATED | COLLECTION | ITEMS | KEYS | ESCAPED | LINES | SEPARATED + | EXTENDED | REFRESH | CLEAR | CACHE | UNCACHE | LAZY | TEMPORARY | OPTIONS + | GROUPING | CUBE | ROLLUP + | EXPLAIN | FORMAT | LOGICAL | FORMATTED + | TABLESAMPLE | USE | TO | BUCKET | PERCENTLIT | OUT | OF + | SET + | VIEW | REPLACE + | IF + | NO | DATA + | START | TRANSACTION | COMMIT | ROLLBACK | WORK | ISOLATION | LEVEL + | SNAPSHOT | READ | WRITE | ONLY + | SORT | CLUSTER | DISTRIBUTE UNSET | TBLPROPERTIES | SKEWED | STORED | DIRECTORIES | LOCATION + | EXCHANGE | ARCHIVE | UNARCHIVE | FILEFORMAT | TOUCH | COMPACT | CONCATENATE | CHANGE | FIRST + | AFTER | CASCADE | RESTRICT | BUCKETS | CLUSTERED | SORTED | PURGE | INPUTFORMAT | OUTPUTFORMAT + | INPUTDRIVER | OUTPUTDRIVER | DBPROPERTIES | DFS | TRUNCATE | METADATA | REPLICATION | COMPUTE + | STATISTICS | ANALYZE | PARTITIONED | EXTERNAL | DEFINED | RECORDWRITER + | REVOKE | GRANT | LOCK | UNLOCK | MSCK | EXPORT | IMPORT | LOAD | VALUES | COMMENT + ; + +SELECT: 'SELECT'; +FROM: 'FROM'; +ADD: 'ADD'; +AS: 'AS'; +ALL: 'ALL'; +DISTINCT: 'DISTINCT'; +WHERE: 'WHERE'; +GROUP: 'GROUP'; +BY: 'BY'; +GROUPING: 'GROUPING'; +SETS: 'SETS'; +CUBE: 'CUBE'; +ROLLUP: 'ROLLUP'; +ORDER: 'ORDER'; +HAVING: 'HAVING'; +LIMIT: 'LIMIT'; +AT: 'AT'; +OR: 'OR'; +AND: 'AND'; +IN: 'IN'; +NOT: 'NOT' | '!'; +NO: 'NO'; +EXISTS: 'EXISTS'; +BETWEEN: 'BETWEEN'; +LIKE: 'LIKE'; +RLIKE: 'RLIKE' | 'REGEXP'; +IS: 'IS'; +NULL: 'NULL'; +TRUE: 'TRUE'; +FALSE: 'FALSE'; +NULLS: 'NULLS'; +ASC: 'ASC'; +DESC: 'DESC'; +FOR: 'FOR'; +INTERVAL: 'INTERVAL'; +CASE: 'CASE'; +WHEN: 'WHEN'; +THEN: 'THEN'; +ELSE: 'ELSE'; +END: 'END'; +JOIN: 'JOIN'; +CROSS: 'CROSS'; +OUTER: 'OUTER'; +INNER: 'INNER'; +LEFT: 'LEFT'; +SEMI: 'SEMI'; +RIGHT: 'RIGHT'; +FULL: 'FULL'; +NATURAL: 'NATURAL'; +ON: 'ON'; +LATERAL: 'LATERAL'; +WINDOW: 'WINDOW'; +OVER: 'OVER'; +PARTITION: 'PARTITION'; +RANGE: 'RANGE'; +ROWS: 'ROWS'; +UNBOUNDED: 'UNBOUNDED'; +PRECEDING: 'PRECEDING'; +FOLLOWING: 'FOLLOWING'; +CURRENT: 'CURRENT'; +ROW: 'ROW'; +WITH: 'WITH'; +VALUES: 'VALUES'; +CREATE: 'CREATE'; +TABLE: 'TABLE'; +VIEW: 'VIEW'; +REPLACE: 'REPLACE'; +INSERT: 'INSERT'; +DELETE: 'DELETE'; +INTO: 'INTO'; +DESCRIBE: 'DESCRIBE'; +EXPLAIN: 'EXPLAIN'; +FORMAT: 'FORMAT'; +LOGICAL: 'LOGICAL'; +CAST: 'CAST'; +SHOW: 'SHOW'; +TABLES: 'TABLES'; +COLUMNS: 'COLUMNS'; +COLUMN: 'COLUMN'; +USE: 'USE'; +PARTITIONS: 'PARTITIONS'; +FUNCTIONS: 'FUNCTIONS'; +DROP: 'DROP'; +UNION: 'UNION'; +EXCEPT: 'EXCEPT'; +INTERSECT: 'INTERSECT'; +TO: 'TO'; +TABLESAMPLE: 'TABLESAMPLE'; +STRATIFY: 'STRATIFY'; +ALTER: 'ALTER'; +RENAME: 'RENAME'; +ARRAY: 'ARRAY'; +MAP: 'MAP'; +STRUCT: 'STRUCT'; +COMMENT: 'COMMENT'; +SET: 'SET'; +DATA: 'DATA'; +START: 'START'; +TRANSACTION: 'TRANSACTION'; +COMMIT: 'COMMIT'; +ROLLBACK: 'ROLLBACK'; +WORK: 'WORK'; +ISOLATION: 'ISOLATION'; +LEVEL: 'LEVEL'; +SNAPSHOT: 'SNAPSHOT'; +READ: 'READ'; +WRITE: 'WRITE'; +ONLY: 'ONLY'; + +IF: 'IF'; + +EQ : '=' | '=='; +NSEQ: '<=>'; +NEQ : '<>'; +NEQJ: '!='; +LT : '<'; +LTE : '<='; +GT : '>'; +GTE : '>='; + +PLUS: '+'; +MINUS: '-'; +ASTERISK: '*'; +SLASH: '/'; +PERCENT: '%'; +DIV: 'DIV'; +TILDE: '~'; +AMPERSAND: '&'; +PIPE: '|'; +HAT: '^'; + +PERCENTLIT: 'PERCENT'; +BUCKET: 'BUCKET'; +OUT: 'OUT'; +OF: 'OF'; + +SORT: 'SORT'; +CLUSTER: 'CLUSTER'; +DISTRIBUTE: 'DISTRIBUTE'; +OVERWRITE: 'OVERWRITE'; +TRANSFORM: 'TRANSFORM'; +REDUCE: 'REDUCE'; +USING: 'USING'; +SERDE: 'SERDE'; +SERDEPROPERTIES: 'SERDEPROPERTIES'; +RECORDREADER: 'RECORDREADER'; +RECORDWRITER: 'RECORDWRITER'; +DELIMITED: 'DELIMITED'; +FIELDS: 'FIELDS'; +TERMINATED: 'TERMINATED'; +COLLECTION: 'COLLECTION'; +ITEMS: 'ITEMS'; +KEYS: 'KEYS'; +ESCAPED: 'ESCAPED'; +LINES: 'LINES'; +SEPARATED: 'SEPARATED'; +FUNCTION: 'FUNCTION'; +EXTENDED: 'EXTENDED'; +REFRESH: 'REFRESH'; +CLEAR: 'CLEAR'; +CACHE: 'CACHE'; +UNCACHE: 'UNCACHE'; +LAZY: 'LAZY'; +FORMATTED: 'FORMATTED'; +TEMPORARY: 'TEMPORARY' | 'TEMP'; +OPTIONS: 'OPTIONS'; +UNSET: 'UNSET'; +TBLPROPERTIES: 'TBLPROPERTIES'; +DBPROPERTIES: 'DBPROPERTIES'; +BUCKETS: 'BUCKETS'; +SKEWED: 'SKEWED'; +STORED: 'STORED'; +DIRECTORIES: 'DIRECTORIES'; +LOCATION: 'LOCATION'; +EXCHANGE: 'EXCHANGE'; +ARCHIVE: 'ARCHIVE'; +UNARCHIVE: 'UNARCHIVE'; +FILEFORMAT: 'FILEFORMAT'; +TOUCH: 'TOUCH'; +COMPACT: 'COMPACT'; +CONCATENATE: 'CONCATENATE'; +CHANGE: 'CHANGE'; +FIRST: 'FIRST'; +AFTER: 'AFTER'; +CASCADE: 'CASCADE'; +RESTRICT: 'RESTRICT'; +CLUSTERED: 'CLUSTERED'; +SORTED: 'SORTED'; +PURGE: 'PURGE'; +INPUTFORMAT: 'INPUTFORMAT'; +OUTPUTFORMAT: 'OUTPUTFORMAT'; +INPUTDRIVER: 'INPUTDRIVER'; +OUTPUTDRIVER: 'OUTPUTDRIVER'; +DATABASE: 'DATABASE' | 'SCHEMA'; +DFS: 'DFS'; +TRUNCATE: 'TRUNCATE'; +METADATA: 'METADATA'; +REPLICATION: 'REPLICATION'; +ANALYZE: 'ANALYZE'; +COMPUTE: 'COMPUTE'; +STATISTICS: 'STATISTICS'; +PARTITIONED: 'PARTITIONED'; +EXTERNAL: 'EXTERNAL'; +DEFINED: 'DEFINED'; +REVOKE: 'REVOKE'; +GRANT: 'GRANT'; +LOCK: 'LOCK'; +UNLOCK: 'UNLOCK'; +MSCK: 'MSCK'; +EXPORT: 'EXPORT'; +IMPORT: 'IMPORT'; +LOAD: 'LOAD'; + +STRING + : '\'' ( ~('\''|'\\') | ('\\' .) )* '\'' + | '\"' ( ~('\"'|'\\') | ('\\' .) )* '\"' + ; + +BIGINT_LITERAL + : DIGIT+ 'L' + ; + +SMALLINT_LITERAL + : DIGIT+ 'S' + ; + +TINYINT_LITERAL + : DIGIT+ 'Y' + ; + +INTEGER_VALUE + : DIGIT+ + ; + +DECIMAL_VALUE + : DIGIT+ '.' DIGIT* + | '.' DIGIT+ + ; + +SCIENTIFIC_DECIMAL_VALUE + : DIGIT+ ('.' DIGIT*)? EXPONENT + | '.' DIGIT+ EXPONENT + ; + +DOUBLE_LITERAL + : + (INTEGER_VALUE | DECIMAL_VALUE | SCIENTIFIC_DECIMAL_VALUE) 'D' + ; + +IDENTIFIER + : (LETTER | DIGIT | '_')+ + ; + +BACKQUOTED_IDENTIFIER + : '`' ( ~'`' | '``' )* '`' + ; + +fragment EXPONENT + : 'E' [+-]? DIGIT+ + ; + +fragment DIGIT + : [0-9] + ; + +fragment LETTER + : [A-Z] + ; + +SIMPLE_COMMENT + : '--' ~[\r\n]* '\r'? '\n'? -> channel(HIDDEN) + ; + +BRACKETED_COMMENT + : '/*' .*? '*/' -> channel(HIDDEN) + ; + +WS + : [ \r\n\t]+ -> channel(HIDDEN) + ; + +// Catch-all for anything we can't recognize. +// We use this to be able to ignore and recover all the text +// when splitting statements with DelimiterLexer +UNRECOGNIZED + : . + ; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index 3540014c3e..105947028d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -21,7 +21,7 @@ import java.sql.{Date, Timestamp} import scala.language.implicitConversions -import org.apache.spark.sql.catalyst.analysis.{EliminateSubqueryAliases, UnresolvedAttribute, UnresolvedExtractValue} +import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.{Inner, JoinType} @@ -161,6 +161,10 @@ package object dsl { def lower(e: Expression): Expression = Lower(e) def sqrt(e: Expression): Expression = Sqrt(e) def abs(e: Expression): Expression = Abs(e) + def star(names: String*): Expression = names match { + case Seq() => UnresolvedStar(None) + case target => UnresolvedStar(Option(target)) + } implicit class DslSymbol(sym: Symbol) extends ImplicitAttribute { def s: String = sym.name } // TODO more implicit class for literal? @@ -231,6 +235,12 @@ package object dsl { AttributeReference(s, structType, nullable = true)() def struct(attrs: AttributeReference*): AttributeReference = struct(StructType.fromAttributes(attrs)) + + /** Create a function. */ + def function(exprs: Expression*): UnresolvedFunction = + UnresolvedFunction(s, exprs, isDistinct = false) + def distinctFunction(exprs: Expression*): UnresolvedFunction = + UnresolvedFunction(s, exprs, isDistinct = true) } implicit class DslAttribute(a: AttributeReference) { @@ -243,8 +253,20 @@ package object dsl { object expressions extends ExpressionConversions // scalastyle:ignore object plans { // scalastyle:ignore + def table(ref: String): LogicalPlan = + UnresolvedRelation(TableIdentifier(ref), None) + + def table(db: String, ref: String): LogicalPlan = + UnresolvedRelation(TableIdentifier(ref, Option(db)), None) + implicit class DslLogicalPlan(val logicalPlan: LogicalPlan) { - def select(exprs: NamedExpression*): LogicalPlan = Project(exprs, logicalPlan) + def select(exprs: Expression*): LogicalPlan = { + val namedExpressions = exprs.map { + case e: NamedExpression => e + case e => UnresolvedAlias(e) + } + Project(namedExpressions, logicalPlan) + } def where(condition: Expression): LogicalPlan = Filter(condition, logicalPlan) @@ -296,6 +318,14 @@ package object dsl { analysis.UnresolvedRelation(TableIdentifier(tableName)), Map.empty, logicalPlan, overwrite, false) + def as(alias: String): LogicalPlan = logicalPlan match { + case UnresolvedRelation(tbl, _) => UnresolvedRelation(tbl, Option(alias)) + case plan => SubqueryAlias(alias, plan) + } + + def distribute(exprs: Expression*): LogicalPlan = + RepartitionByExpression(exprs, logicalPlan) + def analyze: LogicalPlan = EliminateSubqueryAliases(analysis.SimpleAnalyzer.execute(logicalPlan)) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ng/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ng/AstBuilder.scala new file mode 100644 index 0000000000..5a64c414fb --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ng/AstBuilder.scala @@ -0,0 +1,1452 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.parser.ng + +import java.sql.{Date, Timestamp} + +import scala.collection.JavaConverters._ +import scala.collection.mutable.ArrayBuffer + +import org.antlr.v4.runtime.{ParserRuleContext, Token} +import org.antlr.v4.runtime.tree.{ParseTree, TerminalNode} + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.catalyst.{InternalRow, TableIdentifier} +import org.apache.spark.sql.catalyst.analysis._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.parser.ng.SqlBaseParser._ +import org.apache.spark.sql.catalyst.plans._ +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.CalendarInterval +import org.apache.spark.util.random.RandomSampler + +/** + * The AstBuilder converts an ANTLR4 ParseTree into a catalyst Expression, LogicalPlan or + * TableIdentifier. + */ +class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { + import ParserUtils._ + + protected def typedVisit[T](ctx: ParseTree): T = { + ctx.accept(this).asInstanceOf[T] + } + + override def visitSingleStatement(ctx: SingleStatementContext): LogicalPlan = withOrigin(ctx) { + visit(ctx.statement).asInstanceOf[LogicalPlan] + } + + override def visitSingleExpression(ctx: SingleExpressionContext): Expression = withOrigin(ctx) { + visitNamedExpression(ctx.namedExpression) + } + + override def visitSingleTableIdentifier( + ctx: SingleTableIdentifierContext): TableIdentifier = withOrigin(ctx) { + visitTableIdentifier(ctx.tableIdentifier) + } + + override def visitSingleDataType(ctx: SingleDataTypeContext): DataType = withOrigin(ctx) { + visit(ctx.dataType).asInstanceOf[DataType] + } + + /* ******************************************************************************************** + * Plan parsing + * ******************************************************************************************** */ + protected def plan(tree: ParserRuleContext): LogicalPlan = typedVisit(tree) + + /** + * Make sure we do not try to create a plan for a native command. + */ + override def visitExecuteNativeCommand(ctx: ExecuteNativeCommandContext): LogicalPlan = null + + /** + * Create a plan for a SHOW FUNCTIONS command. + */ + override def visitShowFunctions(ctx: ShowFunctionsContext): LogicalPlan = withOrigin(ctx) { + import ctx._ + if (qualifiedName != null) { + val names = qualifiedName().identifier().asScala.map(_.getText).toList + names match { + case db :: name :: Nil => + ShowFunctions(Some(db), Some(name)) + case name :: Nil => + ShowFunctions(None, Some(name)) + case _ => + throw new ParseException("SHOW FUNCTIONS unsupported name", ctx) + } + } else if (pattern != null) { + ShowFunctions(None, Some(string(pattern))) + } else { + ShowFunctions(None, None) + } + } + + /** + * Create a plan for a DESCRIBE FUNCTION command. + */ + override def visitDescribeFunction(ctx: DescribeFunctionContext): LogicalPlan = withOrigin(ctx) { + val functionName = ctx.qualifiedName().identifier().asScala.map(_.getText).mkString(".") + DescribeFunction(functionName, ctx.EXTENDED != null) + } + + /** + * Create a top-level plan with Common Table Expressions. + */ + override def visitQuery(ctx: QueryContext): LogicalPlan = withOrigin(ctx) { + val query = plan(ctx.queryNoWith) + + // Apply CTEs + query.optional(ctx.ctes) { + val ctes = ctx.ctes.namedQuery.asScala.map { + case nCtx => + val namedQuery = visitNamedQuery(nCtx) + (namedQuery.alias, namedQuery) + } + + // Check for duplicate names. + ctes.groupBy(_._1).filter(_._2.size > 1).foreach { + case (name, _) => + throw new ParseException( + s"Name '$name' is used for multiple common table expressions", ctx) + } + + With(query, ctes.toMap) + } + } + + /** + * Create a named logical plan. + * + * This is only used for Common Table Expressions. + */ + override def visitNamedQuery(ctx: NamedQueryContext): SubqueryAlias = withOrigin(ctx) { + SubqueryAlias(ctx.name.getText, plan(ctx.queryNoWith)) + } + + /** + * Create a logical plan which allows for multiple inserts using one 'from' statement. These + * queries have the following SQL form: + * {{{ + * [WITH cte...]? + * FROM src + * [INSERT INTO tbl1 SELECT *]+ + * }}} + * For example: + * {{{ + * FROM db.tbl1 A + * INSERT INTO dbo.tbl1 SELECT * WHERE A.value = 10 LIMIT 5 + * INSERT INTO dbo.tbl2 SELECT * WHERE A.value = 12 + * }}} + * This (Hive) feature cannot be combined with set-operators. + */ + override def visitMultiInsertQuery(ctx: MultiInsertQueryContext): LogicalPlan = withOrigin(ctx) { + val from = visitFromClause(ctx.fromClause) + + // Build the insert clauses. + val inserts = ctx.multiInsertQueryBody.asScala.map { + body => + assert(body.querySpecification.fromClause == null, + "Multi-Insert queries cannot have a FROM clause in their individual SELECT statements", + body) + + withQuerySpecification(body.querySpecification, from). + // Add organization statements. + optionalMap(body.queryOrganization)(withQueryResultClauses). + // Add insert. + optionalMap(body.insertInto())(withInsertInto) + } + + // If there are multiple INSERTS just UNION them together into one query. + inserts match { + case Seq(query) => query + case queries => Union(queries) + } + } + + /** + * Create a logical plan for a regular (single-insert) query. + */ + override def visitSingleInsertQuery( + ctx: SingleInsertQueryContext): LogicalPlan = withOrigin(ctx) { + plan(ctx.queryTerm). + // Add organization statements. + optionalMap(ctx.queryOrganization)(withQueryResultClauses). + // Add insert. + optionalMap(ctx.insertInto())(withInsertInto) + } + + /** + * Add an INSERT INTO [TABLE]/INSERT OVERWRITE TABLE operation to the logical plan. + */ + private def withInsertInto( + ctx: InsertIntoContext, + query: LogicalPlan): LogicalPlan = withOrigin(ctx) { + val tableIdent = visitTableIdentifier(ctx.tableIdentifier) + val partitionKeys = Option(ctx.partitionSpec).map(visitPartitionSpec).getOrElse(Map.empty) + + InsertIntoTable( + UnresolvedRelation(tableIdent, None), + partitionKeys, + query, + ctx.OVERWRITE != null, + ctx.EXISTS != null) + } + + /** + * Create a partition specification map. + */ + override def visitPartitionSpec( + ctx: PartitionSpecContext): Map[String, Option[String]] = withOrigin(ctx) { + ctx.partitionVal.asScala.map { pVal => + val name = pVal.identifier.getText.toLowerCase + val value = Option(pVal.constant).map(visitStringConstant) + name -> value + }.toMap + } + + /** + * Create a partition specification map without optional values. + */ + protected def visitNonOptionalPartitionSpec( + ctx: PartitionSpecContext): Map[String, String] = withOrigin(ctx) { + visitPartitionSpec(ctx).mapValues(_.orNull).map(identity) + } + + /** + * Convert a constant of any type into a string. This is typically used in DDL commands, and its + * main purpose is to prevent slight differences due to back to back conversions i.e.: + * String -> Literal -> String. + */ + protected def visitStringConstant(ctx: ConstantContext): String = withOrigin(ctx) { + ctx match { + case s: StringLiteralContext => createString(s) + case o => o.getText + } + } + + /** + * Add ORDER BY/SORT BY/CLUSTER BY/DISTRIBUTE BY/LIMIT/WINDOWS clauses to the logical plan. These + * clauses determine the shape (ordering/partitioning/rows) of the query result. + */ + private def withQueryResultClauses( + ctx: QueryOrganizationContext, + query: LogicalPlan): LogicalPlan = withOrigin(ctx) { + import ctx._ + + // Handle ORDER BY, SORT BY, DISTRIBUTE BY, and CLUSTER BY clause. + val withOrder = if ( + !order.isEmpty && sort.isEmpty && distributeBy.isEmpty && clusterBy.isEmpty) { + // ORDER BY ... + Sort(order.asScala.map(visitSortItem), global = true, query) + } else if (order.isEmpty && !sort.isEmpty && distributeBy.isEmpty && clusterBy.isEmpty) { + // SORT BY ... + Sort(sort.asScala.map(visitSortItem), global = false, query) + } else if (order.isEmpty && sort.isEmpty && !distributeBy.isEmpty && clusterBy.isEmpty) { + // DISTRIBUTE BY ... + RepartitionByExpression(expressionList(distributeBy), query) + } else if (order.isEmpty && !sort.isEmpty && !distributeBy.isEmpty && clusterBy.isEmpty) { + // SORT BY ... DISTRIBUTE BY ... + Sort( + sort.asScala.map(visitSortItem), + global = false, + RepartitionByExpression(expressionList(distributeBy), query)) + } else if (order.isEmpty && sort.isEmpty && distributeBy.isEmpty && !clusterBy.isEmpty) { + // CLUSTER BY ... + val expressions = expressionList(clusterBy) + Sort( + expressions.map(SortOrder(_, Ascending)), + global = false, + RepartitionByExpression(expressions, query)) + } else if (order.isEmpty && sort.isEmpty && distributeBy.isEmpty && clusterBy.isEmpty) { + // [EMPTY] + query + } else { + throw new ParseException( + "Combination of ORDER BY/SORT BY/DISTRIBUTE BY/CLUSTER BY is not supported", ctx) + } + + // WINDOWS + val withWindow = withOrder.optionalMap(windows)(withWindows) + + // LIMIT + withWindow.optional(limit) { + Limit(typedVisit(limit), withWindow) + } + } + + /** + * Create a logical plan using a query specification. + */ + override def visitQuerySpecification( + ctx: QuerySpecificationContext): LogicalPlan = withOrigin(ctx) { + val from = OneRowRelation.optional(ctx.fromClause) { + visitFromClause(ctx.fromClause) + } + withQuerySpecification(ctx, from) + } + + /** + * Add a query specification to a logical plan. The query specification is the core of the logical + * plan, this is where sourcing (FROM clause), transforming (SELECT TRANSFORM/MAP/REDUCE), + * projection (SELECT), aggregation (GROUP BY ... HAVING ...) and filtering (WHERE) takes place. + * + * Note that query hints are ignored (both by the parser and the builder). + */ + private def withQuerySpecification( + ctx: QuerySpecificationContext, + relation: LogicalPlan): LogicalPlan = withOrigin(ctx) { + import ctx._ + + // WHERE + def filter(ctx: BooleanExpressionContext, plan: LogicalPlan): LogicalPlan = { + Filter(expression(ctx), plan) + } + + // Expressions. + val expressions = Option(namedExpressionSeq).toSeq + .flatMap(_.namedExpression.asScala) + .map(typedVisit[Expression]) + + // Create either a transform or a regular query. + val specType = Option(kind).map(_.getType).getOrElse(SqlBaseParser.SELECT) + specType match { + case SqlBaseParser.MAP | SqlBaseParser.REDUCE | SqlBaseParser.TRANSFORM => + // Transform + + // Add where. + val withFilter = relation.optionalMap(where)(filter) + + // Create the attributes. + val (attributes, schemaLess) = if (colTypeList != null) { + // Typed return columns. + (createStructType(colTypeList).toAttributes, false) + } else if (identifierSeq != null) { + // Untyped return columns. + val attrs = visitIdentifierSeq(identifierSeq).map { name => + AttributeReference(name, StringType, nullable = true)() + } + (attrs, false) + } else { + (Seq(AttributeReference("key", StringType)(), + AttributeReference("value", StringType)()), true) + } + + // Create the transform. + ScriptTransformation( + expressions, + string(script), + attributes, + withFilter, + withScriptIOSchema(inRowFormat, recordWriter, outRowFormat, recordReader, schemaLess)) + + case SqlBaseParser.SELECT => + // Regular select + + // Add lateral views. + val withLateralView = ctx.lateralView.asScala.foldLeft(relation)(withGenerate) + + // Add where. + val withFilter = withLateralView.optionalMap(where)(filter) + + // Add aggregation or a project. + val namedExpressions = expressions.map { + case e: NamedExpression => e + case e: Expression => UnresolvedAlias(e) + } + val withProject = if (aggregation != null) { + withAggregation(aggregation, namedExpressions, withFilter) + } else if (namedExpressions.nonEmpty) { + Project(namedExpressions, withFilter) + } else { + withFilter + } + + // Having + val withHaving = withProject.optional(having) { + // Note that we added a cast to boolean. If the expression itself is already boolean, + // the optimizer will get rid of the unnecessary cast. + Filter(Cast(expression(having), BooleanType), withProject) + } + + // Distinct + val withDistinct = if (setQuantifier() != null && setQuantifier().DISTINCT() != null) { + Distinct(withHaving) + } else { + withHaving + } + + // Window + withDistinct.optionalMap(windows)(withWindows) + } + } + + /** + * Create a (Hive based) [[ScriptInputOutputSchema]]. + */ + protected def withScriptIOSchema( + inRowFormat: RowFormatContext, + recordWriter: Token, + outRowFormat: RowFormatContext, + recordReader: Token, + schemaLess: Boolean): ScriptInputOutputSchema = null + + /** + * Create a logical plan for a given 'FROM' clause. Note that we support multiple (comma + * separated) relations here, these get converted into a single plan by condition-less inner join. + */ + override def visitFromClause(ctx: FromClauseContext): LogicalPlan = withOrigin(ctx) { + val from = ctx.relation.asScala.map(plan).reduceLeft(Join(_, _, Inner, None)) + ctx.lateralView.asScala.foldLeft(from)(withGenerate) + } + + /** + * Connect two queries by a Set operator. + * + * Supported Set operators are: + * - UNION [DISTINCT] + * - UNION ALL + * - EXCEPT [DISTINCT] + * - INTERSECT [DISTINCT] + */ + override def visitSetOperation(ctx: SetOperationContext): LogicalPlan = withOrigin(ctx) { + val left = plan(ctx.left) + val right = plan(ctx.right) + val all = Option(ctx.setQuantifier()).exists(_.ALL != null) + ctx.operator.getType match { + case SqlBaseParser.UNION if all => + Union(left, right) + case SqlBaseParser.UNION => + Distinct(Union(left, right)) + case SqlBaseParser.INTERSECT if all => + throw new ParseException("INTERSECT ALL is not supported.", ctx) + case SqlBaseParser.INTERSECT => + Intersect(left, right) + case SqlBaseParser.EXCEPT if all => + throw new ParseException("EXCEPT ALL is not supported.", ctx) + case SqlBaseParser.EXCEPT => + Except(left, right) + } + } + + /** + * Add a [[WithWindowDefinition]] operator to a logical plan. + */ + private def withWindows( + ctx: WindowsContext, + query: LogicalPlan): LogicalPlan = withOrigin(ctx) { + // Collect all window specifications defined in the WINDOW clause. + val baseWindowMap = ctx.namedWindow.asScala.map { + wCtx => + (wCtx.identifier.getText, typedVisit[WindowSpec](wCtx.windowSpec)) + }.toMap + + // Handle cases like + // window w1 as (partition by p_mfgr order by p_name + // range between 2 preceding and 2 following), + // w2 as w1 + val windowMapView = baseWindowMap.mapValues { + case WindowSpecReference(name) => + baseWindowMap.get(name) match { + case Some(spec: WindowSpecDefinition) => + spec + case Some(ref) => + throw new ParseException(s"Window reference '$name' is not a window specification", ctx) + case None => + throw new ParseException(s"Cannot resolve window reference '$name'", ctx) + } + case spec: WindowSpecDefinition => spec + } + + // Note that mapValues creates a view instead of materialized map. We force materialization by + // mapping over identity. + WithWindowDefinition(windowMapView.map(identity), query) + } + + /** + * Add an [[Aggregate]] to a logical plan. + */ + private def withAggregation( + ctx: AggregationContext, + selectExpressions: Seq[NamedExpression], + query: LogicalPlan): LogicalPlan = withOrigin(ctx) { + import ctx._ + val groupByExpressions = expressionList(groupingExpressions) + + if (GROUPING != null) { + // GROUP BY .... GROUPING SETS (...) + val expressionMap = groupByExpressions.zipWithIndex.toMap + val numExpressions = expressionMap.size + val mask = (1 << numExpressions) - 1 + val masks = ctx.groupingSet.asScala.map { + _.expression.asScala.foldLeft(mask) { + case (bitmap, eCtx) => + // Find the index of the expression. + val e = typedVisit[Expression](eCtx) + val index = expressionMap.find(_._1.semanticEquals(e)).map(_._2).getOrElse( + throw new ParseException( + s"$e doesn't show up in the GROUP BY list", ctx)) + // 0 means that the column at the given index is a grouping column, 1 means it is not, + // so we unset the bit in bitmap. + bitmap & ~(1 << (numExpressions - 1 - index)) + } + } + GroupingSets(masks, groupByExpressions, query, selectExpressions) + } else { + // GROUP BY .... (WITH CUBE | WITH ROLLUP)? + val mappedGroupByExpressions = if (CUBE != null) { + Seq(Cube(groupByExpressions)) + } else if (ROLLUP != null) { + Seq(Rollup(groupByExpressions)) + } else { + groupByExpressions + } + Aggregate(mappedGroupByExpressions, selectExpressions, query) + } + } + + /** + * Add a [[Generate]] (Lateral View) to a logical plan. + */ + private def withGenerate( + query: LogicalPlan, + ctx: LateralViewContext): LogicalPlan = withOrigin(ctx) { + val expressions = expressionList(ctx.expression) + + // Create the generator. + val generator = ctx.qualifiedName.getText.toLowerCase match { + case "explode" if expressions.size == 1 => + Explode(expressions.head) + case "json_tuple" => + JsonTuple(expressions) + case other => + withGenerator(other, expressions, ctx) + } + + Generate( + generator, + join = true, + outer = ctx.OUTER != null, + Some(ctx.tblName.getText.toLowerCase), + ctx.colName.asScala.map(_.getText).map(UnresolvedAttribute.apply), + query) + } + + /** + * Create a [[Generator]]. Override this method in order to support custom Generators. + */ + protected def withGenerator( + name: String, + expressions: Seq[Expression], + ctx: LateralViewContext): Generator = { + throw new ParseException(s"Generator function '$name' is not supported", ctx) + } + + /** + * Create a joins between two or more logical plans. + */ + override def visitJoinRelation(ctx: JoinRelationContext): LogicalPlan = withOrigin(ctx) { + /** Build a join between two plans. */ + def join(ctx: JoinRelationContext, left: LogicalPlan, right: LogicalPlan): Join = { + val baseJoinType = ctx.joinType match { + case null => Inner + case jt if jt.FULL != null => FullOuter + case jt if jt.SEMI != null => LeftSemi + case jt if jt.LEFT != null => LeftOuter + case jt if jt.RIGHT != null => RightOuter + case _ => Inner + } + + // Resolve the join type and join condition + val (joinType, condition) = Option(ctx.joinCriteria) match { + case Some(c) if c.USING != null => + val columns = c.identifier.asScala.map { column => + UnresolvedAttribute.quoted(column.getText) + } + (UsingJoin(baseJoinType, columns), None) + case Some(c) if c.booleanExpression != null => + (baseJoinType, Option(expression(c.booleanExpression))) + case None if ctx.NATURAL != null => + (NaturalJoin(baseJoinType), None) + case None => + (baseJoinType, None) + } + Join(left, right, joinType, condition) + } + + // Handle all consecutive join clauses. ANTLR produces a right nested tree in which the the + // first join clause is at the top. However fields of previously referenced tables can be used + // in following join clauses. The tree needs to be reversed in order to make this work. + var result = plan(ctx.left) + var current = ctx + while (current != null) { + current.right match { + case right: JoinRelationContext => + result = join(current, result, plan(right.left)) + current = right + case right => + result = join(current, result, plan(right)) + current = null + } + } + result + } + + /** + * Add a [[Sample]] to a logical plan. + * + * This currently supports the following sampling methods: + * - TABLESAMPLE(x ROWS): Sample the table down to the given number of rows. + * - TABLESAMPLE(x PERCENT): Sample the table down to the given percentage. Note that percentages + * are defined as a number between 0 and 100. + * - TABLESAMPLE(BUCKET x OUT OF y): Sample the table down to a 'x' divided by 'y' fraction. + */ + private def withSample(ctx: SampleContext, query: LogicalPlan): LogicalPlan = withOrigin(ctx) { + // Create a sampled plan if we need one. + def sample(fraction: Double): Sample = { + // The range of fraction accepted by Sample is [0, 1]. Because Hive's block sampling + // function takes X PERCENT as the input and the range of X is [0, 100], we need to + // adjust the fraction. + val eps = RandomSampler.roundingEpsilon + assert(fraction >= 0.0 - eps && fraction <= 1.0 + eps, + s"Sampling fraction ($fraction) must be on interval [0, 1]", + ctx) + Sample(0.0, fraction, withReplacement = false, (math.random * 1000).toInt, query)(true) + } + + ctx.sampleType.getType match { + case SqlBaseParser.ROWS => + Limit(expression(ctx.expression), query) + + case SqlBaseParser.PERCENTLIT => + val fraction = ctx.percentage.getText.toDouble + sample(fraction / 100.0d) + + case SqlBaseParser.BUCKET if ctx.ON != null => + throw new ParseException("TABLESAMPLE(BUCKET x OUT OF y ON id) is not supported", ctx) + + case SqlBaseParser.BUCKET => + sample(ctx.numerator.getText.toDouble / ctx.denominator.getText.toDouble) + } + } + + /** + * Create a logical plan for a sub-query. + */ + override def visitSubquery(ctx: SubqueryContext): LogicalPlan = withOrigin(ctx) { + plan(ctx.queryNoWith) + } + + /** + * Create an un-aliased table reference. This is typically used for top-level table references, + * for example: + * {{{ + * INSERT INTO db.tbl2 + * TABLE db.tbl1 + * }}} + */ + override def visitTable(ctx: TableContext): LogicalPlan = withOrigin(ctx) { + UnresolvedRelation(visitTableIdentifier(ctx.tableIdentifier), None) + } + + /** + * Create an aliased table reference. This is typically used in FROM clauses. + */ + override def visitTableName(ctx: TableNameContext): LogicalPlan = withOrigin(ctx) { + val table = UnresolvedRelation( + visitTableIdentifier(ctx.tableIdentifier), + Option(ctx.identifier).map(_.getText)) + table.optionalMap(ctx.sample)(withSample) + } + + /** + * Create an inline table (a virtual table in Hive parlance). + */ + override def visitInlineTable(ctx: InlineTableContext): LogicalPlan = withOrigin(ctx) { + // Get the backing expressions. + val expressions = ctx.expression.asScala.map { eCtx => + val e = expression(eCtx) + assert(e.foldable, "All expressions in an inline table must be constants.", eCtx) + e + } + + // Validate and evaluate the rows. + val (structType, structConstructor) = expressions.head.dataType match { + case st: StructType => + (st, (e: Expression) => e) + case dt => + val st = CreateStruct(Seq(expressions.head)).dataType + (st, (e: Expression) => CreateStruct(Seq(e))) + } + val rows = expressions.map { + case expression => + val safe = Cast(structConstructor(expression), structType) + safe.eval().asInstanceOf[InternalRow] + } + + // Construct attributes. + val baseAttributes = structType.toAttributes.map(_.withNullability(true)) + val attributes = if (ctx.identifierList != null) { + val aliases = visitIdentifierList(ctx.identifierList) + assert(aliases.size == baseAttributes.size, + "Number of aliases must match the number of fields in an inline table.", ctx) + baseAttributes.zip(aliases).map(p => p._1.withName(p._2)) + } else { + baseAttributes + } + + // Create plan and add an alias if a name has been defined. + LocalRelation(attributes, rows).optionalMap(ctx.identifier)(aliasPlan) + } + + /** + * Create an alias (SubqueryAlias) for a join relation. This is practically the same as + * visitAliasedQuery and visitNamedExpression, ANTLR4 however requires us to use 3 different + * hooks. + */ + override def visitAliasedRelation(ctx: AliasedRelationContext): LogicalPlan = withOrigin(ctx) { + plan(ctx.relation).optionalMap(ctx.sample)(withSample).optionalMap(ctx.identifier)(aliasPlan) + } + + /** + * Create an alias (SubqueryAlias) for a sub-query. This is practically the same as + * visitAliasedRelation and visitNamedExpression, ANTLR4 however requires us to use 3 different + * hooks. + */ + override def visitAliasedQuery(ctx: AliasedQueryContext): LogicalPlan = withOrigin(ctx) { + plan(ctx.queryNoWith).optionalMap(ctx.sample)(withSample).optionalMap(ctx.identifier)(aliasPlan) + } + + /** + * Create an alias (SubqueryAlias) for a LogicalPlan. + */ + private def aliasPlan(alias: IdentifierContext, plan: LogicalPlan): LogicalPlan = { + SubqueryAlias(alias.getText, plan) + } + + /** + * Create a Sequence of Strings for a parenthesis enclosed alias list. + */ + override def visitIdentifierList(ctx: IdentifierListContext): Seq[String] = withOrigin(ctx) { + visitIdentifierSeq(ctx.identifierSeq) + } + + /** + * Create a Sequence of Strings for an identifier list. + */ + override def visitIdentifierSeq(ctx: IdentifierSeqContext): Seq[String] = withOrigin(ctx) { + ctx.identifier.asScala.map(_.getText) + } + + /* ******************************************************************************************** + * Table Identifier parsing + * ******************************************************************************************** */ + /** + * Create a [[TableIdentifier]] from a 'tableName' or 'databaseName'.'tableName' pattern. + */ + override def visitTableIdentifier( + ctx: TableIdentifierContext): TableIdentifier = withOrigin(ctx) { + TableIdentifier(ctx.table.getText, Option(ctx.db).map(_.getText)) + } + + /* ******************************************************************************************** + * Expression parsing + * ******************************************************************************************** */ + /** + * Create an expression from the given context. This method just passes the context on to the + * vistor and only takes care of typing (We assume that the visitor returns an Expression here). + */ + protected def expression(ctx: ParserRuleContext): Expression = typedVisit(ctx) + + /** + * Create sequence of expressions from the given sequence of contexts. + */ + private def expressionList(trees: java.util.List[ExpressionContext]): Seq[Expression] = { + trees.asScala.map(expression) + } + + /** + * Invert a boolean expression if it has a valid NOT clause. + */ + private def invertIfNotDefined(expression: Expression, not: TerminalNode): Expression = { + if (not != null) { + Not(expression) + } else { + expression + } + } + + /** + * Create a star (i.e. all) expression; this selects all elements (in the specified object). + * Both un-targeted (global) and targeted aliases are supported. + */ + override def visitStar(ctx: StarContext): Expression = withOrigin(ctx) { + UnresolvedStar(Option(ctx.qualifiedName()).map(_.identifier.asScala.map(_.getText))) + } + + /** + * Create an aliased expression if an alias is specified. Both single and multi-aliases are + * supported. + */ + override def visitNamedExpression(ctx: NamedExpressionContext): Expression = withOrigin(ctx) { + val e = expression(ctx.expression) + if (ctx.identifier != null) { + Alias(e, ctx.identifier.getText)() + } else if (ctx.identifierList != null) { + MultiAlias(e, visitIdentifierList(ctx.identifierList)) + } else { + e + } + } + + /** + * Combine a number of boolean expressions into a balanced expression tree. These expressions are + * either combined by a logical [[And]] or a logical [[Or]]. + * + * A balanced binary tree is created because regular left recursive trees cause considerable + * performance degradations and can cause stack overflows. + */ + override def visitLogicalBinary(ctx: LogicalBinaryContext): Expression = withOrigin(ctx) { + val expressionType = ctx.operator.getType + val expressionCombiner = expressionType match { + case SqlBaseParser.AND => And.apply _ + case SqlBaseParser.OR => Or.apply _ + } + + // Collect all similar left hand contexts. + val contexts = ArrayBuffer(ctx.right) + var current = ctx.left + def collectContexts: Boolean = current match { + case lbc: LogicalBinaryContext if lbc.operator.getType == expressionType => + contexts += lbc.right + current = lbc.left + true + case _ => + contexts += current + false + } + while (collectContexts) { + // No body - all updates take place in the collectContexts. + } + + // Reverse the contexts to have them in the same sequence as in the SQL statement & turn them + // into expressions. + val expressions = contexts.reverse.map(expression) + + // Create a balanced tree. + def reduceToExpressionTree(low: Int, high: Int): Expression = high - low match { + case 0 => + expressions(low) + case 1 => + expressionCombiner(expressions(low), expressions(high)) + case x => + val mid = low + x / 2 + expressionCombiner( + reduceToExpressionTree(low, mid), + reduceToExpressionTree(mid + 1, high)) + } + reduceToExpressionTree(0, expressions.size - 1) + } + + /** + * Invert a boolean expression. + */ + override def visitLogicalNot(ctx: LogicalNotContext): Expression = withOrigin(ctx) { + Not(expression(ctx.booleanExpression())) + } + + /** + * Create a filtering correlated sub-query. This is not supported yet. + */ + override def visitExists(ctx: ExistsContext): Expression = { + throw new ParseException("EXISTS clauses are not supported.", ctx) + } + + /** + * Create a comparison expression. This compares two expressions. The following comparison + * operators are supported: + * - Equal: '=' or '==' + * - Null-safe Equal: '<=>' + * - Not Equal: '<>' or '!=' + * - Less than: '<' + * - Less then or Equal: '<=' + * - Greater than: '>' + * - Greater then or Equal: '>=' + */ + override def visitComparison(ctx: ComparisonContext): Expression = withOrigin(ctx) { + val left = expression(ctx.left) + val right = expression(ctx.right) + val operator = ctx.comparisonOperator().getChild(0).asInstanceOf[TerminalNode] + operator.getSymbol.getType match { + case SqlBaseParser.EQ => + EqualTo(left, right) + case SqlBaseParser.NSEQ => + EqualNullSafe(left, right) + case SqlBaseParser.NEQ | SqlBaseParser.NEQJ => + Not(EqualTo(left, right)) + case SqlBaseParser.LT => + LessThan(left, right) + case SqlBaseParser.LTE => + LessThanOrEqual(left, right) + case SqlBaseParser.GT => + GreaterThan(left, right) + case SqlBaseParser.GTE => + GreaterThanOrEqual(left, right) + } + } + + /** + * Create a BETWEEN expression. This tests if an expression lies with in the bounds set by two + * other expressions. The inverse can also be created. + */ + override def visitBetween(ctx: BetweenContext): Expression = withOrigin(ctx) { + val value = expression(ctx.value) + val between = And( + GreaterThanOrEqual(value, expression(ctx.lower)), + LessThanOrEqual(value, expression(ctx.upper))) + invertIfNotDefined(between, ctx.NOT) + } + + /** + * Create an IN expression. This tests if the value of the left hand side expression is + * contained by the sequence of expressions on the right hand side. + */ + override def visitInList(ctx: InListContext): Expression = withOrigin(ctx) { + val in = In(expression(ctx.value), ctx.expression().asScala.map(expression)) + invertIfNotDefined(in, ctx.NOT) + } + + /** + * Create an IN expression, where the the right hand side is a query. This is unsupported. + */ + override def visitInSubquery(ctx: InSubqueryContext): Expression = { + throw new ParseException("IN with a Sub-query is currently not supported.", ctx) + } + + /** + * Create a (R)LIKE/REGEXP expression. + */ + override def visitLike(ctx: LikeContext): Expression = { + val left = expression(ctx.value) + val right = expression(ctx.pattern) + val like = ctx.like.getType match { + case SqlBaseParser.LIKE => + Like(left, right) + case SqlBaseParser.RLIKE => + RLike(left, right) + } + invertIfNotDefined(like, ctx.NOT) + } + + /** + * Create an IS (NOT) NULL expression. + */ + override def visitNullPredicate(ctx: NullPredicateContext): Expression = withOrigin(ctx) { + val value = expression(ctx.value) + if (ctx.NOT != null) { + IsNotNull(value) + } else { + IsNull(value) + } + } + + /** + * Create a binary arithmetic expression. The following arithmetic operators are supported: + * - Mulitplication: '*' + * - Division: '/' + * - Hive Long Division: 'DIV' + * - Modulo: '%' + * - Addition: '+' + * - Subtraction: '-' + * - Binary AND: '&' + * - Binary XOR + * - Binary OR: '|' + */ + override def visitArithmeticBinary(ctx: ArithmeticBinaryContext): Expression = withOrigin(ctx) { + val left = expression(ctx.left) + val right = expression(ctx.right) + ctx.operator.getType match { + case SqlBaseParser.ASTERISK => + Multiply(left, right) + case SqlBaseParser.SLASH => + Divide(left, right) + case SqlBaseParser.PERCENT => + Remainder(left, right) + case SqlBaseParser.DIV => + Cast(Divide(left, right), LongType) + case SqlBaseParser.PLUS => + Add(left, right) + case SqlBaseParser.MINUS => + Subtract(left, right) + case SqlBaseParser.AMPERSAND => + BitwiseAnd(left, right) + case SqlBaseParser.HAT => + BitwiseXor(left, right) + case SqlBaseParser.PIPE => + BitwiseOr(left, right) + } + } + + /** + * Create a unary arithmetic expression. The following arithmetic operators are supported: + * - Plus: '+' + * - Minus: '-' + * - Bitwise Not: '~' + */ + override def visitArithmeticUnary(ctx: ArithmeticUnaryContext): Expression = withOrigin(ctx) { + val value = expression(ctx.valueExpression) + ctx.operator.getType match { + case SqlBaseParser.PLUS => + value + case SqlBaseParser.MINUS => + UnaryMinus(value) + case SqlBaseParser.TILDE => + BitwiseNot(value) + } + } + + /** + * Create a [[Cast]] expression. + */ + override def visitCast(ctx: CastContext): Expression = withOrigin(ctx) { + Cast(expression(ctx.expression), typedVisit(ctx.dataType)) + } + + /** + * Create a (windowed) Function expression. + */ + override def visitFunctionCall(ctx: FunctionCallContext): Expression = withOrigin(ctx) { + // Create the function call. + val name = ctx.qualifiedName.getText + val isDistinct = Option(ctx.setQuantifier()).exists(_.DISTINCT != null) + val arguments = ctx.expression().asScala.map(expression) match { + case Seq(UnresolvedStar(None)) if name.toLowerCase == "count" && !isDistinct => + // Transform COUNT(*) into COUNT(1). Move this to analysis? + Seq(Literal(1)) + case expressions => + expressions + } + val function = UnresolvedFunction(name, arguments, isDistinct) + + // Check if the function is evaluated in a windowed context. + ctx.windowSpec match { + case spec: WindowRefContext => + UnresolvedWindowExpression(function, visitWindowRef(spec)) + case spec: WindowDefContext => + WindowExpression(function, visitWindowDef(spec)) + case _ => function + } + } + + /** + * Create a reference to a window frame, i.e. [[WindowSpecReference]]. + */ + override def visitWindowRef(ctx: WindowRefContext): WindowSpecReference = withOrigin(ctx) { + WindowSpecReference(ctx.identifier.getText) + } + + /** + * Create a window definition, i.e. [[WindowSpecDefinition]]. + */ + override def visitWindowDef(ctx: WindowDefContext): WindowSpecDefinition = withOrigin(ctx) { + // CLUSTER BY ... | PARTITION BY ... ORDER BY ... + val partition = ctx.partition.asScala.map(expression) + val order = ctx.sortItem.asScala.map(visitSortItem) + + // RANGE/ROWS BETWEEN ... + val frameSpecOption = Option(ctx.windowFrame).map { frame => + val frameType = frame.frameType.getType match { + case SqlBaseParser.RANGE => RangeFrame + case SqlBaseParser.ROWS => RowFrame + } + + SpecifiedWindowFrame( + frameType, + visitFrameBound(frame.start), + Option(frame.end).map(visitFrameBound).getOrElse(CurrentRow)) + } + + WindowSpecDefinition( + partition, + order, + frameSpecOption.getOrElse(UnspecifiedFrame)) + } + + /** + * Create or resolve a [[FrameBoundary]]. Simple math expressions are allowed for Value + * Preceding/Following boundaries. These expressions must be constant (foldable) and return an + * integer value. + */ + override def visitFrameBound(ctx: FrameBoundContext): FrameBoundary = withOrigin(ctx) { + // We currently only allow foldable integers. + def value: Int = { + val e = expression(ctx.expression) + assert(e.resolved && e.foldable && e.dataType == IntegerType, + "Frame bound value must be a constant integer.", + ctx) + e.eval().asInstanceOf[Int] + } + + // Create the FrameBoundary + ctx.boundType.getType match { + case SqlBaseParser.PRECEDING if ctx.UNBOUNDED != null => + UnboundedPreceding + case SqlBaseParser.PRECEDING => + ValuePreceding(value) + case SqlBaseParser.CURRENT => + CurrentRow + case SqlBaseParser.FOLLOWING if ctx.UNBOUNDED != null => + UnboundedFollowing + case SqlBaseParser.FOLLOWING => + ValueFollowing(value) + } + } + + /** + * Create a [[CreateStruct]] expression. + */ + override def visitRowConstructor(ctx: RowConstructorContext): Expression = withOrigin(ctx) { + CreateStruct(ctx.expression.asScala.map(expression)) + } + + /** + * Create a [[ScalarSubquery]] expression. + */ + override def visitSubqueryExpression( + ctx: SubqueryExpressionContext): Expression = withOrigin(ctx) { + ScalarSubquery(plan(ctx.query)) + } + + /** + * Create a value based [[CaseWhen]] expression. This has the following SQL form: + * {{{ + * CASE [expression] + * WHEN [value] THEN [expression] + * ... + * ELSE [expression] + * END + * }}} + */ + override def visitSimpleCase(ctx: SimpleCaseContext): Expression = withOrigin(ctx) { + val e = expression(ctx.valueExpression) + val branches = ctx.whenClause.asScala.map { wCtx => + (EqualTo(e, expression(wCtx.condition)), expression(wCtx.result)) + } + CaseWhen(branches, Option(ctx.elseExpression).map(expression)) + } + + /** + * Create a condition based [[CaseWhen]] expression. This has the following SQL syntax: + * {{{ + * CASE + * WHEN [predicate] THEN [expression] + * ... + * ELSE [expression] + * END + * }}} + * + * @param ctx the parse tree + * */ + override def visitSearchedCase(ctx: SearchedCaseContext): Expression = withOrigin(ctx) { + val branches = ctx.whenClause.asScala.map { wCtx => + (expression(wCtx.condition), expression(wCtx.result)) + } + CaseWhen(branches, Option(ctx.elseExpression).map(expression)) + } + + /** + * Create a dereference expression. The return type depends on the type of the parent, this can + * either be a [[UnresolvedAttribute]] (if the parent is an [[UnresolvedAttribute]]), or an + * [[UnresolvedExtractValue]] if the parent is some expression. + */ + override def visitDereference(ctx: DereferenceContext): Expression = withOrigin(ctx) { + val attr = ctx.fieldName.getText + expression(ctx.base) match { + case UnresolvedAttribute(nameParts) => + UnresolvedAttribute(nameParts :+ attr) + case e => + UnresolvedExtractValue(e, Literal(attr)) + } + } + + /** + * Create an [[UnresolvedAttribute]] expression. + */ + override def visitColumnReference(ctx: ColumnReferenceContext): Expression = withOrigin(ctx) { + UnresolvedAttribute.quoted(ctx.getText) + } + + /** + * Create an [[UnresolvedExtractValue]] expression, this is used for subscript access to an array. + */ + override def visitSubscript(ctx: SubscriptContext): Expression = withOrigin(ctx) { + UnresolvedExtractValue(expression(ctx.value), expression(ctx.index)) + } + + /** + * Create an expression for an expression between parentheses. This is need because the ANTLR + * visitor cannot automatically convert the nested context into an expression. + */ + override def visitParenthesizedExpression( + ctx: ParenthesizedExpressionContext): Expression = withOrigin(ctx) { + expression(ctx.expression) + } + + /** + * Create a [[SortOrder]] expression. + */ + override def visitSortItem(ctx: SortItemContext): SortOrder = withOrigin(ctx) { + if (ctx.DESC != null) { + SortOrder(expression(ctx.expression), Descending) + } else { + SortOrder(expression(ctx.expression), Ascending) + } + } + + /** + * Create a typed Literal expression. A typed literal has the following SQL syntax: + * {{{ + * [TYPE] '[VALUE]' + * }}} + * Currently Date and Timestamp typed literals are supported. + * + * TODO what the added value of this over casting? + */ + override def visitTypeConstructor(ctx: TypeConstructorContext): Literal = withOrigin(ctx) { + val value = string(ctx.STRING) + ctx.identifier.getText.toUpperCase match { + case "DATE" => + Literal(Date.valueOf(value)) + case "TIMESTAMP" => + Literal(Timestamp.valueOf(value)) + case other => + throw new ParseException(s"Literals of type '$other' are currently not supported.", ctx) + } + } + + /** + * Create a NULL literal expression. + */ + override def visitNullLiteral(ctx: NullLiteralContext): Literal = withOrigin(ctx) { + Literal(null) + } + + /** + * Create a Boolean literal expression. + */ + override def visitBooleanLiteral(ctx: BooleanLiteralContext): Literal = withOrigin(ctx) { + if (ctx.getText.toBoolean) { + Literal.TrueLiteral + } else { + Literal.FalseLiteral + } + } + + /** + * Create an integral literal expression. The code selects the most narrow integral type + * possible, either a BigDecimal, a Long or an Integer is returned. + */ + override def visitIntegerLiteral(ctx: IntegerLiteralContext): Literal = withOrigin(ctx) { + BigDecimal(ctx.getText) match { + case v if v.isValidInt => + Literal(v.intValue()) + case v if v.isValidLong => + Literal(v.longValue()) + case v => Literal(v.underlying()) + } + } + + /** + * Create a double literal for a number denoted in scientifc notation. + */ + override def visitScientificDecimalLiteral( + ctx: ScientificDecimalLiteralContext): Literal = withOrigin(ctx) { + Literal(ctx.getText.toDouble) + } + + /** + * Create a decimal literal for a regular decimal number. + */ + override def visitDecimalLiteral(ctx: DecimalLiteralContext): Literal = withOrigin(ctx) { + Literal(BigDecimal(ctx.getText).underlying()) + } + + /** Create a numeric literal expression. */ + private def numericLiteral(ctx: NumberContext)(f: String => Any): Literal = withOrigin(ctx) { + val raw = ctx.getText + try { + Literal(f(raw.substring(0, raw.length - 1))) + } catch { + case e: NumberFormatException => + throw new ParseException(e.getMessage, ctx) + } + } + + /** + * Create a Byte Literal expression. + */ + override def visitTinyIntLiteral(ctx: TinyIntLiteralContext): Literal = numericLiteral(ctx) { + _.toByte + } + + /** + * Create a Short Literal expression. + */ + override def visitSmallIntLiteral(ctx: SmallIntLiteralContext): Literal = numericLiteral(ctx) { + _.toShort + } + + /** + * Create a Long Literal expression. + */ + override def visitBigIntLiteral(ctx: BigIntLiteralContext): Literal = numericLiteral(ctx) { + _.toLong + } + + /** + * Create a Double Literal expression. + */ + override def visitDoubleLiteral(ctx: DoubleLiteralContext): Literal = numericLiteral(ctx) { + _.toDouble + } + + /** + * Create a String literal expression. + */ + override def visitStringLiteral(ctx: StringLiteralContext): Literal = withOrigin(ctx) { + Literal(createString(ctx)) + } + + /** + * Create a String from a string literal context. This supports multiple consecutive string + * literals, these are concatenated, for example this expression "'hello' 'world'" will be + * converted into "helloworld". + * + * Special characters can be escaped by using Hive/C-style escaping. + */ + private def createString(ctx: StringLiteralContext): String = { + ctx.STRING().asScala.map(string).mkString + } + + /** + * Create a [[CalendarInterval]] literal expression. An interval expression can contain multiple + * unit value pairs, for instance: interval 2 months 2 days. + */ + override def visitInterval(ctx: IntervalContext): Literal = withOrigin(ctx) { + val intervals = ctx.intervalField.asScala.map(visitIntervalField) + assert(intervals.nonEmpty, "at least one time unit should be given for interval literal", ctx) + Literal(intervals.reduce(_.add(_))) + } + + /** + * Create a [[CalendarInterval]] for a unit value pair. Two unit configuration types are + * supported: + * - Single unit. + * - From-To unit (only 'YEAR TO MONTH' and 'DAY TO SECOND' are supported). + */ + override def visitIntervalField(ctx: IntervalFieldContext): CalendarInterval = withOrigin(ctx) { + import ctx._ + val s = value.getText + val interval = (unit.getText.toLowerCase, Option(to).map(_.getText.toLowerCase)) match { + case (u, None) if u.endsWith("s") => + // Handle plural forms, e.g: yearS/monthS/weekS/dayS/hourS/minuteS/hourS/... + CalendarInterval.fromSingleUnitString(u.substring(0, u.length - 1), s) + case (u, None) => + CalendarInterval.fromSingleUnitString(u, s) + case ("year", Some("month")) => + CalendarInterval.fromYearMonthString(s) + case ("day", Some("second")) => + CalendarInterval.fromDayTimeString(s) + case (from, Some(t)) => + throw new ParseException(s"Intervals FROM $from TO $t are not supported.", ctx) + } + assert(interval != null, "No interval can be constructed", ctx) + interval + } + + /* ******************************************************************************************** + * DataType parsing + * ******************************************************************************************** */ + /** + * Resolve/create a primitive type. + */ + override def visitPrimitiveDataType(ctx: PrimitiveDataTypeContext): DataType = withOrigin(ctx) { + (ctx.identifier.getText.toLowerCase, ctx.INTEGER_VALUE().asScala.toList) match { + case ("boolean", Nil) => BooleanType + case ("tinyint" | "byte", Nil) => ByteType + case ("smallint" | "short", Nil) => ShortType + case ("int" | "integer", Nil) => IntegerType + case ("bigint" | "long", Nil) => LongType + case ("float", Nil) => FloatType + case ("double", Nil) => DoubleType + case ("date", Nil) => DateType + case ("timestamp", Nil) => TimestampType + case ("char" | "varchar" | "string", Nil) => StringType + case ("char" | "varchar", _ :: Nil) => StringType + case ("binary", Nil) => BinaryType + case ("decimal", Nil) => DecimalType.USER_DEFAULT + case ("decimal", precision :: Nil) => DecimalType(precision.getText.toInt, 0) + case ("decimal", precision :: scale :: Nil) => + DecimalType(precision.getText.toInt, scale.getText.toInt) + case (dt, params) => + throw new ParseException( + s"DataType $dt${params.mkString("(", ",", ")")} is not supported.", ctx) + } + } + + /** + * Create a complex DataType. Arrays, Maps and Structures are supported. + */ + override def visitComplexDataType(ctx: ComplexDataTypeContext): DataType = withOrigin(ctx) { + ctx.complex.getType match { + case SqlBaseParser.ARRAY => + ArrayType(typedVisit(ctx.dataType(0))) + case SqlBaseParser.MAP => + MapType(typedVisit(ctx.dataType(0)), typedVisit(ctx.dataType(1))) + case SqlBaseParser.STRUCT => + createStructType(ctx.colTypeList()) + } + } + + /** + * Create a [[StructType]] from a sequence of [[StructField]]s. + */ + protected def createStructType(ctx: ColTypeListContext): StructType = { + StructType(Option(ctx).toSeq.flatMap(visitColTypeList)) + } + + /** + * Create a [[StructType]] from a number of column definitions. + */ + override def visitColTypeList(ctx: ColTypeListContext): Seq[StructField] = withOrigin(ctx) { + ctx.colType().asScala.map(visitColType) + } + + /** + * Create a [[StructField]] from a column definition. + */ + override def visitColType(ctx: ColTypeContext): StructField = withOrigin(ctx) { + import ctx._ + + // Add the comment to the metadata. + val builder = new MetadataBuilder + if (STRING != null) { + builder.putString("comment", string(STRING)) + } + + StructField(identifier.getText, typedVisit(dataType), nullable = true, builder.build()) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ng/ParseDriver.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ng/ParseDriver.scala new file mode 100644 index 0000000000..c9a286374c --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ng/ParseDriver.scala @@ -0,0 +1,240 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.parser.ng + +import org.antlr.v4.runtime._ +import org.antlr.v4.runtime.atn.PredictionMode +import org.antlr.v4.runtime.misc.ParseCancellationException + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.parser.ParserInterface +import org.apache.spark.sql.catalyst.parser.ng.SqlBaseParser._ +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.trees.Origin +import org.apache.spark.sql.types.DataType + +/** + * Base SQL parsing infrastructure. + */ +abstract class AbstractSqlParser extends ParserInterface with Logging { + + /** Creates/Resolves DataType for a given SQL string. */ + def parseDataType(sqlText: String): DataType = parse(sqlText) { parser => + // TODO add this to the parser interface. + astBuilder.visitSingleDataType(parser.singleDataType()) + } + + /** Creates Expression for a given SQL string. */ + override def parseExpression(sqlText: String): Expression = parse(sqlText) { parser => + astBuilder.visitSingleExpression(parser.singleExpression()) + } + + /** Creates TableIdentifier for a given SQL string. */ + override def parseTableIdentifier(sqlText: String): TableIdentifier = parse(sqlText) { parser => + astBuilder.visitSingleTableIdentifier(parser.singleTableIdentifier()) + } + + /** Creates LogicalPlan for a given SQL string. */ + override def parsePlan(sqlText: String): LogicalPlan = parse(sqlText) { parser => + astBuilder.visitSingleStatement(parser.singleStatement()) match { + case plan: LogicalPlan => plan + case _ => nativeCommand(sqlText) + } + } + + /** Get the builder (visitor) which converts a ParseTree into a AST. */ + protected def astBuilder: AstBuilder + + /** Create a native command, or fail when this is not supported. */ + protected def nativeCommand(sqlText: String): LogicalPlan = { + val position = Origin(None, None) + throw new ParseException(Option(sqlText), "Unsupported SQL statement", position, position) + } + + protected def parse[T](command: String)(toResult: SqlBaseParser => T): T = { + logInfo(s"Parsing command: $command") + + val lexer = new SqlBaseLexer(new ANTLRNoCaseStringStream(command)) + lexer.removeErrorListeners() + lexer.addErrorListener(ParseErrorListener) + + val tokenStream = new CommonTokenStream(lexer) + val parser = new SqlBaseParser(tokenStream) + parser.addParseListener(PostProcessor) + parser.removeErrorListeners() + parser.addErrorListener(ParseErrorListener) + + try { + try { + // first, try parsing with potentially faster SLL mode + parser.getInterpreter.setPredictionMode(PredictionMode.SLL) + toResult(parser) + } + catch { + case e: ParseCancellationException => + // if we fail, parse with LL mode + tokenStream.reset() // rewind input stream + parser.reset() + + // Try Again. + parser.getInterpreter.setPredictionMode(PredictionMode.LL) + toResult(parser) + } + } + catch { + case e: ParseException if e.command.isDefined => + throw e + case e: ParseException => + throw e.withCommand(command) + case e: AnalysisException => + val position = Origin(e.line, e.startPosition) + throw new ParseException(Option(command), e.message, position, position) + } + } +} + +/** + * Concrete SQL parser for Catalyst-only SQL statements. + */ +object CatalystSqlParser extends AbstractSqlParser { + val astBuilder = new AstBuilder +} + +/** + * This string stream provides the lexer with upper case characters only. This greatly simplifies + * lexing the stream, while we can maintain the original command. + * + * This is based on Hive's org.apache.hadoop.hive.ql.parse.ParseDriver.ANTLRNoCaseStringStream + * + * The comment below (taken from the original class) describes the rationale for doing this: + * + * This class provides and implementation for a case insensitive token checker for the lexical + * analysis part of antlr. By converting the token stream into upper case at the time when lexical + * rules are checked, this class ensures that the lexical rules need to just match the token with + * upper case letters as opposed to combination of upper case and lower case characters. This is + * purely used for matching lexical rules. The actual token text is stored in the same way as the + * user input without actually converting it into an upper case. The token values are generated by + * the consume() function of the super class ANTLRStringStream. The LA() function is the lookahead + * function and is purely used for matching lexical rules. This also means that the grammar will + * only accept capitalized tokens in case it is run from other tools like antlrworks which do not + * have the ANTLRNoCaseStringStream implementation. + */ + +private[parser] class ANTLRNoCaseStringStream(input: String) extends ANTLRInputStream(input) { + override def LA(i: Int): Int = { + val la = super.LA(i) + if (la == 0 || la == IntStream.EOF) la + else Character.toUpperCase(la) + } +} + +/** + * The ParseErrorListener converts parse errors into AnalysisExceptions. + */ +case object ParseErrorListener extends BaseErrorListener { + override def syntaxError( + recognizer: Recognizer[_, _], + offendingSymbol: scala.Any, + line: Int, + charPositionInLine: Int, + msg: String, + e: RecognitionException): Unit = { + val position = Origin(Some(line), Some(charPositionInLine)) + throw new ParseException(None, msg, position, position) + } +} + +/** + * A [[ParseException]] is an [[AnalysisException]] that is thrown during the parse process. It + * contains fields and an extended error message that make reporting and diagnosing errors easier. + */ +class ParseException( + val command: Option[String], + message: String, + val start: Origin, + val stop: Origin) extends AnalysisException(message, start.line, start.startPosition) { + + def this(message: String, ctx: ParserRuleContext) = { + this(Option(ParserUtils.command(ctx)), + message, + ParserUtils.position(ctx.getStart), + ParserUtils.position(ctx.getStop)) + } + + override def getMessage: String = { + val builder = new StringBuilder + builder ++= "\n" ++= message + start match { + case Origin(Some(l), Some(p)) => + builder ++= s"(line $l, pos $p)\n" + command.foreach { cmd => + val (above, below) = cmd.split("\n").splitAt(l) + builder ++= "\n== SQL ==\n" + above.foreach(builder ++= _ += '\n') + builder ++= (0 until p).map(_ => "-").mkString("") ++= "^^^\n" + below.foreach(builder ++= _ += '\n') + } + case _ => + command.foreach { cmd => + builder ++= "\n== SQL ==\n" ++= cmd + } + } + builder.toString + } + + def withCommand(cmd: String): ParseException = { + new ParseException(Option(cmd), message, start, stop) + } +} + +/** + * The post-processor validates & cleans-up the parse tree during the parse process. + */ +case object PostProcessor extends SqlBaseBaseListener { + + /** Remove the back ticks from an Identifier. */ + override def exitQuotedIdentifier(ctx: QuotedIdentifierContext): Unit = { + replaceTokenByIdentifier(ctx, 1) { token => + // Remove the double back ticks in the string. + token.setText(token.getText.replace("``", "`")) + token + } + } + + /** Treat non-reserved keywords as Identifiers. */ + override def exitNonReserved(ctx: NonReservedContext): Unit = { + replaceTokenByIdentifier(ctx, 0)(identity) + } + + private def replaceTokenByIdentifier( + ctx: ParserRuleContext, + stripMargins: Int)( + f: CommonToken => CommonToken = identity): Unit = { + val parent = ctx.getParent + parent.removeLastChild() + val token = ctx.getChild(0).getPayload.asInstanceOf[Token] + parent.addChild(f(new CommonToken( + new org.antlr.v4.runtime.misc.Pair(token.getTokenSource, token.getInputStream), + SqlBaseParser.IDENTIFIER, + token.getChannel, + token.getStartIndex + stripMargins, + token.getStopIndex - stripMargins))) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ng/ParserUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ng/ParserUtils.scala new file mode 100644 index 0000000000..1fbfa763b4 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ng/ParserUtils.scala @@ -0,0 +1,118 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.parser.ng + +import org.antlr.v4.runtime.{CharStream, ParserRuleContext, Token} +import org.antlr.v4.runtime.misc.Interval +import org.antlr.v4.runtime.tree.TerminalNode + +import org.apache.spark.sql.catalyst.parser.ParseUtils.unescapeSQLString +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, Origin} + +/** + * A collection of utility methods for use during the parsing process. + */ +object ParserUtils { + /** Get the command which created the token. */ + def command(ctx: ParserRuleContext): String = { + command(ctx.getStart.getInputStream) + } + + /** Get the command which created the token. */ + def command(stream: CharStream): String = { + stream.getText(Interval.of(0, stream.size())) + } + + /** Get the code that creates the given node. */ + def source(ctx: ParserRuleContext): String = { + val stream = ctx.getStart.getInputStream + stream.getText(Interval.of(ctx.getStart.getStartIndex, ctx.getStop.getStopIndex)) + } + + /** Get all the text which comes after the given rule. */ + def remainder(ctx: ParserRuleContext): String = remainder(ctx.getStop) + + /** Get all the text which comes after the given token. */ + def remainder(token: Token): String = { + val stream = token.getInputStream + val interval = Interval.of(token.getStopIndex + 1, stream.size()) + stream.getText(interval) + } + + /** Convert a string token into a string. */ + def string(token: Token): String = unescapeSQLString(token.getText) + + /** Convert a string node into a string. */ + def string(node: TerminalNode): String = unescapeSQLString(node.getText) + + /** Get the origin (line and position) of the token. */ + def position(token: Token): Origin = { + Origin(Option(token.getLine), Option(token.getCharPositionInLine)) + } + + /** Assert if a condition holds. If it doesn't throw a parse exception. */ + def assert(f: => Boolean, message: String, ctx: ParserRuleContext): Unit = { + if (!f) { + throw new ParseException(message, ctx) + } + } + + /** + * Register the origin of the context. Any TreeNode created in the closure will be assigned the + * registered origin. This method restores the previously set origin after completion of the + * closure. + */ + def withOrigin[T](ctx: ParserRuleContext)(f: => T): T = { + val current = CurrentOrigin.get + CurrentOrigin.set(position(ctx.getStart)) + try { + f + } finally { + CurrentOrigin.set(current) + } + } + + /** Some syntactic sugar which makes it easier to work with optional clauses for LogicalPlans. */ + implicit class EnhancedLogicalPlan(val plan: LogicalPlan) extends AnyVal { + /** + * Create a plan using the block of code when the given context exists. Otherwise return the + * original plan. + */ + def optional(ctx: AnyRef)(f: => LogicalPlan): LogicalPlan = { + if (ctx != null) { + f + } else { + plan + } + } + + /** + * Map a [[LogicalPlan]] to another [[LogicalPlan]] if the passed context exists using the + * passed function. The original plan is returned when the context does not exist. + */ + def optionalMap[C <: ParserRuleContext]( + ctx: C)( + f: (C, LogicalPlan) => LogicalPlan): LogicalPlan = { + if (ctx != null) { + f(ctx, plan) + } else { + plan + } + } + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/CatalystQlSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/CatalystQlSuite.scala index c068e895b6..223485e292 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/CatalystQlSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/CatalystQlSuite.scala @@ -21,15 +21,20 @@ import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.parser.ng.CatalystSqlParser import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.unsafe.types.CalendarInterval class CatalystQlSuite extends PlanTest { val parser = new CatalystQl() + import org.apache.spark.sql.catalyst.dsl.expressions._ + import org.apache.spark.sql.catalyst.dsl.plans._ + + val star = UnresolvedAlias(UnresolvedStar(None)) test("test case insensitive") { - val result = Project(UnresolvedAlias(Literal(1)):: Nil, OneRowRelation) + val result = OneRowRelation.select(1) assert(result === parser.parsePlan("seLect 1")) assert(result === parser.parsePlan("select 1")) assert(result === parser.parsePlan("SELECT 1")) @@ -37,52 +42,31 @@ class CatalystQlSuite extends PlanTest { test("test NOT operator with comparison operations") { val parsed = parser.parsePlan("SELECT NOT TRUE > TRUE") - val expected = Project( - UnresolvedAlias( - Not( - GreaterThan(Literal(true), Literal(true))) - ) :: Nil, - OneRowRelation) + val expected = OneRowRelation.select(Not(GreaterThan(true, true))) comparePlans(parsed, expected) } test("test Union Distinct operator") { - val parsed1 = parser.parsePlan("SELECT * FROM t0 UNION SELECT * FROM t1") - val parsed2 = parser.parsePlan("SELECT * FROM t0 UNION DISTINCT SELECT * FROM t1") - val expected = - Project(UnresolvedAlias(UnresolvedStar(None)) :: Nil, - SubqueryAlias("u_1", - Distinct( - Union( - Project(UnresolvedAlias(UnresolvedStar(None)) :: Nil, - UnresolvedRelation(TableIdentifier("t0"), None)), - Project(UnresolvedAlias(UnresolvedStar(None)) :: Nil, - UnresolvedRelation(TableIdentifier("t1"), None)))))) + val parsed1 = parser.parsePlan( + "SELECT * FROM t0 UNION SELECT * FROM t1") + val parsed2 = parser.parsePlan( + "SELECT * FROM t0 UNION DISTINCT SELECT * FROM t1") + val expected = Distinct(Union(table("t0").select(star), table("t1").select(star))) + .as("u_1").select(star) comparePlans(parsed1, expected) comparePlans(parsed2, expected) } test("test Union All operator") { val parsed = parser.parsePlan("SELECT * FROM t0 UNION ALL SELECT * FROM t1") - val expected = - Project(UnresolvedAlias(UnresolvedStar(None)) :: Nil, - SubqueryAlias("u_1", - Union( - Project(UnresolvedAlias(UnresolvedStar(None)) :: Nil, - UnresolvedRelation(TableIdentifier("t0"), None)), - Project(UnresolvedAlias(UnresolvedStar(None)) :: Nil, - UnresolvedRelation(TableIdentifier("t1"), None))))) + val expected = Union(table("t0").select(star), table("t1").select(star)).as("u_1").select(star) comparePlans(parsed, expected) } test("support hive interval literal") { def checkInterval(sql: String, result: CalendarInterval): Unit = { val parsed = parser.parsePlan(sql) - val expected = Project( - UnresolvedAlias( - Literal(result) - ) :: Nil, - OneRowRelation) + val expected = OneRowRelation.select(Literal(result)) comparePlans(parsed, expected) } @@ -129,11 +113,7 @@ class CatalystQlSuite extends PlanTest { test("support scientific notation") { def assertRight(input: String, output: Double): Unit = { val parsed = parser.parsePlan("SELECT " + input) - val expected = Project( - UnresolvedAlias( - Literal(output) - ) :: Nil, - OneRowRelation) + val expected = OneRowRelation.select(Literal(output)) comparePlans(parsed, expected) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DataTypeParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DataTypeParserSuite.scala index 7d3608033b..d9bd33c50a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DataTypeParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DataTypeParserSuite.scala @@ -18,19 +18,24 @@ package org.apache.spark.sql.catalyst.parser import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.parser.ng.{CatalystSqlParser, ParseException} import org.apache.spark.sql.types._ -class DataTypeParserSuite extends SparkFunSuite { +abstract class AbstractDataTypeParserSuite extends SparkFunSuite { + + def parse(sql: String): DataType def checkDataType(dataTypeString: String, expectedDataType: DataType): Unit = { test(s"parse ${dataTypeString.replace("\n", "")}") { - assert(DataTypeParser.parse(dataTypeString) === expectedDataType) + assert(parse(dataTypeString) === expectedDataType) } } + def intercept(sql: String) + def unsupported(dataTypeString: String): Unit = { test(s"$dataTypeString is not supported") { - intercept[DataTypeException](DataTypeParser.parse(dataTypeString)) + intercept(dataTypeString) } } @@ -97,13 +102,6 @@ class DataTypeParserSuite extends SparkFunSuite { StructField("arrAy", ArrayType(DoubleType, true), true) :: StructField("anotherArray", ArrayType(StringType, true), true) :: Nil) ) - // A column name can be a reserved word in our DDL parser and SqlParser. - checkDataType( - "Struct", - StructType( - StructField("TABLE", StringType, true) :: - StructField("CASE", BooleanType, true) :: Nil) - ) // Use backticks to quote column names having special characters. checkDataType( "struct<`x+y`:int, `!@#$%^&*()`:string, `1_2.345<>:\"`:varchar(20)>", @@ -118,6 +116,43 @@ class DataTypeParserSuite extends SparkFunSuite { unsupported("it is not a data type") unsupported("struct") unsupported("struct", + StructType( + StructField("TABLE", StringType, true) :: + StructField("CASE", BooleanType, true) :: Nil) + ) + unsupported("struct") + unsupported("struct<`x``y` int>") } + +class CatalystQlDataTypeParserSuite extends AbstractDataTypeParserSuite { + override def intercept(sql: String): Unit = + intercept[ParseException](CatalystSqlParser.parseDataType(sql)) + + override def parse(sql: String): DataType = + CatalystSqlParser.parseDataType(sql) + + // A column name can be a reserved word in our DDL parser and SqlParser. + unsupported("Struct") + + checkDataType( + "struct", + (new StructType).add("x", IntegerType).add("y", StringType)) + + checkDataType( + "struct<`x``y` int>", + (new StructType).add("x`y", IntegerType)) +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ng/ErrorParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ng/ErrorParserSuite.scala new file mode 100644 index 0000000000..1963fc368f --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ng/ErrorParserSuite.scala @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.parser.ng + +import org.apache.spark.SparkFunSuite + +/** + * Test various parser errors. + */ +class ErrorParserSuite extends SparkFunSuite { + def intercept(sql: String, line: Int, startPosition: Int, messages: String*): Unit = { + val e = intercept[ParseException](CatalystSqlParser.parsePlan(sql)) + + // Check position. + assert(e.line.isDefined) + assert(e.line.get === line) + assert(e.startPosition.isDefined) + assert(e.startPosition.get === startPosition) + + // Check messages. + val error = e.getMessage + messages.foreach { message => + assert(error.contains(message)) + } + } + + test("no viable input") { + intercept("select from tbl", 1, 7, "no viable alternative at input", "-------^^^") + intercept("select\nfrom tbl", 2, 0, "no viable alternative at input", "^^^") + intercept("select ((r + 1) ", 1, 16, "no viable alternative at input", "----------------^^^") + } + + test("extraneous input") { + intercept("select 1 1", 1, 9, "extraneous input '1' expecting", "---------^^^") + intercept("select *\nfrom r as q t", 2, 12, "extraneous input", "------------^^^") + } + + test("mismatched input") { + intercept("select * from r order by q from t", 1, 27, + "mismatched input", + "---------------------------^^^") + intercept("select *\nfrom r\norder by q\nfrom t", 4, 0, "mismatched input", "^^^") + } + + test("semantic errors") { + intercept("select *\nfrom r\norder by q\ncluster by q", 3, 0, + "Combination of ORDER BY/SORT BY/DISTRIBUTE BY/CLUSTER BY is not supported", + "^^^") + intercept("select * from r where a in (select * from t)", 1, 24, + "IN with a Sub-query is currently not supported", + "------------------------^^^") + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ng/ExpressionParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ng/ExpressionParserSuite.scala new file mode 100644 index 0000000000..32311a5a66 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ng/ExpressionParserSuite.scala @@ -0,0 +1,497 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.parser.ng + +import java.sql.{Date, Timestamp} + +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, _} +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.CalendarInterval + +/** + * Test basic expression parsing. If a type of expression is supported it should be tested here. + * + * Please note that some of the expressions test don't have to be sound expressions, only their + * structure needs to be valid. Unsound expressions should be caught by the Analyzer or + * CheckAnalysis classes. + */ +class ExpressionParserSuite extends PlanTest { + import CatalystSqlParser._ + import org.apache.spark.sql.catalyst.dsl.expressions._ + import org.apache.spark.sql.catalyst.dsl.plans._ + + def assertEqual(sqlCommand: String, e: Expression): Unit = { + compareExpressions(parseExpression(sqlCommand), e) + } + + def intercept(sqlCommand: String, messages: String*): Unit = { + val e = intercept[ParseException](parseExpression(sqlCommand)) + messages.foreach { message => + assert(e.message.contains(message)) + } + } + + test("star expressions") { + // Global Star + assertEqual("*", UnresolvedStar(None)) + + // Targeted Star + assertEqual("a.b.*", UnresolvedStar(Option(Seq("a", "b")))) + } + + // NamedExpression (Alias/Multialias) + test("named expressions") { + // No Alias + val r0 = 'a + assertEqual("a", r0) + + // Single Alias. + val r1 = 'a as "b" + assertEqual("a as b", r1) + assertEqual("a b", r1) + + // Multi-Alias + assertEqual("a as (b, c)", MultiAlias('a, Seq("b", "c"))) + assertEqual("a() (b, c)", MultiAlias('a.function(), Seq("b", "c"))) + + // Numeric literals without a space between the literal qualifier and the alias, should not be + // interpreted as such. An unresolved reference should be returned instead. + // TODO add the JIRA-ticket number. + assertEqual("1SL", Symbol("1SL")) + + // Aliased star is allowed. + assertEqual("a.* b", UnresolvedStar(Option(Seq("a"))) as 'b) + } + + test("binary logical expressions") { + // And + assertEqual("a and b", 'a && 'b) + + // Or + assertEqual("a or b", 'a || 'b) + + // Combination And/Or check precedence + assertEqual("a and b or c and d", ('a && 'b) || ('c && 'd)) + assertEqual("a or b or c and d", 'a || 'b || ('c && 'd)) + + // Multiple AND/OR get converted into a balanced tree + assertEqual("a or b or c or d or e or f", (('a || 'b) || 'c) || (('d || 'e) || 'f)) + assertEqual("a and b and c and d and e and f", (('a && 'b) && 'c) && (('d && 'e) && 'f)) + } + + test("long binary logical expressions") { + def testVeryBinaryExpression(op: String, clazz: Class[_]): Unit = { + val sql = (1 to 1000).map(x => s"$x == $x").mkString(op) + val e = parseExpression(sql) + assert(e.collect { case _: EqualTo => true }.size === 1000) + assert(e.collect { case x if clazz.isInstance(x) => true }.size === 999) + } + testVeryBinaryExpression(" AND ", classOf[And]) + testVeryBinaryExpression(" OR ", classOf[Or]) + } + + test("not expressions") { + assertEqual("not a", !'a) + assertEqual("!a", !'a) + assertEqual("not true > true", Not(GreaterThan(true, true))) + } + + test("exists expression") { + intercept("exists (select 1 from b where b.x = a.x)", "EXISTS clauses are not supported") + } + + test("comparison expressions") { + assertEqual("a = b", 'a === 'b) + assertEqual("a == b", 'a === 'b) + assertEqual("a <=> b", 'a <=> 'b) + assertEqual("a <> b", 'a =!= 'b) + assertEqual("a != b", 'a =!= 'b) + assertEqual("a < b", 'a < 'b) + assertEqual("a <= b", 'a <= 'b) + assertEqual("a > b", 'a > 'b) + assertEqual("a >= b", 'a >= 'b) + } + + test("between expressions") { + assertEqual("a between b and c", 'a >= 'b && 'a <= 'c) + assertEqual("a not between b and c", !('a >= 'b && 'a <= 'c)) + } + + test("in expressions") { + assertEqual("a in (b, c, d)", 'a in ('b, 'c, 'd)) + assertEqual("a not in (b, c, d)", !('a in ('b, 'c, 'd))) + } + + test("in sub-query") { + intercept("a in (select b from c)", "IN with a Sub-query is currently not supported") + } + + test("like expressions") { + assertEqual("a like 'pattern%'", 'a like "pattern%") + assertEqual("a not like 'pattern%'", !('a like "pattern%")) + assertEqual("a rlike 'pattern%'", 'a rlike "pattern%") + assertEqual("a not rlike 'pattern%'", !('a rlike "pattern%")) + assertEqual("a regexp 'pattern%'", 'a rlike "pattern%") + assertEqual("a not regexp 'pattern%'", !('a rlike "pattern%")) + } + + test("is null expressions") { + assertEqual("a is null", 'a.isNull) + assertEqual("a is not null", 'a.isNotNull) + assertEqual("a = b is null", ('a === 'b).isNull) + assertEqual("a = b is not null", ('a === 'b).isNotNull) + } + + test("binary arithmetic expressions") { + // Simple operations + assertEqual("a * b", 'a * 'b) + assertEqual("a / b", 'a / 'b) + assertEqual("a DIV b", ('a / 'b).cast(LongType)) + assertEqual("a % b", 'a % 'b) + assertEqual("a + b", 'a + 'b) + assertEqual("a - b", 'a - 'b) + assertEqual("a & b", 'a & 'b) + assertEqual("a ^ b", 'a ^ 'b) + assertEqual("a | b", 'a | 'b) + + // Check precedences + assertEqual( + "a * t | b ^ c & d - e + f % g DIV h / i * k", + 'a * 't | ('b ^ ('c & ('d - 'e + (('f % 'g / 'h).cast(LongType) / 'i * 'k))))) + } + + test("unary arithmetic expressions") { + assertEqual("+a", 'a) + assertEqual("-a", -'a) + assertEqual("~a", ~'a) + assertEqual("-+~~a", -(~(~'a))) + } + + test("cast expressions") { + // Note that DataType parsing is tested elsewhere. + assertEqual("cast(a as int)", 'a.cast(IntegerType)) + assertEqual("cast(a as timestamp)", 'a.cast(TimestampType)) + assertEqual("cast(a as array)", 'a.cast(ArrayType(IntegerType))) + assertEqual("cast(cast(a as int) as long)", 'a.cast(IntegerType).cast(LongType)) + } + + test("function expressions") { + assertEqual("foo()", 'foo.function()) + assertEqual("foo.bar()", Symbol("foo.bar").function()) + assertEqual("foo(*)", 'foo.function(star())) + assertEqual("count(*)", 'count.function(1)) + assertEqual("foo(a, b)", 'foo.function('a, 'b)) + assertEqual("foo(all a, b)", 'foo.function('a, 'b)) + assertEqual("foo(distinct a, b)", 'foo.distinctFunction('a, 'b)) + assertEqual("grouping(distinct a, b)", 'grouping.distinctFunction('a, 'b)) + assertEqual("`select`(all a, b)", 'select.function('a, 'b)) + } + + test("window function expressions") { + val func = 'foo.function(star()) + def windowed( + partitioning: Seq[Expression] = Seq.empty, + ordering: Seq[SortOrder] = Seq.empty, + frame: WindowFrame = UnspecifiedFrame): Expression = { + WindowExpression(func, WindowSpecDefinition(partitioning, ordering, frame)) + } + + // Basic window testing. + assertEqual("foo(*) over w1", UnresolvedWindowExpression(func, WindowSpecReference("w1"))) + assertEqual("foo(*) over ()", windowed()) + assertEqual("foo(*) over (partition by a, b)", windowed(Seq('a, 'b))) + assertEqual("foo(*) over (distribute by a, b)", windowed(Seq('a, 'b))) + assertEqual("foo(*) over (cluster by a, b)", windowed(Seq('a, 'b))) + assertEqual("foo(*) over (order by a desc, b asc)", windowed(Seq.empty, Seq('a.desc, 'b.asc ))) + assertEqual("foo(*) over (sort by a desc, b asc)", windowed(Seq.empty, Seq('a.desc, 'b.asc ))) + assertEqual("foo(*) over (partition by a, b order by c)", windowed(Seq('a, 'b), Seq('c.asc))) + assertEqual("foo(*) over (distribute by a, b sort by c)", windowed(Seq('a, 'b), Seq('c.asc))) + + // Test use of expressions in window functions. + assertEqual( + "sum(product + 1) over (partition by ((product) + (1)) order by 2)", + WindowExpression('sum.function('product + 1), + WindowSpecDefinition(Seq('product + 1), Seq(Literal(2).asc), UnspecifiedFrame))) + assertEqual( + "sum(product + 1) over (partition by ((product / 2) + 1) order by 2)", + WindowExpression('sum.function('product + 1), + WindowSpecDefinition(Seq('product / 2 + 1), Seq(Literal(2).asc), UnspecifiedFrame))) + + // Range/Row + val frameTypes = Seq(("rows", RowFrame), ("range", RangeFrame)) + val boundaries = Seq( + ("10 preceding", ValuePreceding(10), CurrentRow), + ("3 + 1 following", ValueFollowing(4), CurrentRow), // Will fail during analysis + ("unbounded preceding", UnboundedPreceding, CurrentRow), + ("unbounded following", UnboundedFollowing, CurrentRow), // Will fail during analysis + ("between unbounded preceding and current row", UnboundedPreceding, CurrentRow), + ("between unbounded preceding and unbounded following", + UnboundedPreceding, UnboundedFollowing), + ("between 10 preceding and current row", ValuePreceding(10), CurrentRow), + ("between current row and 5 following", CurrentRow, ValueFollowing(5)), + ("between 10 preceding and 5 following", ValuePreceding(10), ValueFollowing(5)) + ) + frameTypes.foreach { + case (frameTypeSql, frameType) => + boundaries.foreach { + case (boundarySql, begin, end) => + val query = s"foo(*) over (partition by a order by b $frameTypeSql $boundarySql)" + val expr = windowed(Seq('a), Seq('b.asc), SpecifiedWindowFrame(frameType, begin, end)) + assertEqual(query, expr) + } + } + + // We cannot use non integer constants. + intercept("foo(*) over (partition by a order by b rows 10.0 preceding)", + "Frame bound value must be a constant integer.") + + // We cannot use an arbitrary expression. + intercept("foo(*) over (partition by a order by b rows exp(b) preceding)", + "Frame bound value must be a constant integer.") + } + + test("row constructor") { + // Note that '(a)' will be interpreted as a nested expression. + assertEqual("(a, b)", CreateStruct(Seq('a, 'b))) + assertEqual("(a, b, c)", CreateStruct(Seq('a, 'b, 'c))) + } + + test("scalar sub-query") { + assertEqual( + "(select max(val) from tbl) > current", + ScalarSubquery(table("tbl").select('max.function('val))) > 'current) + assertEqual( + "a = (select b from s)", + 'a === ScalarSubquery(table("s").select('b))) + } + + test("case when") { + assertEqual("case a when 1 then b when 2 then c else d end", + CaseKeyWhen('a, Seq(1, 'b, 2, 'c, 'd))) + assertEqual("case when a = 1 then b when a = 2 then c else d end", + CaseWhen(Seq(('a === 1, 'b.expr), ('a === 2, 'c.expr)), 'd)) + } + + test("dereference") { + assertEqual("a.b", UnresolvedAttribute("a.b")) + assertEqual("`select`.b", UnresolvedAttribute("select.b")) + assertEqual("(a + b).b", ('a + 'b).getField("b")) // This will fail analysis. + assertEqual("struct(a, b).b", 'struct.function('a, 'b).getField("b")) + } + + test("reference") { + // Regular + assertEqual("a", 'a) + + // Starting with a digit. + assertEqual("1a", Symbol("1a")) + + // Quoted using a keyword. + assertEqual("`select`", 'select) + + // Unquoted using an unreserved keyword. + assertEqual("columns", 'columns) + } + + test("subscript") { + assertEqual("a[b]", 'a.getItem('b)) + assertEqual("a[1 + 1]", 'a.getItem(Literal(1) + 1)) + assertEqual("`c`.a[b]", UnresolvedAttribute("c.a").getItem('b)) + } + + test("parenthesis") { + assertEqual("(a)", 'a) + assertEqual("r * (a + b)", 'r * ('a + 'b)) + } + + test("type constructors") { + // Dates. + assertEqual("dAte '2016-03-11'", Literal(Date.valueOf("2016-03-11"))) + intercept[IllegalArgumentException] { + parseExpression("DAtE 'mar 11 2016'") + } + + // Timestamps. + assertEqual("tImEstAmp '2016-03-11 20:54:00.000'", + Literal(Timestamp.valueOf("2016-03-11 20:54:00.000"))) + intercept[IllegalArgumentException] { + parseExpression("timestamP '2016-33-11 20:54:00.000'") + } + + // Unsupported datatype. + intercept("GEO '(10,-6)'", "Literals of type 'GEO' are currently not supported.") + } + + test("literals") { + // NULL + assertEqual("null", Literal(null)) + + // Boolean + assertEqual("trUe", Literal(true)) + assertEqual("False", Literal(false)) + + // Integral should have the narrowest possible type + assertEqual("787324", Literal(787324)) + assertEqual("7873247234798249234", Literal(7873247234798249234L)) + assertEqual("78732472347982492793712334", + Literal(BigDecimal("78732472347982492793712334").underlying())) + + // Decimal + assertEqual("7873247234798249279371.2334", + Literal(BigDecimal("7873247234798249279371.2334").underlying())) + + // Scientific Decimal + assertEqual("9.0e1", 90d) + assertEqual(".9e+2", 90d) + assertEqual("0.9e+2", 90d) + assertEqual("900e-1", 90d) + assertEqual("900.0E-1", 90d) + assertEqual("9.e+1", 90d) + intercept(".e3") + + // Tiny Int Literal + assertEqual("10Y", Literal(10.toByte)) + intercept("-1000Y") + + // Small Int Literal + assertEqual("10S", Literal(10.toShort)) + intercept("40000S") + + // Long Int Literal + assertEqual("10L", Literal(10L)) + intercept("78732472347982492793712334L") + + // Double Literal + assertEqual("10.0D", Literal(10.0D)) + // TODO we need to figure out if we should throw an exception here! + assertEqual("1E309", Literal(Double.PositiveInfinity)) + } + + test("strings") { + // Single Strings. + assertEqual("\"hello\"", "hello") + assertEqual("'hello'", "hello") + + // Multi-Strings. + assertEqual("\"hello\" 'world'", "helloworld") + assertEqual("'hello' \" \" 'world'", "hello world") + + // 'LIKE' string literals. Notice that an escaped '%' is the same as an escaped '\' and a + // regular '%'; to get the correct result you need to add another escaped '\'. + // TODO figure out if we shouldn't change the ParseUtils.unescapeSQLString method? + assertEqual("'pattern%'", "pattern%") + assertEqual("'no-pattern\\%'", "no-pattern\\%") + assertEqual("'pattern\\\\%'", "pattern\\%") + assertEqual("'pattern\\\\\\%'", "pattern\\\\%") + + // Escaped characters. + // See: http://dev.mysql.com/doc/refman/5.7/en/string-literals.html + assertEqual("'\\0'", "\u0000") // ASCII NUL (X'00') + assertEqual("'\\''", "\'") // Single quote + assertEqual("'\\\"'", "\"") // Double quote + assertEqual("'\\b'", "\b") // Backspace + assertEqual("'\\n'", "\n") // Newline + assertEqual("'\\r'", "\r") // Carriage return + assertEqual("'\\t'", "\t") // Tab character + assertEqual("'\\Z'", "\u001A") // ASCII 26 - CTRL + Z (EOF on windows) + + // Octals + assertEqual("'\\110\\145\\154\\154\\157\\041'", "Hello!") + + // Unicode + assertEqual("'\\u0087\\u0111\\u0114\\u0108\\u0100\\u0032\\u0058\\u0041'", "World :)") + } + + test("intervals") { + def intervalLiteral(u: String, s: String): Literal = { + Literal(CalendarInterval.fromSingleUnitString(u, s)) + } + + // Empty interval statement + intercept("interval", "at least one time unit should be given for interval literal") + + // Single Intervals. + val units = Seq( + "year", + "month", + "week", + "day", + "hour", + "minute", + "second", + "millisecond", + "microsecond") + val forms = Seq("", "s") + val values = Seq("0", "10", "-7", "21") + units.foreach { unit => + forms.foreach { form => + values.foreach { value => + val expected = intervalLiteral(unit, value) + assertEqual(s"interval $value $unit$form", expected) + assertEqual(s"interval '$value' $unit$form", expected) + } + } + } + + // Hive nanosecond notation. + assertEqual("interval 13.123456789 seconds", intervalLiteral("second", "13.123456789")) + assertEqual("interval -13.123456789 second", intervalLiteral("second", "-13.123456789")) + + // Non Existing unit + intercept("interval 10 nanoseconds", "No interval can be constructed") + + // Year-Month intervals. + val yearMonthValues = Seq("123-10", "496-0", "-2-3", "-123-0") + yearMonthValues.foreach { value => + val result = Literal(CalendarInterval.fromYearMonthString(value)) + assertEqual(s"interval '$value' year to month", result) + } + + // Day-Time intervals. + val datTimeValues = Seq( + "99 11:22:33.123456789", + "-99 11:22:33.123456789", + "10 9:8:7.123456789", + "1 0:0:0", + "-1 0:0:0", + "1 0:0:1") + datTimeValues.foreach { value => + val result = Literal(CalendarInterval.fromDayTimeString(value)) + assertEqual(s"interval '$value' day to second", result) + } + + // Unknown FROM TO intervals + intercept("interval 10 month to second", "Intervals FROM month TO second are not supported.") + + // Composed intervals. + assertEqual( + "interval 3 months 22 seconds 1 millisecond", + Literal(new CalendarInterval(3, 22001000L))) + assertEqual( + "interval 3 years '-1-10' year to month 3 weeks '1 0:0:2' day to second", + Literal(new CalendarInterval(14, + 22 * CalendarInterval.MICROS_PER_DAY + 2 * CalendarInterval.MICROS_PER_SECOND))) + } + + test("composed expressions") { + assertEqual("1 + r.r As q", (Literal(1) + UnresolvedAttribute("r.r")).as("q")) + assertEqual("1 - f('o', o(bar))", Literal(1) - 'f.function("o", 'o.function('bar))) + intercept("1 - f('o', o(bar)) hello * world", "mismatched input '*'") + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ng/PlanParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ng/PlanParserSuite.scala new file mode 100644 index 0000000000..4206d22ca7 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ng/PlanParserSuite.scala @@ -0,0 +1,429 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.parser.ng + +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans._ +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.types.{BooleanType, IntegerType} + +class PlanParserSuite extends PlanTest { + import CatalystSqlParser._ + import org.apache.spark.sql.catalyst.dsl.expressions._ + import org.apache.spark.sql.catalyst.dsl.plans._ + + def assertEqual(sqlCommand: String, plan: LogicalPlan): Unit = { + comparePlans(parsePlan(sqlCommand), plan) + } + + def intercept(sqlCommand: String, messages: String*): Unit = { + val e = intercept[ParseException](parsePlan(sqlCommand)) + messages.foreach { message => + assert(e.message.contains(message)) + } + } + + test("case insensitive") { + val plan = table("a").select(star()) + assertEqual("sELEct * FroM a", plan) + assertEqual("select * fRoM a", plan) + assertEqual("SELECT * FROM a", plan) + } + + test("show functions") { + assertEqual("show functions", ShowFunctions(None, None)) + assertEqual("show functions foo", ShowFunctions(None, Some("foo"))) + assertEqual("show functions foo.bar", ShowFunctions(Some("foo"), Some("bar"))) + assertEqual("show functions 'foo\\\\.*'", ShowFunctions(None, Some("foo\\.*"))) + intercept("show functions foo.bar.baz", "SHOW FUNCTIONS unsupported name") + } + + test("describe function") { + assertEqual("describe function bar", DescribeFunction("bar", isExtended = false)) + assertEqual("describe function extended bar", DescribeFunction("bar", isExtended = true)) + assertEqual("describe function foo.bar", DescribeFunction("foo.bar", isExtended = false)) + assertEqual("describe function extended f.bar", DescribeFunction("f.bar", isExtended = true)) + } + + test("set operations") { + val a = table("a").select(star()) + val b = table("b").select(star()) + + assertEqual("select * from a union select * from b", Distinct(a.union(b))) + assertEqual("select * from a union distinct select * from b", Distinct(a.union(b))) + assertEqual("select * from a union all select * from b", a.union(b)) + assertEqual("select * from a except select * from b", a.except(b)) + intercept("select * from a except all select * from b", "EXCEPT ALL is not supported.") + assertEqual("select * from a except distinct select * from b", a.except(b)) + assertEqual("select * from a intersect select * from b", a.intersect(b)) + intercept("select * from a intersect all select * from b", "INTERSECT ALL is not supported.") + assertEqual("select * from a intersect distinct select * from b", a.intersect(b)) + } + + test("common table expressions") { + def cte(plan: LogicalPlan, namedPlans: (String, LogicalPlan)*): With = { + val ctes = namedPlans.map { + case (name, cte) => + name -> SubqueryAlias(name, cte) + }.toMap + With(plan, ctes) + } + assertEqual( + "with cte1 as (select * from a) select * from cte1", + cte(table("cte1").select(star()), "cte1" -> table("a").select(star()))) + assertEqual( + "with cte1 (select 1) select * from cte1", + cte(table("cte1").select(star()), "cte1" -> OneRowRelation.select(1))) + assertEqual( + "with cte1 (select 1), cte2 as (select * from cte1) select * from cte2", + cte(table("cte2").select(star()), + "cte1" -> OneRowRelation.select(1), + "cte2" -> table("cte1").select(star()))) + intercept( + "with cte1 (select 1), cte1 as (select 1 from cte1) select * from cte1", + "Name 'cte1' is used for multiple common table expressions") + } + + test("simple select query") { + assertEqual("select 1", OneRowRelation.select(1)) + assertEqual("select a, b", OneRowRelation.select('a, 'b)) + assertEqual("select a, b from db.c", table("db", "c").select('a, 'b)) + assertEqual("select a, b from db.c where x < 1", table("db", "c").where('x < 1).select('a, 'b)) + assertEqual( + "select a, b from db.c having x < 1", + table("db", "c").select('a, 'b).where(('x < 1).cast(BooleanType))) + assertEqual("select distinct a, b from db.c", Distinct(table("db", "c").select('a, 'b))) + assertEqual("select all a, b from db.c", table("db", "c").select('a, 'b)) + } + + test("reverse select query") { + assertEqual("from a", table("a")) + assertEqual("from a select b, c", table("a").select('b, 'c)) + assertEqual( + "from db.a select b, c where d < 1", table("db", "a").where('d < 1).select('b, 'c)) + assertEqual("from a select distinct b, c", Distinct(table("a").select('b, 'c))) + assertEqual( + "from (from a union all from b) c select *", + table("a").union(table("b")).as("c").select(star())) + } + + test("transform query spec") { + val p = ScriptTransformation(Seq('a, 'b), "func", Seq.empty, table("e"), null) + assertEqual("select transform(a, b) using 'func' from e where f < 10", + p.copy(child = p.child.where('f < 10), output = Seq('key.string, 'value.string))) + assertEqual("map a, b using 'func' as c, d from e", + p.copy(output = Seq('c.string, 'd.string))) + assertEqual("reduce a, b using 'func' as (c: int, d decimal(10, 0)) from e", + p.copy(output = Seq('c.int, 'd.decimal(10, 0)))) + } + + test("multi select query") { + assertEqual( + "from a select * select * where s < 10", + table("a").select(star()).union(table("a").where('s < 10).select(star()))) + intercept( + "from a select * select * from x where a.s < 10", + "Multi-Insert queries cannot have a FROM clause in their individual SELECT statements") + assertEqual( + "from a insert into tbl1 select * insert into tbl2 select * where s < 10", + table("a").select(star()).insertInto("tbl1").union( + table("a").where('s < 10).select(star()).insertInto("tbl2"))) + } + + test("query organization") { + // Test all valid combinations of order by/sort by/distribute by/cluster by/limit/windows + val baseSql = "select * from t" + val basePlan = table("t").select(star()) + + val ws = Map("w1" -> WindowSpecDefinition(Seq.empty, Seq.empty, UnspecifiedFrame)) + val limitWindowClauses = Seq( + ("", (p: LogicalPlan) => p), + (" limit 10", (p: LogicalPlan) => p.limit(10)), + (" window w1 as ()", (p: LogicalPlan) => WithWindowDefinition(ws, p)), + (" window w1 as () limit 10", (p: LogicalPlan) => WithWindowDefinition(ws, p).limit(10)) + ) + + val orderSortDistrClusterClauses = Seq( + ("", basePlan), + (" order by a, b desc", basePlan.orderBy('a.asc, 'b.desc)), + (" sort by a, b desc", basePlan.sortBy('a.asc, 'b.desc)), + (" distribute by a, b", basePlan.distribute('a, 'b)), + (" distribute by a sort by b", basePlan.distribute('a).sortBy('b.asc)), + (" cluster by a, b", basePlan.distribute('a, 'b).sortBy('a.asc, 'b.asc)) + ) + + orderSortDistrClusterClauses.foreach { + case (s1, p1) => + limitWindowClauses.foreach { + case (s2, pf2) => + assertEqual(baseSql + s1 + s2, pf2(p1)) + } + } + + val msg = "Combination of ORDER BY/SORT BY/DISTRIBUTE BY/CLUSTER BY is not supported" + intercept(s"$baseSql order by a sort by a", msg) + intercept(s"$baseSql cluster by a distribute by a", msg) + intercept(s"$baseSql order by a cluster by a", msg) + intercept(s"$baseSql order by a distribute by a", msg) + } + + test("insert into") { + val sql = "select * from t" + val plan = table("t").select(star()) + def insert( + partition: Map[String, Option[String]], + overwrite: Boolean = false, + ifNotExists: Boolean = false): LogicalPlan = + InsertIntoTable(table("s"), partition, plan, overwrite, ifNotExists) + + // Single inserts + assertEqual(s"insert overwrite table s $sql", + insert(Map.empty, overwrite = true)) + assertEqual(s"insert overwrite table s if not exists $sql", + insert(Map.empty, overwrite = true, ifNotExists = true)) + assertEqual(s"insert into s $sql", + insert(Map.empty)) + assertEqual(s"insert into table s partition (c = 'd', e = 1) $sql", + insert(Map("c" -> Option("d"), "e" -> Option("1")))) + assertEqual(s"insert overwrite table s partition (c = 'd', x) if not exists $sql", + insert(Map("c" -> Option("d"), "x" -> None), overwrite = true, ifNotExists = true)) + + // Multi insert + val plan2 = table("t").where('x > 5).select(star()) + assertEqual("from t insert into s select * limit 1 insert into u select * where x > 5", + InsertIntoTable( + table("s"), Map.empty, plan.limit(1), overwrite = false, ifNotExists = false).union( + InsertIntoTable( + table("u"), Map.empty, plan2, overwrite = false, ifNotExists = false))) + } + + test("aggregation") { + val sql = "select a, b, sum(c) as c from d group by a, b" + + // Normal + assertEqual(sql, table("d").groupBy('a, 'b)('a, 'b, 'sum.function('c).as("c"))) + + // Cube + assertEqual(s"$sql with cube", + table("d").groupBy(Cube(Seq('a, 'b)))('a, 'b, 'sum.function('c).as("c"))) + + // Rollup + assertEqual(s"$sql with rollup", + table("d").groupBy(Rollup(Seq('a, 'b)))('a, 'b, 'sum.function('c).as("c"))) + + // Grouping Sets + assertEqual(s"$sql grouping sets((a, b), (a), ())", + GroupingSets(Seq(0, 1, 3), Seq('a, 'b), table("d"), Seq('a, 'b, 'sum.function('c).as("c")))) + intercept(s"$sql grouping sets((a, b), (c), ())", + "c doesn't show up in the GROUP BY list") + } + + test("limit") { + val sql = "select * from t" + val plan = table("t").select(star()) + assertEqual(s"$sql limit 10", plan.limit(10)) + assertEqual(s"$sql limit cast(9 / 4 as int)", plan.limit(Cast(Literal(9) / 4, IntegerType))) + } + + test("window spec") { + // Note that WindowSpecs are testing in the ExpressionParserSuite + val sql = "select * from t" + val plan = table("t").select(star()) + val spec = WindowSpecDefinition(Seq('a, 'b), Seq('c.asc), + SpecifiedWindowFrame(RowFrame, ValuePreceding(1), ValueFollowing(1))) + + // Test window resolution. + val ws1 = Map("w1" -> spec, "w2" -> spec, "w3" -> spec) + assertEqual( + s"""$sql + |window w1 as (partition by a, b order by c rows between 1 preceding and 1 following), + | w2 as w1, + | w3 as w1""".stripMargin, + WithWindowDefinition(ws1, plan)) + + // Fail with no reference. + intercept(s"$sql window w2 as w1", "Cannot resolve window reference 'w1'") + + // Fail when resolved reference is not a window spec. + intercept( + s"""$sql + |window w1 as (partition by a, b order by c rows between 1 preceding and 1 following), + | w2 as w1, + | w3 as w2""".stripMargin, + "Window reference 'w2' is not a window specification" + ) + } + + test("lateral view") { + // Single lateral view + assertEqual( + "select * from t lateral view explode(x) expl as x", + table("t") + .generate(Explode('x), join = true, outer = false, Some("expl"), Seq("x")) + .select(star())) + + // Multiple lateral views + assertEqual( + """select * + |from t + |lateral view explode(x) expl + |lateral view outer json_tuple(x, y) jtup q, z""".stripMargin, + table("t") + .generate(Explode('x), join = true, outer = false, Some("expl"), Seq.empty) + .generate(JsonTuple(Seq('x, 'y)), join = true, outer = true, Some("jtup"), Seq("q", "z")) + .select(star())) + + // Multi-Insert lateral views. + val from = table("t1").generate(Explode('x), join = true, outer = false, Some("expl"), Seq("x")) + assertEqual( + """from t1 + |lateral view explode(x) expl as x + |insert into t2 + |select * + |lateral view json_tuple(x, y) jtup q, z + |insert into t3 + |select * + |where s < 10 + """.stripMargin, + Union(from + .generate(JsonTuple(Seq('x, 'y)), join = true, outer = false, Some("jtup"), Seq("q", "z")) + .select(star()) + .insertInto("t2"), + from.where('s < 10).select(star()).insertInto("t3"))) + + // Unsupported generator. + intercept( + "select * from t lateral view posexplode(x) posexpl as x, y", + "Generator function 'posexplode' is not supported") + } + + test("joins") { + // Test single joins. + val testUnconditionalJoin = (sql: String, jt: JoinType) => { + assertEqual( + s"select * from t as tt $sql u", + table("t").as("tt").join(table("u"), jt, None).select(star())) + } + val testConditionalJoin = (sql: String, jt: JoinType) => { + assertEqual( + s"select * from t $sql u as uu on a = b", + table("t").join(table("u").as("uu"), jt, Option('a === 'b)).select(star())) + } + val testNaturalJoin = (sql: String, jt: JoinType) => { + assertEqual( + s"select * from t tt natural $sql u as uu", + table("t").as("tt").join(table("u").as("uu"), NaturalJoin(jt), None).select(star())) + } + val testUsingJoin = (sql: String, jt: JoinType) => { + assertEqual( + s"select * from t $sql u using(a, b)", + table("t").join(table("u"), UsingJoin(jt, Seq('a.attr, 'b.attr)), None).select(star())) + } + val testAll = Seq(testUnconditionalJoin, testConditionalJoin, testNaturalJoin, testUsingJoin) + + def test(sql: String, jt: JoinType, tests: Seq[(String, JoinType) => Unit]): Unit = { + tests.foreach(_(sql, jt)) + } + test("cross join", Inner, Seq(testUnconditionalJoin)) + test(",", Inner, Seq(testUnconditionalJoin)) + test("join", Inner, testAll) + test("inner join", Inner, testAll) + test("left join", LeftOuter, testAll) + test("left outer join", LeftOuter, testAll) + test("right join", RightOuter, testAll) + test("right outer join", RightOuter, testAll) + test("full join", FullOuter, testAll) + test("full outer join", FullOuter, testAll) + + // Test multiple consecutive joins + assertEqual( + "select * from a join b join c right join d", + table("a").join(table("b")).join(table("c")).join(table("d"), RightOuter).select(star())) + } + + test("sampled relations") { + val sql = "select * from t" + assertEqual(s"$sql tablesample(100 rows)", + table("t").limit(100).select(star())) + assertEqual(s"$sql tablesample(43 percent) as x", + Sample(0, .43d, withReplacement = false, 10L, table("t").as("x"))(true).select(star())) + assertEqual(s"$sql tablesample(bucket 4 out of 10) as x", + Sample(0, .4d, withReplacement = false, 10L, table("t").as("x"))(true).select(star())) + intercept(s"$sql tablesample(bucket 4 out of 10 on x) as x", + "TABLESAMPLE(BUCKET x OUT OF y ON id) is not supported") + intercept(s"$sql tablesample(bucket 11 out of 10) as x", + s"Sampling fraction (${11.0/10.0}) must be on interval [0, 1]") + } + + test("sub-query") { + val plan = table("t0").select('id) + assertEqual("select id from (t0)", plan) + assertEqual("select id from ((((((t0))))))", plan) + assertEqual( + "(select * from t1) union distinct (select * from t2)", + Distinct(table("t1").select(star()).union(table("t2").select(star())))) + assertEqual( + "select * from ((select * from t1) union (select * from t2)) t", + Distinct( + table("t1").select(star()).union(table("t2").select(star()))).as("t").select(star())) + assertEqual( + """select id + |from (((select id from t0) + | union all + | (select id from t0)) + | union all + | (select id from t0)) as u_1 + """.stripMargin, + plan.union(plan).union(plan).as("u_1").select('id)) + } + + test("scalar sub-query") { + assertEqual( + "select (select max(b) from s) ss from t", + table("t").select(ScalarSubquery(table("s").select('max.function('b))).as("ss"))) + assertEqual( + "select * from t where a = (select b from s)", + table("t").where('a === ScalarSubquery(table("s").select('b))).select(star())) + assertEqual( + "select g from t group by g having a > (select b from s)", + table("t") + .groupBy('g)('g) + .where(('a > ScalarSubquery(table("s").select('b))).cast(BooleanType))) + } + + test("table reference") { + assertEqual("table t", table("t")) + assertEqual("table d.t", table("d", "t")) + } + + test("inline table") { + assertEqual("values 1, 2, 3, 4", LocalRelation.fromExternalRows( + Seq('col1.int), + Seq(1, 2, 3, 4).map(x => Row(x)))) + assertEqual( + "values (1, 'a'), (2, 'b'), (3, 'c') as tbl(a, b)", + LocalRelation.fromExternalRows( + Seq('a.int, 'b.string), + Seq((1, "a"), (2, "b"), (3, "c")).map(x => Row(x._1, x._2))).as("tbl")) + intercept("values (a, 'a'), (b, 'b')", + "All expressions in an inline table must be constants.") + intercept("values (1, 'a'), (2, 'b') as tbl(a, b, c)", + "Number of aliases must match the number of fields in an inline table.") + intercept[ArrayIndexOutOfBoundsException](parsePlan("values (1, 'a'), (2, 'b', 5Y)")) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ng/TableIdentifierParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ng/TableIdentifierParserSuite.scala new file mode 100644 index 0000000000..0874322187 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ng/TableIdentifierParserSuite.scala @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.parser.ng + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.TableIdentifier + +class TableIdentifierParserSuite extends SparkFunSuite { + import CatalystSqlParser._ + + test("table identifier") { + // Regular names. + assert(TableIdentifier("q") === parseTableIdentifier("q")) + assert(TableIdentifier("q", Option("d")) === parseTableIdentifier("d.q")) + + // Illegal names. + intercept[ParseException](parseTableIdentifier("")) + intercept[ParseException](parseTableIdentifier("d.q.g")) + + // SQL Keywords. + val keywords = Seq("select", "from", "where", "left", "right") + keywords.foreach { keyword => + intercept[ParseException](parseTableIdentifier(keyword)) + assert(TableIdentifier(keyword) === parseTableIdentifier(s"`$keyword`")) + assert(TableIdentifier(keyword, Option("db")) === parseTableIdentifier(s"db.`$keyword`")) + } + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala index 0541844e0b..aa5d4330d3 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.plans import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, OneRowRelation} +import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, OneRowRelation, Sample} import org.apache.spark.sql.catalyst.util._ /** @@ -32,6 +32,8 @@ abstract class PlanTest extends SparkFunSuite with PredicateHelper { */ protected def normalizeExprIds(plan: LogicalPlan) = { plan transformAllExpressions { + case s: ScalarSubquery => + ScalarSubquery(s.query, ExprId(0)) case a: AttributeReference => AttributeReference(a.name, a.dataType, a.nullable)(exprId = ExprId(0)) case a: Alias => @@ -40,21 +42,25 @@ abstract class PlanTest extends SparkFunSuite with PredicateHelper { } /** - * Normalizes the filter conditions that appear in the plan. For instance, - * ((expr 1 && expr 2) && expr 3), (expr 1 && expr 2 && expr 3), (expr 3 && (expr 1 && expr 2) - * etc., will all now be equivalent. + * Normalizes plans: + * - Filter the filter conditions that appear in a plan. For instance, + * ((expr 1 && expr 2) && expr 3), (expr 1 && expr 2 && expr 3), (expr 3 && (expr 1 && expr 2) + * etc., will all now be equivalent. + * - Sample the seed will replaced by 0L. */ - private def normalizeFilters(plan: LogicalPlan) = { + private def normalizePlan(plan: LogicalPlan): LogicalPlan = { plan transform { case filter @ Filter(condition: Expression, child: LogicalPlan) => Filter(splitConjunctivePredicates(condition).sortBy(_.hashCode()).reduce(And), child) + case sample: Sample => + sample.copy(seed = 0L)(true) } } /** Fails the test if the two plans do not match */ protected def comparePlans(plan1: LogicalPlan, plan2: LogicalPlan) { - val normalized1 = normalizeFilters(normalizeExprIds(plan1)) - val normalized2 = normalizeFilters(normalizeExprIds(plan2)) + val normalized1 = normalizePlan(normalizeExprIds(plan1)) + val normalized2 = normalizePlan(normalizeExprIds(plan2)) if (normalized1 != normalized2) { fail( s""" diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala new file mode 100644 index 0000000000..c098fa99c2 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala @@ -0,0 +1,219 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution + +import scala.collection.JavaConverters._ + +import org.apache.spark.sql.SaveMode +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.parser.ng.{AbstractSqlParser, AstBuilder} +import org.apache.spark.sql.catalyst.parser.ng.SqlBaseParser._ +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, OneRowRelation} +import org.apache.spark.sql.execution.command.{DescribeCommand => _, _} +import org.apache.spark.sql.execution.datasources._ + +/** + * Concrete parser for Spark SQL statements. + */ +object SparkSqlParser extends AbstractSqlParser{ + val astBuilder = new SparkSqlAstBuilder +} + +/** + * Builder that converts an ANTLR ParseTree into a LogicalPlan/Expression/TableIdentifier. + */ +class SparkSqlAstBuilder extends AstBuilder { + import org.apache.spark.sql.catalyst.parser.ng.ParserUtils._ + + /** + * Create a [[SetCommand]] logical plan. + * + * Note that we assume that everything after the SET keyword is assumed to be a part of the + * key-value pair. The split between key and value is made by searching for the first `=` + * character in the raw string. + */ + override def visitSetConfiguration(ctx: SetConfigurationContext): LogicalPlan = withOrigin(ctx) { + // Construct the command. + val raw = remainder(ctx.SET.getSymbol) + val keyValueSeparatorIndex = raw.indexOf('=') + if (keyValueSeparatorIndex >= 0) { + val key = raw.substring(0, keyValueSeparatorIndex).trim + val value = raw.substring(keyValueSeparatorIndex + 1).trim + SetCommand(Some(key -> Option(value))) + } else if (raw.nonEmpty) { + SetCommand(Some(raw.trim -> None)) + } else { + SetCommand(None) + } + } + + /** + * Create a [[SetDatabaseCommand]] logical plan. + */ + override def visitUse(ctx: UseContext): LogicalPlan = withOrigin(ctx) { + SetDatabaseCommand(ctx.db.getText) + } + + /** + * Create a [[ShowTablesCommand]] logical plan. + */ + override def visitShowTables(ctx: ShowTablesContext): LogicalPlan = withOrigin(ctx) { + if (ctx.LIKE != null) { + logWarning("SHOW TABLES LIKE option is ignored.") + } + ShowTablesCommand(Option(ctx.db).map(_.getText)) + } + + /** + * Create a [[RefreshTable]] logical plan. + */ + override def visitRefreshTable(ctx: RefreshTableContext): LogicalPlan = withOrigin(ctx) { + RefreshTable(visitTableIdentifier(ctx.tableIdentifier)) + } + + /** + * Create a [[CacheTableCommand]] logical plan. + */ + override def visitCacheTable(ctx: CacheTableContext): LogicalPlan = withOrigin(ctx) { + val query = Option(ctx.query).map(plan) + CacheTableCommand(ctx.identifier.getText, query, ctx.LAZY != null) + } + + /** + * Create an [[UncacheTableCommand]] logical plan. + */ + override def visitUncacheTable(ctx: UncacheTableContext): LogicalPlan = withOrigin(ctx) { + UncacheTableCommand(ctx.identifier.getText) + } + + /** + * Create a [[ClearCacheCommand]] logical plan. + */ + override def visitClearCache(ctx: ClearCacheContext): LogicalPlan = withOrigin(ctx) { + ClearCacheCommand + } + + /** + * Create an [[ExplainCommand]] logical plan. + */ + override def visitExplain(ctx: ExplainContext): LogicalPlan = withOrigin(ctx) { + val options = ctx.explainOption.asScala + if (options.exists(_.FORMATTED != null)) { + logWarning("EXPLAIN FORMATTED option is ignored.") + } + if (options.exists(_.LOGICAL != null)) { + logWarning("EXPLAIN LOGICAL option is ignored.") + } + + // Create the explain comment. + val statement = plan(ctx.statement) + if (isExplainableStatement(statement)) { + ExplainCommand(statement, extended = options.exists(_.EXTENDED != null)) + } else { + ExplainCommand(OneRowRelation) + } + } + + /** + * Determine if a plan should be explained at all. + */ + protected def isExplainableStatement(plan: LogicalPlan): Boolean = plan match { + case _: datasources.DescribeCommand => false + case _ => true + } + + /** + * Create a [[DescribeCommand]] logical plan. + */ + override def visitDescribeTable(ctx: DescribeTableContext): LogicalPlan = withOrigin(ctx) { + // FORMATTED and columns are not supported. Return null and let the parser decide what to do + // with this (create an exception or pass it on to a different system). + if (ctx.describeColName != null || ctx.FORMATTED != null || ctx.partitionSpec != null) { + null + } else { + datasources.DescribeCommand( + visitTableIdentifier(ctx.tableIdentifier), + ctx.EXTENDED != null) + } + } + + /** Type to keep track of a table header. */ + type TableHeader = (TableIdentifier, Boolean, Boolean, Boolean) + + /** + * Validate a create table statement and return the [[TableIdentifier]]. + */ + override def visitCreateTableHeader( + ctx: CreateTableHeaderContext): TableHeader = withOrigin(ctx) { + val temporary = ctx.TEMPORARY != null + val ifNotExists = ctx.EXISTS != null + assert(!temporary || !ifNotExists, + "a CREATE TEMPORARY TABLE statement does not allow IF NOT EXISTS clause.", + ctx) + (visitTableIdentifier(ctx.tableIdentifier), temporary, ifNotExists, ctx.EXTERNAL != null) + } + + /** + * Create a [[CreateTableUsing]] or a [[CreateTableUsingAsSelect]] logical plan. + * + * TODO add bucketing and partitioning. + */ + override def visitCreateTableUsing(ctx: CreateTableUsingContext): LogicalPlan = withOrigin(ctx) { + val (table, temp, ifNotExists, external) = visitCreateTableHeader(ctx.createTableHeader) + if (external) { + logWarning("EXTERNAL option is not supported.") + } + val options = Option(ctx.tablePropertyList).map(visitTablePropertyList).getOrElse(Map.empty) + val provider = ctx.tableProvider.qualifiedName.getText + + if (ctx.query != null) { + // Get the backing query. + val query = plan(ctx.query) + + // Determine the storage mode. + val mode = if (ifNotExists) { + SaveMode.Ignore + } else if (temp) { + SaveMode.Overwrite + } else { + SaveMode.ErrorIfExists + } + CreateTableUsingAsSelect(table, provider, temp, Array.empty, None, mode, options, query) + } else { + val struct = Option(ctx.colTypeList).map(createStructType) + CreateTableUsing(table, struct, provider, temp, options, ifNotExists, managedIfNoPath = false) + } + } + + /** + * Convert a table property list into a key-value map. + */ + override def visitTablePropertyList( + ctx: TablePropertyListContext): Map[String, String] = withOrigin(ctx) { + ctx.tableProperty.asScala.map { property => + // A key can either be a String or a collection of dot separated elements. We need to treat + // these differently. + val key = if (property.key.STRING != null) { + string(property.key.STRING) + } else { + property.key.getText + } + val value = Option(property.value).map(string).orNull + key -> value + }.toMap + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 8abb9d7e4a..7ce15e3f35 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -27,8 +27,8 @@ import org.apache.spark.sql.catalyst.analysis.{Star, UnresolvedFunction} import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ -import org.apache.spark.sql.catalyst.parser.CatalystQl import org.apache.spark.sql.catalyst.plans.logical.BroadcastHint +import org.apache.spark.sql.execution.SparkSqlParser import org.apache.spark.sql.expressions.UserDefinedFunction import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -1172,8 +1172,7 @@ object functions { * @group normal_funcs */ def expr(expr: String): Column = { - val parser = SQLContext.getActive().map(_.sessionState.sqlParser).getOrElse(new CatalystQl()) - Column(parser.parseExpression(expr)) + Column(SparkSqlParser.parseExpression(expr)) } ////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala index e5f02caabc..9bc640763f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala @@ -81,7 +81,7 @@ private[sql] class SessionState(ctx: SQLContext) { /** * Parser that extracts expressions, plans, table identifiers etc. from SQL texts. */ - lazy val sqlParser: ParserInterface = new SparkQl(conf) + lazy val sqlParser: ParserInterface = SparkSqlParser /** * Planner that converts optimized logical plans to physical plans. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index 5af1a4fcd7..a5a4ff13de 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -329,8 +329,8 @@ class JoinSuite extends QueryTest with SharedSQLContext { } test("full outer join") { - upperCaseData.where('N <= 4).registerTempTable("left") - upperCaseData.where('N >= 3).registerTempTable("right") + upperCaseData.where('N <= 4).registerTempTable("`left`") + upperCaseData.where('N >= 3).registerTempTable("`right`") val left = UnresolvedRelation(TableIdentifier("left"), None) val right = UnresolvedRelation(TableIdentifier("right"), None) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index c958eac266..b727e88668 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -1656,7 +1656,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { val e2 = intercept[AnalysisException] { sql("select interval 23 nanosecond") } - assert(e2.message.contains("cannot recognize input near")) + assert(e2.message.contains("No interval can be constructed")) } test("SPARK-8945: add and subtract expressions for interval type") { -- cgit v1.2.3 From d7b58f1461f71ee3c028360eef0ffedd17d6a076 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Mon, 28 Mar 2016 13:07:32 -0700 Subject: [SPARK-14052] [SQL] build a BytesToBytesMap directly in HashedRelation ## What changes were proposed in this pull request? Currently, for the key that can not fit within a long, we build a hash map for UnsafeHashedRelation, it's converted to BytesToBytesMap after serialization and deserialization. We should build a BytesToBytesMap directly to have better memory efficiency. In order to do that, BytesToBytesMap should support multiple (K,V) pair with the same K, Location.putNewKey() is renamed to Location.append(), which could append multiple values for the same key (same Location). `Location.newValue()` is added to find the next value for the same key. ## How was this patch tested? Existing tests. Added benchmark for broadcast hash join with duplicated keys. Author: Davies Liu Closes #11870 from davies/map2. --- .../apache/spark/unsafe/map/BytesToBytesMap.java | 113 ++++++--- .../unsafe/map/AbstractBytesToBytesMapSuite.java | 64 ++++- .../execution/UnsafeFixedWidthAggregationMap.java | 2 +- .../spark/sql/execution/joins/HashedRelation.scala | 281 +++++++++------------ .../sql/execution/joins/ShuffledHashJoin.scala | 18 +- .../sql/execution/BenchmarkWholeStageCodegen.scala | 26 +- .../sql/execution/joins/HashedRelationSuite.scala | 15 +- 7 files changed, 299 insertions(+), 220 deletions(-) diff --git a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java index 9aacb084f6..32958be7a7 100644 --- a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java +++ b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java @@ -56,9 +56,10 @@ import org.apache.spark.util.collection.unsafe.sort.UnsafeSorterSpillWriter; * Bytes 4 to 8: len(k) * Bytes 8 to 8 + len(k): key data * Bytes 8 + len(k) to 8 + len(k) + len(v): value data + * Bytes 8 + len(k) + len(v) to 8 + len(k) + len(v) + 8: pointer to next pair * * This means that the first four bytes store the entire record (key + value) length. This format - * is consistent with {@link org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter}, + * is compatible with {@link org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter}, * so we can pass records from this map directly into the sorter to sort records in place. */ public final class BytesToBytesMap extends MemoryConsumer { @@ -132,7 +133,12 @@ public final class BytesToBytesMap extends MemoryConsumer { /** * Number of keys defined in the map. */ - private int numElements; + private int numKeys; + + /** + * Number of values defined in the map. A key could have multiple values. + */ + private int numValues; /** * The map will be expanded once the number of keys exceeds this threshold. @@ -223,7 +229,12 @@ public final class BytesToBytesMap extends MemoryConsumer { /** * Returns the number of keys defined in the map. */ - public int numElements() { return numElements; } + public int numKeys() { return numKeys; } + + /** + * Returns the number of values defined in the map. A key could have multiple values. + */ + public int numValues() { return numValues; } public final class MapIterator implements Iterator { @@ -311,7 +322,8 @@ public final class BytesToBytesMap extends MemoryConsumer { if (currentPage != null) { int totalLength = Platform.getInt(pageBaseObject, offsetInPage); loc.with(currentPage, offsetInPage); - offsetInPage += 4 + totalLength; + // [total size] [key size] [key] [value] [pointer to next] + offsetInPage += 4 + totalLength + 8; recordsInPage --; return loc; } else { @@ -361,7 +373,7 @@ public final class BytesToBytesMap extends MemoryConsumer { while (numRecords > 0) { int length = Platform.getInt(base, offset); writer.write(base, offset + 4, length, 0); - offset += 4 + length; + offset += 4 + length + 8; numRecords--; } writer.close(); @@ -395,7 +407,7 @@ public final class BytesToBytesMap extends MemoryConsumer { * `lookup()`, the behavior of the returned iterator is undefined. */ public MapIterator iterator() { - return new MapIterator(numElements, loc, false); + return new MapIterator(numValues, loc, false); } /** @@ -409,7 +421,7 @@ public final class BytesToBytesMap extends MemoryConsumer { * `lookup()`, the behavior of the returned iterator is undefined. */ public MapIterator destructiveIterator() { - return new MapIterator(numElements, loc, true); + return new MapIterator(numValues, loc, true); } /** @@ -559,6 +571,20 @@ public final class BytesToBytesMap extends MemoryConsumer { return this; } + /** + * Find the next pair that has the same key as current one. + */ + public boolean nextValue() { + assert isDefined; + long nextAddr = Platform.getLong(baseObject, valueOffset + valueLength); + if (nextAddr == 0) { + return false; + } else { + updateAddressesAndSizes(nextAddr); + return true; + } + } + /** * Returns the memory page that contains the current record. * This is only valid if this is returned by {@link BytesToBytesMap#iterator()}. @@ -625,10 +651,9 @@ public final class BytesToBytesMap extends MemoryConsumer { } /** - * Store a new key and value. This method may only be called once for a given key; if you want - * to update the value associated with a key, then you can directly manipulate the bytes stored - * at the value address. The return value indicates whether the put succeeded or whether it - * failed because additional memory could not be acquired. + * Append a new value for the key. This method could be called multiple times for a given key. + * The return value indicates whether the put succeeded or whether it failed because additional + * memory could not be acquired. *

* It is only valid to call this method immediately after calling `lookup()` using the same key. *

@@ -637,7 +662,7 @@ public final class BytesToBytesMap extends MemoryConsumer { *

*

* After calling this method, calls to `get[Key|Value]Address()` and `get[Key|Value]Length` - * will return information on the data stored by this `putNewKey` call. + * will return information on the data stored by this `append` call. *

*

* As an example usage, here's the proper way to store a new key: @@ -645,7 +670,7 @@ public final class BytesToBytesMap extends MemoryConsumer { *

      *   Location loc = map.lookup(keyBase, keyOffset, keyLength);
      *   if (!loc.isDefined()) {
-     *     if (!loc.putNewKey(keyBase, keyOffset, keyLength, ...)) {
+     *     if (!loc.append(keyBase, keyOffset, keyLength, ...)) {
      *       // handle failure to grow map (by spilling, for example)
      *     }
      *   }
@@ -657,26 +682,23 @@ public final class BytesToBytesMap extends MemoryConsumer {
      * @return true if the put() was successful and false if the put() failed because memory could
      *         not be acquired.
      */
-    public boolean putNewKey(Object keyBase, long keyOffset, int keyLength,
-        Object valueBase, long valueOffset, int valueLength) {
-      assert (!isDefined) : "Can only set value once for a key";
-      assert (keyLength % 8 == 0);
-      assert (valueLength % 8 == 0);
-      assert(longArray != null);
-
+    public boolean append(Object kbase, long koff, int klen, Object vbase, long voff, int vlen) {
+      assert (klen % 8 == 0);
+      assert (vlen % 8 == 0);
+      assert (longArray != null);
 
-      if (numElements == MAX_CAPACITY
+      if (numKeys == MAX_CAPACITY
         // The map could be reused from last spill (because of no enough memory to grow),
         // then we don't try to grow again if hit the `growthThreshold`.
-        || !canGrowArray && numElements > growthThreshold) {
+        || !canGrowArray && numKeys > growthThreshold) {
         return false;
       }
 
       // Here, we'll copy the data into our data pages. Because we only store a relative offset from
       // the key address instead of storing the absolute address of the value, the key and value
       // must be stored in the same memory page.
-      // (8 byte key length) (key) (value)
-      final long recordLength = 8 + keyLength + valueLength;
+      // (8 byte key length) (key) (value) (8 byte pointer to next value)
+      final long recordLength = 8 + klen + vlen + 8;
       if (currentPage == null || currentPage.size() - pageCursor < recordLength) {
         if (!acquireNewPage(recordLength + 4L)) {
           return false;
@@ -687,30 +709,40 @@ public final class BytesToBytesMap extends MemoryConsumer {
       final Object base = currentPage.getBaseObject();
       long offset = currentPage.getBaseOffset() + pageCursor;
       final long recordOffset = offset;
-      Platform.putInt(base, offset, keyLength + valueLength + 4);
-      Platform.putInt(base, offset + 4, keyLength);
+      Platform.putInt(base, offset, klen + vlen + 4);
+      Platform.putInt(base, offset + 4, klen);
       offset += 8;
-      Platform.copyMemory(keyBase, keyOffset, base, offset, keyLength);
-      offset += keyLength;
-      Platform.copyMemory(valueBase, valueOffset, base, offset, valueLength);
+      Platform.copyMemory(kbase, koff, base, offset, klen);
+      offset += klen;
+      Platform.copyMemory(vbase, voff, base, offset, vlen);
+      offset += vlen;
+      Platform.putLong(base, offset, 0);
 
       // --- Update bookkeeping data structures ----------------------------------------------------
       offset = currentPage.getBaseOffset();
       Platform.putInt(base, offset, Platform.getInt(base, offset) + 1);
       pageCursor += recordLength;
-      numElements++;
       final long storedKeyAddress = taskMemoryManager.encodePageNumberAndOffset(
         currentPage, recordOffset);
-      longArray.set(pos * 2, storedKeyAddress);
-      longArray.set(pos * 2 + 1, keyHashcode);
-      updateAddressesAndSizes(storedKeyAddress);
-      isDefined = true;
+      numValues++;
+      if (isDefined) {
+        // put this pair at the end of chain
+        while (nextValue()) { /* do nothing */ }
+        Platform.putLong(baseObject, valueOffset + valueLength, storedKeyAddress);
+        nextValue();  // point to new added value
+      } else {
+        numKeys++;
+        longArray.set(pos * 2, storedKeyAddress);
+        longArray.set(pos * 2 + 1, keyHashcode);
+        updateAddressesAndSizes(storedKeyAddress);
+        isDefined = true;
 
-      if (numElements > growthThreshold && longArray.size() < MAX_CAPACITY) {
-        try {
-          growAndRehash();
-        } catch (OutOfMemoryError oom) {
-          canGrowArray = false;
+        if (numKeys > growthThreshold && longArray.size() < MAX_CAPACITY) {
+          try {
+            growAndRehash();
+          } catch (OutOfMemoryError oom) {
+            canGrowArray = false;
+          }
         }
       }
       return true;
@@ -866,7 +898,8 @@ public final class BytesToBytesMap extends MemoryConsumer {
    * Reset this map to initialized state.
    */
   public void reset() {
-    numElements = 0;
+    numKeys = 0;
+    numValues = 0;
     longArray.zeroOut();
 
     while (dataPages.size() > 0) {
diff --git a/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java
index 449fb45c30..84b82f5a47 100644
--- a/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java
+++ b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java
@@ -182,7 +182,7 @@ public abstract class AbstractBytesToBytesMapSuite {
   public void emptyMap() {
     BytesToBytesMap map = new BytesToBytesMap(taskMemoryManager, 64, PAGE_SIZE_BYTES);
     try {
-      Assert.assertEquals(0, map.numElements());
+      Assert.assertEquals(0, map.numKeys());
       final int keyLengthInWords = 10;
       final int keyLengthInBytes = keyLengthInWords * 8;
       final byte[] key = getRandomByteArray(keyLengthInWords);
@@ -204,7 +204,7 @@ public abstract class AbstractBytesToBytesMapSuite {
       final BytesToBytesMap.Location loc =
         map.lookup(keyData, Platform.BYTE_ARRAY_OFFSET, recordLengthBytes);
       Assert.assertFalse(loc.isDefined());
-      Assert.assertTrue(loc.putNewKey(
+      Assert.assertTrue(loc.append(
         keyData,
         Platform.BYTE_ARRAY_OFFSET,
         recordLengthBytes,
@@ -232,7 +232,7 @@ public abstract class AbstractBytesToBytesMapSuite {
         getByteArray(loc.getValueBase(), loc.getValueOffset(), recordLengthBytes));
 
       try {
-        Assert.assertTrue(loc.putNewKey(
+        Assert.assertTrue(loc.append(
           keyData,
           Platform.BYTE_ARRAY_OFFSET,
           recordLengthBytes,
@@ -260,7 +260,7 @@ public abstract class AbstractBytesToBytesMapSuite {
         Assert.assertFalse(loc.isDefined());
         // Ensure that we store some zero-length keys
         if (i % 5 == 0) {
-          Assert.assertTrue(loc.putNewKey(
+          Assert.assertTrue(loc.append(
             null,
             Platform.LONG_ARRAY_OFFSET,
             0,
@@ -269,7 +269,7 @@ public abstract class AbstractBytesToBytesMapSuite {
             8
           ));
         } else {
-          Assert.assertTrue(loc.putNewKey(
+          Assert.assertTrue(loc.append(
             value,
             Platform.LONG_ARRAY_OFFSET,
             8,
@@ -349,7 +349,7 @@ public abstract class AbstractBytesToBytesMapSuite {
           KEY_LENGTH
         );
         Assert.assertFalse(loc.isDefined());
-        Assert.assertTrue(loc.putNewKey(
+        Assert.assertTrue(loc.append(
           key,
           Platform.LONG_ARRAY_OFFSET,
           KEY_LENGTH,
@@ -417,7 +417,7 @@ public abstract class AbstractBytesToBytesMapSuite {
             key.length
           );
           Assert.assertFalse(loc.isDefined());
-          Assert.assertTrue(loc.putNewKey(
+          Assert.assertTrue(loc.append(
             key,
             Platform.BYTE_ARRAY_OFFSET,
             key.length,
@@ -471,7 +471,7 @@ public abstract class AbstractBytesToBytesMapSuite {
             key.length
           );
           Assert.assertFalse(loc.isDefined());
-          Assert.assertTrue(loc.putNewKey(
+          Assert.assertTrue(loc.append(
             key,
             Platform.BYTE_ARRAY_OFFSET,
             key.length,
@@ -514,7 +514,7 @@ public abstract class AbstractBytesToBytesMapSuite {
       final BytesToBytesMap.Location loc =
         map.lookup(emptyArray, Platform.LONG_ARRAY_OFFSET, 0);
       Assert.assertFalse(loc.isDefined());
-      Assert.assertFalse(loc.putNewKey(
+      Assert.assertFalse(loc.append(
         emptyArray, Platform.LONG_ARRAY_OFFSET, 0, emptyArray, Platform.LONG_ARRAY_OFFSET, 0));
     } finally {
       map.free();
@@ -535,7 +535,7 @@ public abstract class AbstractBytesToBytesMapSuite {
         final long[] arr = new long[]{i};
         final BytesToBytesMap.Location loc = map.lookup(arr, Platform.LONG_ARRAY_OFFSET, 8);
         success =
-          loc.putNewKey(arr, Platform.LONG_ARRAY_OFFSET, 8, arr, Platform.LONG_ARRAY_OFFSET, 8);
+          loc.append(arr, Platform.LONG_ARRAY_OFFSET, 8, arr, Platform.LONG_ARRAY_OFFSET, 8);
         if (!success) {
           break;
         }
@@ -556,7 +556,7 @@ public abstract class AbstractBytesToBytesMapSuite {
       for (i = 0; i < 1024; i++) {
         final long[] arr = new long[]{i};
         final BytesToBytesMap.Location loc = map.lookup(arr, Platform.LONG_ARRAY_OFFSET, 8);
-        loc.putNewKey(arr, Platform.LONG_ARRAY_OFFSET, 8, arr, Platform.LONG_ARRAY_OFFSET, 8);
+        loc.append(arr, Platform.LONG_ARRAY_OFFSET, 8, arr, Platform.LONG_ARRAY_OFFSET, 8);
       }
       BytesToBytesMap.MapIterator iter = map.iterator();
       for (i = 0; i < 100; i++) {
@@ -586,6 +586,44 @@ public abstract class AbstractBytesToBytesMapSuite {
     }
   }
 
+  @Test
+  public void multipleValuesForSameKey() {
+    BytesToBytesMap map =
+      new BytesToBytesMap(taskMemoryManager, blockManager, serializerManager, 1, 0.75, 1024, false);
+    try {
+      int i;
+      for (i = 0; i < 1024; i++) {
+        final long[] arr = new long[]{i};
+        map.lookup(arr, Platform.LONG_ARRAY_OFFSET, 8)
+          .append(arr, Platform.LONG_ARRAY_OFFSET, 8, arr, Platform.LONG_ARRAY_OFFSET, 8);
+      }
+      assert map.numKeys() == 1024;
+      assert map.numValues() == 1024;
+      for (i = 0; i < 1024; i++) {
+        final long[] arr = new long[]{i};
+        map.lookup(arr, Platform.LONG_ARRAY_OFFSET, 8)
+          .append(arr, Platform.LONG_ARRAY_OFFSET, 8, arr, Platform.LONG_ARRAY_OFFSET, 8);
+      }
+      assert map.numKeys() == 1024;
+      assert map.numValues() == 2048;
+      for (i = 0; i < 1024; i++) {
+        final long[] arr = new long[]{i};
+        final BytesToBytesMap.Location loc = map.lookup(arr, Platform.LONG_ARRAY_OFFSET, 8);
+        assert loc.isDefined();
+        assert loc.nextValue();
+        assert !loc.nextValue();
+      }
+      BytesToBytesMap.MapIterator iter = map.iterator();
+      for (i = 0; i < 2048; i++) {
+        assert iter.hasNext();
+        final BytesToBytesMap.Location loc = iter.next();
+        assert loc.isDefined();
+      }
+    } finally {
+      map.free();
+    }
+  }
+
   @Test
   public void initialCapacityBoundsChecking() {
     try {
@@ -608,7 +646,7 @@ public abstract class AbstractBytesToBytesMapSuite {
 
   @Test
   public void testPeakMemoryUsed() {
-    final long recordLengthBytes = 24;
+    final long recordLengthBytes = 32;
     final long pageSizeBytes = 256 + 8; // 8 bytes for end-of-page marker
     final long numRecordsPerPage = (pageSizeBytes - 8) / recordLengthBytes;
     final BytesToBytesMap map = new BytesToBytesMap(taskMemoryManager, 1024, pageSizeBytes);
@@ -622,7 +660,7 @@ public abstract class AbstractBytesToBytesMapSuite {
     try {
       for (long i = 0; i < numRecordsPerPage * 10; i++) {
         final long[] value = new long[]{i};
-        map.lookup(value, Platform.LONG_ARRAY_OFFSET, 8).putNewKey(
+        map.lookup(value, Platform.LONG_ARRAY_OFFSET, 8).append(
           value,
           Platform.LONG_ARRAY_OFFSET,
           8,
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java
index 8882903bbf..1f1b5389aa 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java
@@ -134,7 +134,7 @@ public final class UnsafeFixedWidthAggregationMap {
     if (!loc.isDefined()) {
       // This is the first time that we've seen this grouping key, so we'll insert a copy of the
       // empty aggregation buffer into the map:
-      boolean putSucceeded = loc.putNewKey(
+      boolean putSucceeded = loc.append(
         key.getBaseObject(),
         key.getBaseOffset(),
         key.getSizeInBytes(),
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
index 8cc3528639..dc4793e85a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
@@ -18,18 +18,18 @@
 package org.apache.spark.sql.execution.joins
 
 import java.io.{Externalizable, IOException, ObjectInput, ObjectOutput}
-import java.nio.ByteOrder
 import java.util.{HashMap => JavaHashMap}
 
-import org.apache.spark.{SparkConf, SparkEnv}
+import org.apache.spark.{SparkConf, SparkEnv, SparkException, TaskContext}
 import org.apache.spark.memory.{StaticMemoryManager, TaskMemoryManager}
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.plans.physical.BroadcastMode
 import org.apache.spark.sql.execution.SparkSqlSerializer
 import org.apache.spark.unsafe.Platform
+import org.apache.spark.unsafe.hash.Murmur3_x86_32
 import org.apache.spark.unsafe.map.BytesToBytesMap
-import org.apache.spark.util.{KnownSizeEstimation, SizeEstimator, Utils}
+import org.apache.spark.util.{KnownSizeEstimation, Utils}
 import org.apache.spark.util.collection.CompactBuffer
 
 /**
@@ -54,6 +54,11 @@ private[execution] sealed trait HashedRelation {
     */
   def getMemorySize: Long = 1L  // to make the test happy
 
+  /**
+   * Release any used resources.
+   */
+  def close(): Unit = {}
+
   // This is a helper method to implement Externalizable, and is used by
   // GeneralHashedRelation and UniqueKeyHashedRelation
   protected def writeBytes(out: ObjectOutput, serialized: Array[Byte]): Unit = {
@@ -132,163 +137,83 @@ private[execution] object HashedRelation {
 }
 
 /**
- * A HashedRelation for UnsafeRow, which is backed by HashMap or BytesToBytesMap that maps the key
- * into a sequence of values.
- *
- * When it's created, it uses HashMap. After it's serialized and deserialized, it switch to use
- * BytesToBytesMap for better memory performance (multiple values for the same are stored as a
- * continuous byte array.
+ * A HashedRelation for UnsafeRow, which is backed BytesToBytesMap.
  *
  * It's serialized in the following format:
  *  [number of keys]
- *  [size of key] [size of all values in bytes] [key bytes] [bytes for all values]
- *  ...
- *
- * All the values are serialized as following:
- *   [number of fields] [number of bytes] [underlying bytes of UnsafeRow]
- *   ...
+ *  [size of key] [size of value] [key bytes] [bytes for value]
  */
-private[joins] final class UnsafeHashedRelation(
-    private var hashTable: JavaHashMap[UnsafeRow, CompactBuffer[UnsafeRow]])
-  extends HashedRelation
-  with KnownSizeEstimation
-  with Externalizable {
-
-  private[joins] def this() = this(null)  // Needed for serialization
+private[joins] class UnsafeHashedRelation(
+    private var numFields: Int,
+    private var binaryMap: BytesToBytesMap)
+  extends HashedRelation with KnownSizeEstimation with Externalizable {
 
-  // Use BytesToBytesMap in executor for better performance (it's created when deserialization)
-  // This is used in broadcast joins and distributed mode only
-  @transient private[this] var binaryMap: BytesToBytesMap = _
+  private[joins] def this() = this(0, null)  // Needed for serialization
 
-  /**
-   * Return the size of the unsafe map on the executors.
-   *
-   * For broadcast joins, this hashed relation is bigger on the driver because it is
-   * represented as a Java hash map there. While serializing the map to the executors,
-   * however, we rehash the contents in a binary map to reduce the memory footprint on
-   * the executors.
-   *
-   * For non-broadcast joins or in local mode, return 0.
-   */
   override def getMemorySize: Long = {
-    if (binaryMap != null) {
-      binaryMap.getTotalMemoryConsumption
-    } else {
-      0
-    }
+    binaryMap.getTotalMemoryConsumption
   }
 
   override def estimatedSize: Long = {
-    if (binaryMap != null) {
-      binaryMap.getTotalMemoryConsumption
-    } else {
-      SizeEstimator.estimate(hashTable)
-    }
+    binaryMap.getTotalMemoryConsumption
   }
 
   override def get(key: InternalRow): Seq[InternalRow] = {
     val unsafeKey = key.asInstanceOf[UnsafeRow]
-
-    if (binaryMap != null) {
-      // Used in Broadcast join
-      val map = binaryMap  // avoid the compiler error
-      val loc = new map.Location  // this could be allocated in stack
-      binaryMap.safeLookup(unsafeKey.getBaseObject, unsafeKey.getBaseOffset,
-        unsafeKey.getSizeInBytes, loc, unsafeKey.hashCode())
-      if (loc.isDefined) {
-        val buffer = CompactBuffer[UnsafeRow]()
-
-        val base = loc.getValueBase
-        var offset = loc.getValueOffset
-        val last = offset + loc.getValueLength
-        while (offset < last) {
-          val numFields = Platform.getInt(base, offset)
-          val sizeInBytes = Platform.getInt(base, offset + 4)
-          offset += 8
-
-          val row = new UnsafeRow(numFields)
-          row.pointTo(base, offset, sizeInBytes)
-          buffer += row
-          offset += sizeInBytes
-        }
-        buffer
-      } else {
-        null
+    val map = binaryMap  // avoid the compiler error
+    val loc = new map.Location  // this could be allocated in stack
+    binaryMap.safeLookup(unsafeKey.getBaseObject, unsafeKey.getBaseOffset,
+      unsafeKey.getSizeInBytes, loc, unsafeKey.hashCode())
+    if (loc.isDefined) {
+      val buffer = CompactBuffer[UnsafeRow]()
+      val row = new UnsafeRow(numFields)
+      row.pointTo(loc.getValueBase, loc.getValueOffset, loc.getValueLength)
+      buffer += row
+      while (loc.nextValue()) {
+        val row = new UnsafeRow(numFields)
+        row.pointTo(loc.getValueBase, loc.getValueOffset, loc.getValueLength)
+        buffer += row
       }
-
+      buffer
     } else {
-      // Use the Java HashMap in local mode or for non-broadcast joins (e.g. ShuffleHashJoin)
-      hashTable.get(unsafeKey)
+      null
     }
   }
 
-  override def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException {
-    if (binaryMap != null) {
-      // This could happen when a cached broadcast object need to be dumped into disk to free memory
-      out.writeInt(binaryMap.numElements())
-
-      var buffer = new Array[Byte](64)
-      def write(base: Object, offset: Long, length: Int): Unit = {
-        if (buffer.length < length) {
-          buffer = new Array[Byte](length)
-        }
-        Platform.copyMemory(base, offset, buffer, Platform.BYTE_ARRAY_OFFSET, length)
-        out.write(buffer, 0, length)
-      }
+  override def close(): Unit = {
+    binaryMap.free()
+  }
 
-      val iter = binaryMap.iterator()
-      while (iter.hasNext) {
-        val loc = iter.next()
-        // [key size] [values size] [key bytes] [values bytes]
-        out.writeInt(loc.getKeyLength)
-        out.writeInt(loc.getValueLength)
-        write(loc.getKeyBase, loc.getKeyOffset, loc.getKeyLength)
-        write(loc.getValueBase, loc.getValueOffset, loc.getValueLength)
+  override def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException {
+    out.writeInt(numFields)
+    // TODO: move these into BytesToBytesMap
+    out.writeInt(binaryMap.numKeys())
+    out.writeInt(binaryMap.numValues())
+
+    var buffer = new Array[Byte](64)
+    def write(base: Object, offset: Long, length: Int): Unit = {
+      if (buffer.length < length) {
+        buffer = new Array[Byte](length)
       }
+      Platform.copyMemory(base, offset, buffer, Platform.BYTE_ARRAY_OFFSET, length)
+      out.write(buffer, 0, length)
+    }
 
-    } else {
-      assert(hashTable != null)
-      out.writeInt(hashTable.size())
-
-      val iter = hashTable.entrySet().iterator()
-      while (iter.hasNext) {
-        val entry = iter.next()
-        val key = entry.getKey
-        val values = entry.getValue
-
-        // write all the values as single byte array
-        var totalSize = 0L
-        var i = 0
-        while (i < values.length) {
-          totalSize += values(i).getSizeInBytes + 4 + 4
-          i += 1
-        }
-        assert(totalSize < Integer.MAX_VALUE, "values are too big")
-
-        // [key size] [values size] [key bytes] [values bytes]
-        out.writeInt(key.getSizeInBytes)
-        out.writeInt(totalSize.toInt)
-        out.write(key.getBytes)
-        i = 0
-        while (i < values.length) {
-          // [num of fields] [num of bytes] [row bytes]
-          // write the integer in native order, so they can be read by UNSAFE.getInt()
-          if (ByteOrder.nativeOrder() == ByteOrder.BIG_ENDIAN) {
-            out.writeInt(values(i).numFields())
-            out.writeInt(values(i).getSizeInBytes)
-          } else {
-            out.writeInt(Integer.reverseBytes(values(i).numFields()))
-            out.writeInt(Integer.reverseBytes(values(i).getSizeInBytes))
-          }
-          out.write(values(i).getBytes)
-          i += 1
-        }
-      }
+    val iter = binaryMap.iterator()
+    while (iter.hasNext) {
+      val loc = iter.next()
+      // [key size] [values size] [key bytes] [value bytes]
+      out.writeInt(loc.getKeyLength)
+      out.writeInt(loc.getValueLength)
+      write(loc.getKeyBase, loc.getKeyOffset, loc.getKeyLength)
+      write(loc.getValueBase, loc.getValueOffset, loc.getValueLength)
     }
   }
 
   override def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException {
+    numFields = in.readInt()
     val nKeys = in.readInt()
+    val nValues = in.readInt()
     // This is used in Broadcast, shared by multiple tasks, so we use on-heap memory
     // TODO(josh): This needs to be revisited before we merge this patch; making this change now
     // so that tests compile:
@@ -314,7 +239,7 @@ private[joins] final class UnsafeHashedRelation(
     var i = 0
     var keyBuffer = new Array[Byte](1024)
     var valuesBuffer = new Array[Byte](1024)
-    while (i < nKeys) {
+    while (i < nValues) {
       val keySize = in.readInt()
       val valuesSize = in.readInt()
       if (keySize > keyBuffer.length) {
@@ -326,13 +251,11 @@ private[joins] final class UnsafeHashedRelation(
       }
       in.readFully(valuesBuffer, 0, valuesSize)
 
-      // put it into binary map
       val loc = binaryMap.lookup(keyBuffer, Platform.BYTE_ARRAY_OFFSET, keySize)
-      assert(!loc.isDefined, "Duplicated key found!")
-      val putSuceeded = loc.putNewKey(
-        keyBuffer, Platform.BYTE_ARRAY_OFFSET, keySize,
+      val putSuceeded = loc.append(keyBuffer, Platform.BYTE_ARRAY_OFFSET, keySize,
         valuesBuffer, Platform.BYTE_ARRAY_OFFSET, valuesSize)
       if (!putSuceeded) {
+        binaryMap.free()
         throw new IOException("Could not allocate memory to grow BytesToBytesMap")
       }
       i += 1
@@ -340,6 +263,29 @@ private[joins] final class UnsafeHashedRelation(
   }
 }
 
+/**
+ * A HashedRelation for UnsafeRow with unique keys.
+ */
+private[joins] final class UniqueUnsafeHashedRelation(
+    private var numFields: Int,
+    private var binaryMap: BytesToBytesMap)
+  extends UnsafeHashedRelation(numFields, binaryMap) with UniqueHashedRelation {
+  def getValue(key: InternalRow): InternalRow = {
+    val unsafeKey = key.asInstanceOf[UnsafeRow]
+    val map = binaryMap  // avoid the compiler error
+    val loc = new map.Location  // this could be allocated in stack
+    binaryMap.safeLookup(unsafeKey.getBaseObject, unsafeKey.getBaseOffset,
+      unsafeKey.getSizeInBytes, loc, unsafeKey.hashCode())
+    if (loc.isDefined) {
+      val row = new UnsafeRow(numFields)
+      row.pointTo(loc.getValueBase, loc.getValueOffset, loc.getValueLength)
+      row
+    } else {
+      null
+    }
+  }
+}
+
 private[joins] object UnsafeHashedRelation {
 
   def apply(
@@ -347,29 +293,54 @@ private[joins] object UnsafeHashedRelation {
       keyGenerator: UnsafeProjection,
       sizeEstimate: Int): HashedRelation = {
 
-    // Use a Java hash table here because unsafe maps expect fixed size records
-    // TODO: Use BytesToBytesMap for memory efficiency
-    val hashTable = new JavaHashMap[UnsafeRow, CompactBuffer[UnsafeRow]](sizeEstimate)
+    val taskMemoryManager = if (TaskContext.get() != null) {
+      TaskContext.get().taskMemoryManager()
+    } else {
+      new TaskMemoryManager(
+        new StaticMemoryManager(
+          new SparkConf().set("spark.memory.offHeap.enabled", "false"),
+          Long.MaxValue,
+          Long.MaxValue,
+          1),
+        0)
+    }
+    val pageSizeBytes = Option(SparkEnv.get).map(_.memoryManager.pageSizeBytes)
+      .getOrElse(new SparkConf().getSizeAsBytes("spark.buffer.pageSize", "16m"))
+
+    val binaryMap = new BytesToBytesMap(
+      taskMemoryManager,
+      // Only 70% of the slots can be used before growing, more capacity help to reduce collision
+      (sizeEstimate * 1.5 + 1).toInt,
+      pageSizeBytes)
 
     // Create a mapping of buildKeys -> rows
+    var numFields = 0
+    // Whether all the keys are unique or not
+    var allUnique: Boolean = true
     while (input.hasNext) {
-      val unsafeRow = input.next().asInstanceOf[UnsafeRow]
-      val rowKey = keyGenerator(unsafeRow)
-      if (!rowKey.anyNull) {
-        val existingMatchList = hashTable.get(rowKey)
-        val matchList = if (existingMatchList == null) {
-          val newMatchList = new CompactBuffer[UnsafeRow]()
-          hashTable.put(rowKey.copy(), newMatchList)
-          newMatchList
-        } else {
-          existingMatchList
+      val row = input.next().asInstanceOf[UnsafeRow]
+      numFields = row.numFields()
+      val key = keyGenerator(row)
+      if (!key.anyNull) {
+        val loc = binaryMap.lookup(key.getBaseObject, key.getBaseOffset, key.getSizeInBytes)
+        if (loc.isDefined) {
+          allUnique = false
+        }
+        val success = loc.append(
+          key.getBaseObject, key.getBaseOffset, key.getSizeInBytes,
+          row.getBaseObject, row.getBaseOffset, row.getSizeInBytes)
+        if (!success) {
+          binaryMap.free()
+          throw new SparkException("There is no enough memory to build hash map")
         }
-        matchList += unsafeRow
       }
     }
 
-    // TODO: create UniqueUnsafeRelation
-    new UnsafeHashedRelation(hashTable)
+    if (allUnique) {
+      new UniqueUnsafeHashedRelation(numFields, binaryMap)
+    } else {
+      new UnsafeHashedRelation(numFields, binaryMap)
+    }
   }
 }
 
@@ -523,7 +494,7 @@ private[joins] object LongHashedRelation {
     keyGenerator: Projection,
     sizeEstimate: Int): HashedRelation = {
 
-    // Use a Java hash table here because unsafe maps expect fixed size records
+    // TODO: use LongToBytesMap for better memory efficiency
     val hashTable = new JavaHashMap[Long, CompactBuffer[UnsafeRow]](sizeEstimate)
 
     // Create a mapping of key -> rows
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala
index 5c4f1ef60f..e3a2eaea5d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala
@@ -57,9 +57,19 @@ case class ShuffledHashJoin(
     ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil
 
   private def buildHashedRelation(iter: Iterator[UnsafeRow]): HashedRelation = {
+    val context = TaskContext.get()
+    if (!canJoinKeyFitWithinLong) {
+      // build BytesToBytesMap
+      val relation = HashedRelation(canJoinKeyFitWithinLong, iter, buildSideKeyGenerator)
+      // This relation is usually used until the end of task.
+      context.addTaskCompletionListener((t: TaskContext) =>
+        relation.close()
+      )
+      return relation
+    }
+
     // try to acquire some memory for the hash table, it could trigger other operator to free some
     // memory. The memory acquired here will mostly be used until the end of task.
-    val context = TaskContext.get()
     val memoryManager = context.taskMemoryManager()
     var acquired = 0L
     var used = 0L
@@ -69,18 +79,18 @@ case class ShuffledHashJoin(
 
     val copiedIter = iter.map { row =>
       // It's hard to guess what's exactly memory will be used, we have a rough guess here.
-      // TODO: use BytesToBytesMap instead of HashMap for memory efficiency
-      // Each pair in HashMap will have two UnsafeRows, one CompactBuffer, maybe 10+ pointers
+      // TODO: use LongToBytesMap instead of HashMap for memory efficiency
+      // Each pair in HashMap will have UnsafeRow, CompactBuffer, maybe 10+ pointers
       val needed = 150 + row.getSizeInBytes
       if (needed > acquired - used) {
         val got = memoryManager.acquireExecutionMemory(
           Math.max(memoryManager.pageSizeBytes(), needed), MemoryMode.ON_HEAP, null)
+        acquired += got
         if (got < needed) {
           throw new SparkException("Can't acquire enough memory to build hash map in shuffled" +
             "hash join, please use sort merge join by setting " +
             "spark.sql.join.preferSortMergeJoin=true")
         }
-        acquired += got
       }
       used += needed
       // HashedRelation requires that the UnsafeRow should be separate objects.
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala
index 0b1cb90186..a16092e7d7 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala
@@ -184,11 +184,29 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite {
 
     /**
     Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz
-    Join w 2 longs:                      Best/Avg Time(ms)    Rate(M/s)   Per Row(ns)   Relative
+    Join w 2 longs:                     Best/Avg Time(ms)    Rate(M/s)   Per Row(ns)   Relative
     -------------------------------------------------------------------------------------------
-    Join w 2 longs codegen=false              7877 / 8358         13.3          75.1       1.0X
-    Join w 2 longs codegen=true               3877 / 3937         27.0          37.0       2.0X
+    Join w 2 longs codegen=false           12725 / 13158          8.2         121.4       1.0X
+    Join w 2 longs codegen=true              6044 / 6771         17.3          57.6       2.1X
       */
+
+    val dim4 = broadcast(sqlContext.range(M)
+      .selectExpr("cast(id/10 as long) as k1", "cast(id/10 as long) as k2"))
+
+    runBenchmark("Join w 2 longs duplicated", N) {
+      sqlContext.range(N).join(dim4,
+        (col("id") bitwiseAND M) === col("k1") && (col("id") bitwiseAND M) === col("k2"))
+        .count()
+    }
+
+    /**
+    Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz
+    Join w 2 longs:                     Best/Avg Time(ms)    Rate(M/s)   Per Row(ns)   Relative
+    -------------------------------------------------------------------------------------------
+    Join w 2 longs duplicated codegen=false 13066 / 13710          8.0         124.6       1.0X
+    Join w 2 longs duplicated codegen=true    7122 / 7277         14.7          67.9       1.8X
+     */
+
     runBenchmark("outer join w long", N) {
       sqlContext.range(N).join(dim, (col("id") bitwiseAND M) === col("k"), "left").count()
     }
@@ -438,7 +456,7 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite {
             value.setInt(0, value.getInt(0) + 1)
             i += 1
           } else {
-            loc.putNewKey(key.getBaseObject, key.getBaseOffset, key.getSizeInBytes,
+            loc.append(key.getBaseObject, key.getBaseOffset, key.getSizeInBytes,
               value.getBaseObject, value.getBaseOffset, value.getSizeInBytes)
           }
         }
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala
index e19b4ff1e2..ed4cc1c4c4 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala
@@ -19,11 +19,13 @@ package org.apache.spark.sql.execution.joins
 
 import java.io.{ByteArrayInputStream, ByteArrayOutputStream, ObjectInputStream, ObjectOutputStream}
 
-import org.apache.spark.SparkFunSuite
+import org.apache.spark.{SparkConf, SparkFunSuite}
+import org.apache.spark.memory.{StaticMemoryManager, TaskMemoryManager}
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.test.SharedSQLContext
 import org.apache.spark.sql.types.{IntegerType, StructField, StructType}
+import org.apache.spark.unsafe.map.BytesToBytesMap
 import org.apache.spark.util.collection.CompactBuffer
 
 class HashedRelationSuite extends SparkFunSuite with SharedSQLContext {
@@ -69,10 +71,17 @@ class HashedRelationSuite extends SparkFunSuite with SharedSQLContext {
   }
 
   test("test serialization empty hash map") {
+    val taskMemoryManager = new TaskMemoryManager(
+      new StaticMemoryManager(
+        new SparkConf().set("spark.memory.offHeap.enabled", "false"),
+        Long.MaxValue,
+        Long.MaxValue,
+        1),
+      0)
+    val binaryMap = new BytesToBytesMap(taskMemoryManager, 1, 1)
     val os = new ByteArrayOutputStream()
     val out = new ObjectOutputStream(os)
-    val hashed = new UnsafeHashedRelation(
-      new java.util.HashMap[UnsafeRow, CompactBuffer[UnsafeRow]])
+    val hashed = new UnsafeHashedRelation(1, binaryMap)
     hashed.writeExternal(out)
     out.flush()
     val in = new ObjectInputStream(new ByteArrayInputStream(os.toByteArray))
-- 
cgit v1.2.3


From 7007f72ba7f0ccdc1d11315757a80f55a93451df Mon Sep 17 00:00:00 2001
From: Yin Huai 
Date: Mon, 28 Mar 2016 13:50:42 -0700
Subject: [SPARK-13713][SQL][TEST-MAVEN] Add Antlr4 maven plugin.

Seems https://github.com/apache/spark/commit/600c0b69cab4767e8e5a6f4284777d8b9d4bd40e is missing the antlr4 maven plugin. This pr adds it.

Author: Yin Huai 

Closes #12010 from yhuai/mavenAntlr4.
---
 pom.xml              |  5 +++++
 sql/catalyst/pom.xml | 15 +++++++++++++++
 2 files changed, 20 insertions(+)

diff --git a/pom.xml b/pom.xml
index 475f0544bd..1513a18b71 100644
--- a/pom.xml
+++ b/pom.xml
@@ -1891,6 +1891,11 @@
           antlr3-maven-plugin
           3.5.2
         
+        
+          org.antlr
+          antlr4-maven-plugin
+          ${antlr4.version}
+        
         
         
           org.apache.maven.plugins
diff --git a/sql/catalyst/pom.xml b/sql/catalyst/pom.xml
index c834a011f1..9bfe495e90 100644
--- a/sql/catalyst/pom.xml
+++ b/sql/catalyst/pom.xml
@@ -133,6 +133,21 @@
           
         
       
+      
+        org.antlr
+        antlr4-maven-plugin
+        
+          
+            
+              antlr4
+            
+          
+        
+        
+          true
+          ../catalyst/src/main/antlr4
+        
+      
     
   
 
-- 
cgit v1.2.3


From ff3bea38ed2ac8dac5832f0bf8eac70192a512ef Mon Sep 17 00:00:00 2001
From: nfraison 
Date: Mon, 28 Mar 2016 14:10:25 -0700
Subject: [SPARK-13622][YARN] Issue creating level db for YARN shuffle service

## What changes were proposed in this pull request?
This patch will ensure that we trim all path set in yarn.nodemanager.local-dirs and that the the scheme is well removed so the level db can be created.

## How was this patch tested?
manual tests.

Author: nfraison 

Closes #11475 from ashangit/level_db_creation_issue.
---
 .../java/org/apache/spark/network/yarn/YarnShuffleService.java     | 7 ++++---
 1 file changed, 4 insertions(+), 3 deletions(-)

diff --git a/common/network-yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java b/common/network-yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java
index ba6d30a74c..4bc3c1a3c8 100644
--- a/common/network-yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java
+++ b/common/network-yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java
@@ -24,6 +24,7 @@ import java.util.List;
 import com.google.common.annotations.VisibleForTesting;
 import com.google.common.collect.Lists;
 import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.Path;
 import org.apache.hadoop.yarn.api.records.ContainerId;
 import org.apache.hadoop.yarn.server.api.*;
 import org.slf4j.Logger;
@@ -118,7 +119,7 @@ public class YarnShuffleService extends AuxiliaryService {
     // an application was stopped while the NM was down, we expect yarn to call stopApplication()
     // when it comes back
     registeredExecutorFile =
-      findRegisteredExecutorFile(conf.getStrings("yarn.nodemanager.local-dirs"));
+      findRegisteredExecutorFile(conf.getTrimmedStrings("yarn.nodemanager.local-dirs"));
 
     TransportConf transportConf = new TransportConf("shuffle", new HadoopConfigProvider(conf));
     // If authentication is enabled, set up the shuffle server to use a
@@ -191,12 +192,12 @@ public class YarnShuffleService extends AuxiliaryService {
 
   private File findRegisteredExecutorFile(String[] localDirs) {
     for (String dir: localDirs) {
-      File f = new File(dir, "registeredExecutors.ldb");
+      File f = new File(new Path(dir).toUri().getPath(), "registeredExecutors.ldb");
       if (f.exists()) {
         return f;
       }
     }
-    return new File(localDirs[0], "registeredExecutors.ldb");
+    return new File(new Path(localDirs[0]).toUri().getPath(), "registeredExecutors.ldb");
   }
 
   /**
-- 
cgit v1.2.3


From 39f743a6231cbd8cc770a28f43ee601eff28d597 Mon Sep 17 00:00:00 2001
From: zero323 
Date: Mon, 28 Mar 2016 14:51:36 -0700
Subject: [SPARK-14202] [PYTHON] Use generator expression instead of list comp
 in python_full_outer_jo…
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

## What changes were proposed in this pull request?

This PR replaces list comprehension in python_full_outer_join.dispatch with a generator expression.

## How was this patch tested?

PySpark-Core, PySpark-MLlib test suites against Python 2.7, 3.5.

Author: zero323 

Closes #11998 from zero323/pyspark-join-generator-expr.
---
 python/pyspark/join.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/python/pyspark/join.py b/python/pyspark/join.py
index 94df399016..c1f5362648 100644
--- a/python/pyspark/join.py
+++ b/python/pyspark/join.py
@@ -93,7 +93,7 @@ def python_full_outer_join(rdd, other, numPartitions):
             vbuf.append(None)
         if not wbuf:
             wbuf.append(None)
-        return [(v, w) for v in vbuf for w in wbuf]
+        return ((v, w) for v in vbuf for w in wbuf)
     return _do_python_join(rdd, other, numPartitions, dispatch)
 
 
-- 
cgit v1.2.3


From 8c11d1aab8522c75d78bc6b30402c64e8d9ff065 Mon Sep 17 00:00:00 2001
From: Xusen Yin 
Date: Mon, 28 Mar 2016 15:40:06 -0700
Subject: [SPARK-11893] Model export/import for spark.ml: TrainValidationSplit

https://issues.apache.org/jira/browse/SPARK-11893

jkbradley In order to share read/write with `TrainValidationSplit`, I move the `SharedReadWrite` out of `CrossValidator` into a new trait `SharedReadWrite` in the tunning package.

To reduce the repeated tests, I move the complex tests from `CrossValidatorSuite` to `SharedReadWriteSuite`, and create a fake validator called `MyValidator` to test the shared code.

With `SharedReadWrite`, potential newly added `Validator` can share the read/write common part, and only need to implement their extra params save/load.

Author: Xusen Yin 
Author: Joseph K. Bradley 

Closes #9971 from yinxusen/SPARK-11893.
---
 .../apache/spark/ml/tuning/CrossValidator.scala    | 148 ++-------------------
 .../spark/ml/tuning/TrainValidationSplit.scala     | 100 +++++++++++++-
 .../apache/spark/ml/tuning/ValidatorParams.scala   | 117 +++++++++++++++-
 .../scala/org/apache/spark/ml/util/ReadWrite.scala |  42 +++++-
 .../ml/tuning/TrainValidationSplitSuite.scala      |  45 ++++++-
 5 files changed, 310 insertions(+), 142 deletions(-)

diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
index 963f81cb3e..040b0093b9 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
@@ -19,25 +19,19 @@ package org.apache.spark.ml.tuning
 
 import com.github.fommil.netlib.F2jBLAS
 import org.apache.hadoop.fs.Path
-import org.json4s.{DefaultFormats, JObject}
-import org.json4s.jackson.JsonMethods._
+import org.json4s.DefaultFormats
 
-import org.apache.spark.SparkContext
 import org.apache.spark.annotation.{Experimental, Since}
 import org.apache.spark.internal.Logging
 import org.apache.spark.ml._
-import org.apache.spark.ml.classification.OneVsRestParams
 import org.apache.spark.ml.evaluation.Evaluator
-import org.apache.spark.ml.feature.RFormulaModel
 import org.apache.spark.ml.param._
 import org.apache.spark.ml.param.shared.HasSeed
 import org.apache.spark.ml.util._
-import org.apache.spark.ml.util.DefaultParamsReader.Metadata
 import org.apache.spark.mllib.util.MLUtils
 import org.apache.spark.sql.DataFrame
 import org.apache.spark.sql.types.StructType
 
-
 /**
  * Params for [[CrossValidator]] and [[CrossValidatorModel]].
  */
@@ -45,6 +39,7 @@ private[ml] trait CrossValidatorParams extends ValidatorParams with HasSeed {
   /**
    * Param for number of folds for cross validation.  Must be >= 2.
    * Default: 3
+   *
    * @group param
    */
   val numFolds: IntParam = new IntParam(this, "numFolds",
@@ -163,10 +158,10 @@ object CrossValidator extends MLReadable[CrossValidator] {
 
   private[CrossValidator] class CrossValidatorWriter(instance: CrossValidator) extends MLWriter {
 
-    SharedReadWrite.validateParams(instance)
+    ValidatorParams.validateParams(instance)
 
     override protected def saveImpl(path: String): Unit =
-      SharedReadWrite.saveImpl(path, instance, sc)
+      ValidatorParams.saveImpl(path, instance, sc)
   }
 
   private class CrossValidatorReader extends MLReader[CrossValidator] {
@@ -175,8 +170,11 @@ object CrossValidator extends MLReadable[CrossValidator] {
     private val className = classOf[CrossValidator].getName
 
     override def load(path: String): CrossValidator = {
-      val (metadata, estimator, evaluator, estimatorParamMaps, numFolds) =
-        SharedReadWrite.load(path, sc, className)
+      implicit val format = DefaultFormats
+
+      val (metadata, estimator, evaluator, estimatorParamMaps) =
+        ValidatorParams.loadImpl(path, sc, className)
+      val numFolds = (metadata.params \ "numFolds").extract[Int]
       new CrossValidator(metadata.uid)
         .setEstimator(estimator)
         .setEvaluator(evaluator)
@@ -184,123 +182,6 @@ object CrossValidator extends MLReadable[CrossValidator] {
         .setNumFolds(numFolds)
     }
   }
-
-  private object CrossValidatorReader {
-    /**
-     * Examine the given estimator (which may be a compound estimator) and extract a mapping
-     * from UIDs to corresponding [[Params]] instances.
-     */
-    def getUidMap(instance: Params): Map[String, Params] = {
-      val uidList = getUidMapImpl(instance)
-      val uidMap = uidList.toMap
-      if (uidList.size != uidMap.size) {
-        throw new RuntimeException("CrossValidator.load found a compound estimator with stages" +
-          s" with duplicate UIDs.  List of UIDs: ${uidList.map(_._1).mkString(", ")}")
-      }
-      uidMap
-    }
-
-    def getUidMapImpl(instance: Params): List[(String, Params)] = {
-      val subStages: Array[Params] = instance match {
-        case p: Pipeline => p.getStages.asInstanceOf[Array[Params]]
-        case pm: PipelineModel => pm.stages.asInstanceOf[Array[Params]]
-        case v: ValidatorParams => Array(v.getEstimator, v.getEvaluator)
-        case ovr: OneVsRestParams =>
-          // TODO: SPARK-11892: This case may require special handling.
-          throw new UnsupportedOperationException("CrossValidator write will fail because it" +
-            " cannot yet handle an estimator containing type: ${ovr.getClass.getName}")
-        case rformModel: RFormulaModel => Array(rformModel.pipelineModel)
-        case _: Params => Array()
-      }
-      val subStageMaps = subStages.map(getUidMapImpl).foldLeft(List.empty[(String, Params)])(_ ++ _)
-      List((instance.uid, instance)) ++ subStageMaps
-    }
-  }
-
-  private[tuning] object SharedReadWrite {
-
-    /**
-     * Check that [[CrossValidator.evaluator]] and [[CrossValidator.estimator]] are Writable.
-     * This does not check [[CrossValidator.estimatorParamMaps]].
-     */
-    def validateParams(instance: ValidatorParams): Unit = {
-      def checkElement(elem: Params, name: String): Unit = elem match {
-        case stage: MLWritable => // good
-        case other =>
-          throw new UnsupportedOperationException("CrossValidator write will fail " +
-            s" because it contains $name which does not implement Writable." +
-            s" Non-Writable $name: ${other.uid} of type ${other.getClass}")
-      }
-      checkElement(instance.getEvaluator, "evaluator")
-      checkElement(instance.getEstimator, "estimator")
-      // Check to make sure all Params apply to this estimator.  Throw an error if any do not.
-      // Extraneous Params would cause problems when loading the estimatorParamMaps.
-      val uidToInstance: Map[String, Params] = CrossValidatorReader.getUidMap(instance)
-      instance.getEstimatorParamMaps.foreach { case pMap: ParamMap =>
-        pMap.toSeq.foreach { case ParamPair(p, v) =>
-          require(uidToInstance.contains(p.parent), s"CrossValidator save requires all Params in" +
-            s" estimatorParamMaps to apply to this CrossValidator, its Estimator, or its" +
-            s" Evaluator.  An extraneous Param was found: $p")
-        }
-      }
-    }
-
-    private[tuning] def saveImpl(
-        path: String,
-        instance: CrossValidatorParams,
-        sc: SparkContext,
-        extraMetadata: Option[JObject] = None): Unit = {
-      import org.json4s.JsonDSL._
-
-      val estimatorParamMapsJson = compact(render(
-        instance.getEstimatorParamMaps.map { case paramMap =>
-          paramMap.toSeq.map { case ParamPair(p, v) =>
-            Map("parent" -> p.parent, "name" -> p.name, "value" -> p.jsonEncode(v))
-          }
-        }.toSeq
-      ))
-      val jsonParams = List(
-        "numFolds" -> parse(instance.numFolds.jsonEncode(instance.getNumFolds)),
-        "estimatorParamMaps" -> parse(estimatorParamMapsJson)
-      )
-      DefaultParamsWriter.saveMetadata(instance, path, sc, extraMetadata, Some(jsonParams))
-
-      val evaluatorPath = new Path(path, "evaluator").toString
-      instance.getEvaluator.asInstanceOf[MLWritable].save(evaluatorPath)
-      val estimatorPath = new Path(path, "estimator").toString
-      instance.getEstimator.asInstanceOf[MLWritable].save(estimatorPath)
-    }
-
-    private[tuning] def load[M <: Model[M]](
-        path: String,
-        sc: SparkContext,
-        expectedClassName: String): (Metadata, Estimator[M], Evaluator, Array[ParamMap], Int) = {
-
-      val metadata = DefaultParamsReader.loadMetadata(path, sc, expectedClassName)
-
-      implicit val format = DefaultFormats
-      val evaluatorPath = new Path(path, "evaluator").toString
-      val evaluator = DefaultParamsReader.loadParamsInstance[Evaluator](evaluatorPath, sc)
-      val estimatorPath = new Path(path, "estimator").toString
-      val estimator = DefaultParamsReader.loadParamsInstance[Estimator[M]](estimatorPath, sc)
-
-      val uidToParams = Map(evaluator.uid -> evaluator) ++ CrossValidatorReader.getUidMap(estimator)
-
-      val numFolds = (metadata.params \ "numFolds").extract[Int]
-      val estimatorParamMaps: Array[ParamMap] =
-        (metadata.params \ "estimatorParamMaps").extract[Seq[Seq[Map[String, String]]]].map {
-          pMap =>
-            val paramPairs = pMap.map { case pInfo: Map[String, String] =>
-              val est = uidToParams(pInfo("parent"))
-              val param = est.getParam(pInfo("name"))
-              val value = param.jsonDecode(pInfo("value"))
-              param -> value
-            }
-            ParamMap(paramPairs: _*)
-        }.toArray
-      (metadata, estimator, evaluator, estimatorParamMaps, numFolds)
-    }
-  }
 }
 
 /**
@@ -346,8 +227,6 @@ class CrossValidatorModel private[ml] (
 @Since("1.6.0")
 object CrossValidatorModel extends MLReadable[CrossValidatorModel] {
 
-  import CrossValidator.SharedReadWrite
-
   @Since("1.6.0")
   override def read: MLReader[CrossValidatorModel] = new CrossValidatorModelReader
 
@@ -357,12 +236,12 @@ object CrossValidatorModel extends MLReadable[CrossValidatorModel] {
   private[CrossValidatorModel]
   class CrossValidatorModelWriter(instance: CrossValidatorModel) extends MLWriter {
 
-    SharedReadWrite.validateParams(instance)
+    ValidatorParams.validateParams(instance)
 
     override protected def saveImpl(path: String): Unit = {
       import org.json4s.JsonDSL._
       val extraMetadata = "avgMetrics" -> instance.avgMetrics.toSeq
-      SharedReadWrite.saveImpl(path, instance, sc, Some(extraMetadata))
+      ValidatorParams.saveImpl(path, instance, sc, Some(extraMetadata))
       val bestModelPath = new Path(path, "bestModel").toString
       instance.bestModel.asInstanceOf[MLWritable].save(bestModelPath)
     }
@@ -376,8 +255,9 @@ object CrossValidatorModel extends MLReadable[CrossValidatorModel] {
     override def load(path: String): CrossValidatorModel = {
       implicit val format = DefaultFormats
 
-      val (metadata, estimator, evaluator, estimatorParamMaps, numFolds) =
-        SharedReadWrite.load(path, sc, className)
+      val (metadata, estimator, evaluator, estimatorParamMaps) =
+        ValidatorParams.loadImpl(path, sc, className)
+      val numFolds = (metadata.params \ "numFolds").extract[Int]
       val bestModelPath = new Path(path, "bestModel").toString
       val bestModel = DefaultParamsReader.loadParamsInstance[Model[_]](bestModelPath, sc)
       val avgMetrics = (metadata.metadata \ "avgMetrics").extract[Seq[Double]].toArray
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala
index 70fa5f0234..4d1d6364d7 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala
@@ -17,12 +17,15 @@
 
 package org.apache.spark.ml.tuning
 
+import org.apache.hadoop.fs.Path
+import org.json4s.DefaultFormats
+
 import org.apache.spark.annotation.{Experimental, Since}
 import org.apache.spark.internal.Logging
 import org.apache.spark.ml.{Estimator, Model}
 import org.apache.spark.ml.evaluation.Evaluator
 import org.apache.spark.ml.param.{DoubleParam, ParamMap, ParamValidators}
-import org.apache.spark.ml.util.Identifiable
+import org.apache.spark.ml.util._
 import org.apache.spark.sql.DataFrame
 import org.apache.spark.sql.types.StructType
 
@@ -33,6 +36,7 @@ private[ml] trait TrainValidationSplitParams extends ValidatorParams {
   /**
    * Param for ratio between train and validation data. Must be between 0 and 1.
    * Default: 0.75
+   *
    * @group param
    */
   val trainRatio: DoubleParam = new DoubleParam(this, "trainRatio",
@@ -55,7 +59,7 @@ private[ml] trait TrainValidationSplitParams extends ValidatorParams {
 @Experimental
 class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: String)
   extends Estimator[TrainValidationSplitModel]
-  with TrainValidationSplitParams with Logging {
+  with TrainValidationSplitParams with MLWritable with Logging {
 
   @Since("1.5.0")
   def this() = this(Identifiable.randomUID("tvs"))
@@ -130,6 +134,47 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: St
     }
     copied
   }
+
+  @Since("2.0.0")
+  override def write: MLWriter = new TrainValidationSplit.TrainValidationSplitWriter(this)
+}
+
+@Since("2.0.0")
+object TrainValidationSplit extends MLReadable[TrainValidationSplit] {
+
+  @Since("2.0.0")
+  override def read: MLReader[TrainValidationSplit] = new TrainValidationSplitReader
+
+  @Since("2.0.0")
+  override def load(path: String): TrainValidationSplit = super.load(path)
+
+  private[TrainValidationSplit] class TrainValidationSplitWriter(instance: TrainValidationSplit)
+    extends MLWriter {
+
+    ValidatorParams.validateParams(instance)
+
+    override protected def saveImpl(path: String): Unit =
+      ValidatorParams.saveImpl(path, instance, sc)
+  }
+
+  private class TrainValidationSplitReader extends MLReader[TrainValidationSplit] {
+
+    /** Checked against metadata when loading model */
+    private val className = classOf[TrainValidationSplit].getName
+
+    override def load(path: String): TrainValidationSplit = {
+      implicit val format = DefaultFormats
+
+      val (metadata, estimator, evaluator, estimatorParamMaps) =
+        ValidatorParams.loadImpl(path, sc, className)
+      val trainRatio = (metadata.params \ "trainRatio").extract[Double]
+      new TrainValidationSplit(metadata.uid)
+        .setEstimator(estimator)
+        .setEvaluator(evaluator)
+        .setEstimatorParamMaps(estimatorParamMaps)
+        .setTrainRatio(trainRatio)
+    }
+  }
 }
 
 /**
@@ -146,7 +191,7 @@ class TrainValidationSplitModel private[ml] (
     @Since("1.5.0") override val uid: String,
     @Since("1.5.0") val bestModel: Model[_],
     @Since("1.5.0") val validationMetrics: Array[Double])
-  extends Model[TrainValidationSplitModel] with TrainValidationSplitParams {
+  extends Model[TrainValidationSplitModel] with TrainValidationSplitParams with MLWritable {
 
   @Since("1.5.0")
   override def transform(dataset: DataFrame): DataFrame = {
@@ -167,4 +212,53 @@ class TrainValidationSplitModel private[ml] (
       validationMetrics.clone())
     copyValues(copied, extra)
   }
+
+  @Since("2.0.0")
+  override def write: MLWriter = new TrainValidationSplitModel.TrainValidationSplitModelWriter(this)
+}
+
+@Since("2.0.0")
+object TrainValidationSplitModel extends MLReadable[TrainValidationSplitModel] {
+
+  @Since("2.0.0")
+  override def read: MLReader[TrainValidationSplitModel] = new TrainValidationSplitModelReader
+
+  @Since("2.0.0")
+  override def load(path: String): TrainValidationSplitModel = super.load(path)
+
+  private[TrainValidationSplitModel]
+  class TrainValidationSplitModelWriter(instance: TrainValidationSplitModel) extends MLWriter {
+
+    ValidatorParams.validateParams(instance)
+
+    override protected def saveImpl(path: String): Unit = {
+      import org.json4s.JsonDSL._
+      val extraMetadata = "validationMetrics" -> instance.validationMetrics.toSeq
+      ValidatorParams.saveImpl(path, instance, sc, Some(extraMetadata))
+      val bestModelPath = new Path(path, "bestModel").toString
+      instance.bestModel.asInstanceOf[MLWritable].save(bestModelPath)
+    }
+  }
+
+  private class TrainValidationSplitModelReader extends MLReader[TrainValidationSplitModel] {
+
+    /** Checked against metadata when loading model */
+    private val className = classOf[TrainValidationSplitModel].getName
+
+    override def load(path: String): TrainValidationSplitModel = {
+      implicit val format = DefaultFormats
+
+      val (metadata, estimator, evaluator, estimatorParamMaps) =
+        ValidatorParams.loadImpl(path, sc, className)
+      val trainRatio = (metadata.params \ "trainRatio").extract[Double]
+      val bestModelPath = new Path(path, "bestModel").toString
+      val bestModel = DefaultParamsReader.loadParamsInstance[Model[_]](bestModelPath, sc)
+      val validationMetrics = (metadata.metadata \ "validationMetrics").extract[Seq[Double]].toArray
+      val tvs = new TrainValidationSplitModel(metadata.uid, bestModel, validationMetrics)
+      tvs.set(tvs.estimator, estimator)
+        .set(tvs.evaluator, evaluator)
+        .set(tvs.estimatorParamMaps, estimatorParamMaps)
+        .set(tvs.trainRatio, trainRatio)
+    }
+  }
 }
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala
index 953456e8f0..7a4e106aeb 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala
@@ -17,9 +17,17 @@
 
 package org.apache.spark.ml.tuning
 
-import org.apache.spark.ml.Estimator
+import org.apache.hadoop.fs.Path
+import org.json4s.{DefaultFormats, _}
+import org.json4s.jackson.JsonMethods._
+
+import org.apache.spark.SparkContext
+import org.apache.spark.ml.{Estimator, Model}
 import org.apache.spark.ml.evaluation.Evaluator
-import org.apache.spark.ml.param.{Param, ParamMap, Params}
+import org.apache.spark.ml.param.{Param, ParamMap, ParamPair, Params}
+import org.apache.spark.ml.util.{DefaultParamsReader, DefaultParamsWriter, MetaAlgorithmReadWrite,
+  MLWritable}
+import org.apache.spark.ml.util.DefaultParamsReader.Metadata
 import org.apache.spark.sql.types.StructType
 
 /**
@@ -69,3 +77,108 @@ private[ml] trait ValidatorParams extends Params {
     est.copy(firstEstimatorParamMap).transformSchema(schema)
   }
 }
+
+private[ml] object ValidatorParams {
+  /**
+   * Check that [[ValidatorParams.evaluator]] and [[ValidatorParams.estimator]] are Writable.
+   * This does not check [[ValidatorParams.estimatorParamMaps]].
+   */
+  def validateParams(instance: ValidatorParams): Unit = {
+    def checkElement(elem: Params, name: String): Unit = elem match {
+      case stage: MLWritable => // good
+      case other =>
+        throw new UnsupportedOperationException(instance.getClass.getName + " write will fail " +
+          s" because it contains $name which does not implement Writable." +
+          s" Non-Writable $name: ${other.uid} of type ${other.getClass}")
+    }
+    checkElement(instance.getEvaluator, "evaluator")
+    checkElement(instance.getEstimator, "estimator")
+    // Check to make sure all Params apply to this estimator.  Throw an error if any do not.
+    // Extraneous Params would cause problems when loading the estimatorParamMaps.
+    val uidToInstance: Map[String, Params] = MetaAlgorithmReadWrite.getUidMap(instance)
+    instance.getEstimatorParamMaps.foreach { case pMap: ParamMap =>
+      pMap.toSeq.foreach { case ParamPair(p, v) =>
+        require(uidToInstance.contains(p.parent), s"ValidatorParams save requires all Params in" +
+          s" estimatorParamMaps to apply to this ValidatorParams, its Estimator, or its" +
+          s" Evaluator. An extraneous Param was found: $p")
+      }
+    }
+  }
+
+  /**
+   * Generic implementation of save for [[ValidatorParams]] types.
+   * This handles all [[ValidatorParams]] fields and saves [[Param]] values, but the implementing
+   * class needs to handle model data.
+   */
+  def saveImpl(
+      path: String,
+      instance: ValidatorParams,
+      sc: SparkContext,
+      extraMetadata: Option[JObject] = None): Unit = {
+    import org.json4s.JsonDSL._
+
+    val estimatorParamMapsJson = compact(render(
+      instance.getEstimatorParamMaps.map { case paramMap =>
+        paramMap.toSeq.map { case ParamPair(p, v) =>
+          Map("parent" -> p.parent, "name" -> p.name, "value" -> p.jsonEncode(v))
+        }
+      }.toSeq
+    ))
+
+    val validatorSpecificParams = instance match {
+      case cv: CrossValidatorParams =>
+        List("numFolds" -> parse(cv.numFolds.jsonEncode(cv.getNumFolds)))
+      case tvs: TrainValidationSplitParams =>
+        List("trainRatio" -> parse(tvs.trainRatio.jsonEncode(tvs.getTrainRatio)))
+      case _ =>
+        // This should not happen.
+        throw new NotImplementedError("ValidatorParams.saveImpl does not handle type: " +
+          instance.getClass.getCanonicalName)
+    }
+
+    val jsonParams = validatorSpecificParams ++ List(
+      "estimatorParamMaps" -> parse(estimatorParamMapsJson))
+
+    DefaultParamsWriter.saveMetadata(instance, path, sc, extraMetadata, Some(jsonParams))
+
+    val evaluatorPath = new Path(path, "evaluator").toString
+    instance.getEvaluator.asInstanceOf[MLWritable].save(evaluatorPath)
+    val estimatorPath = new Path(path, "estimator").toString
+    instance.getEstimator.asInstanceOf[MLWritable].save(estimatorPath)
+  }
+
+  /**
+   * Generic implementation of load for [[ValidatorParams]] types.
+   * This handles all [[ValidatorParams]] fields, but the implementing
+   * class needs to handle model data and special [[Param]] values.
+   */
+  def loadImpl[M <: Model[M]](
+      path: String,
+      sc: SparkContext,
+      expectedClassName: String): (Metadata, Estimator[M], Evaluator, Array[ParamMap]) = {
+
+    val metadata = DefaultParamsReader.loadMetadata(path, sc, expectedClassName)
+
+    implicit val format = DefaultFormats
+    val evaluatorPath = new Path(path, "evaluator").toString
+    val evaluator = DefaultParamsReader.loadParamsInstance[Evaluator](evaluatorPath, sc)
+    val estimatorPath = new Path(path, "estimator").toString
+    val estimator = DefaultParamsReader.loadParamsInstance[Estimator[M]](estimatorPath, sc)
+
+    val uidToParams = Map(evaluator.uid -> evaluator) ++ MetaAlgorithmReadWrite.getUidMap(estimator)
+
+    val estimatorParamMaps: Array[ParamMap] =
+      (metadata.params \ "estimatorParamMaps").extract[Seq[Seq[Map[String, String]]]].map {
+        pMap =>
+          val paramPairs = pMap.map { case pInfo: Map[String, String] =>
+            val est = uidToParams(pInfo("parent"))
+            val param = est.getParam(pInfo("name"))
+            val value = param.jsonDecode(pInfo("value"))
+            param -> value
+          }
+          ParamMap(paramPairs: _*)
+      }.toArray
+
+    (metadata, estimator, evaluator, estimatorParamMaps)
+  }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala
index c95e536abd..5a596cad06 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala
@@ -21,13 +21,18 @@ import java.io.IOException
 
 import org.apache.hadoop.fs.Path
 import org.json4s._
-import org.json4s.jackson.JsonMethods._
+import org.json4s.{DefaultFormats, JObject}
 import org.json4s.JsonDSL._
+import org.json4s.jackson.JsonMethods._
 
 import org.apache.spark.SparkContext
 import org.apache.spark.annotation.{Experimental, Since}
 import org.apache.spark.internal.Logging
+import org.apache.spark.ml._
+import org.apache.spark.ml.classification.OneVsRestParams
+import org.apache.spark.ml.feature.RFormulaModel
 import org.apache.spark.ml.param.{ParamPair, Params}
+import org.apache.spark.ml.tuning.ValidatorParams
 import org.apache.spark.sql.SQLContext
 import org.apache.spark.util.Utils
 
@@ -352,3 +357,38 @@ private[ml] object DefaultParamsReader {
     cls.getMethod("read").invoke(null).asInstanceOf[MLReader[T]].load(path)
   }
 }
+
+/**
+ * Default Meta-Algorithm read and write implementation.
+ */
+private[ml] object MetaAlgorithmReadWrite {
+  /**
+   * Examine the given estimator (which may be a compound estimator) and extract a mapping
+   * from UIDs to corresponding [[Params]] instances.
+   */
+  def getUidMap(instance: Params): Map[String, Params] = {
+    val uidList = getUidMapImpl(instance)
+    val uidMap = uidList.toMap
+    if (uidList.size != uidMap.size) {
+      throw new RuntimeException(s"${instance.getClass.getName}.load found a compound estimator" +
+        s" with stages with duplicate UIDs. List of UIDs: ${uidList.map(_._1).mkString(", ")}.")
+    }
+    uidMap
+  }
+
+  private def getUidMapImpl(instance: Params): List[(String, Params)] = {
+    val subStages: Array[Params] = instance match {
+      case p: Pipeline => p.getStages.asInstanceOf[Array[Params]]
+      case pm: PipelineModel => pm.stages.asInstanceOf[Array[Params]]
+      case v: ValidatorParams => Array(v.getEstimator, v.getEvaluator)
+      case ovr: OneVsRestParams =>
+        // TODO: SPARK-11892: This case may require special handling.
+        throw new UnsupportedOperationException(s"${instance.getClass.getName} write will fail" +
+          s" because it cannot yet handle an estimator containing type: ${ovr.getClass.getName}.")
+      case rformModel: RFormulaModel => Array(rformModel.pipelineModel)
+      case _: Params => Array()
+    }
+    val subStageMaps = subStages.map(getUidMapImpl).foldLeft(List.empty[(String, Params)])(_ ++ _)
+    List((instance.uid, instance)) ++ subStageMaps
+  }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala
index cf8dcefebc..7cf7b3e087 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala
@@ -19,17 +19,20 @@ package org.apache.spark.ml.tuning
 
 import org.apache.spark.SparkFunSuite
 import org.apache.spark.ml.{Estimator, Model}
-import org.apache.spark.ml.classification.LogisticRegression
+import org.apache.spark.ml.classification.{LogisticRegression, LogisticRegressionModel}
 import org.apache.spark.ml.evaluation.{BinaryClassificationEvaluator, Evaluator, RegressionEvaluator}
 import org.apache.spark.ml.param.ParamMap
 import org.apache.spark.ml.param.shared.HasInputCol
 import org.apache.spark.ml.regression.LinearRegression
+import org.apache.spark.ml.util.DefaultReadWriteTest
 import org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInput
+import org.apache.spark.mllib.linalg.Vectors
 import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext}
 import org.apache.spark.sql.DataFrame
 import org.apache.spark.sql.types.StructType
 
-class TrainValidationSplitSuite extends SparkFunSuite with MLlibTestSparkContext {
+class TrainValidationSplitSuite
+  extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
   test("train validation with logistic regression") {
     val dataset = sqlContext.createDataFrame(
       sc.parallelize(generateLogisticInput(1.0, 1.0, 100, 42), 2))
@@ -105,6 +108,44 @@ class TrainValidationSplitSuite extends SparkFunSuite with MLlibTestSparkContext
       cv.transformSchema(new StructType())
     }
   }
+
+  test("read/write: TrainValidationSplit") {
+    val lr = new LogisticRegression().setMaxIter(3)
+    val evaluator = new BinaryClassificationEvaluator()
+    val paramMaps = new ParamGridBuilder()
+        .addGrid(lr.regParam, Array(0.1, 0.2))
+        .build()
+    val tvs = new TrainValidationSplit()
+      .setEstimator(lr)
+      .setEvaluator(evaluator)
+      .setTrainRatio(0.5)
+      .setEstimatorParamMaps(paramMaps)
+
+    val tvs2 = testDefaultReadWrite(tvs, testParams = false)
+
+    assert(tvs.getTrainRatio === tvs2.getTrainRatio)
+  }
+
+  test("read/write: TrainValidationSplitModel") {
+    val lr = new LogisticRegression()
+      .setThreshold(0.6)
+    val lrModel = new LogisticRegressionModel(lr.uid, Vectors.dense(1.0, 2.0), 1.2)
+      .setThreshold(0.6)
+    val evaluator = new BinaryClassificationEvaluator()
+    val paramMaps = new ParamGridBuilder()
+        .addGrid(lr.regParam, Array(0.1, 0.2))
+        .build()
+    val tvs = new TrainValidationSplitModel("cvUid", lrModel, Array(0.3, 0.6))
+    tvs.set(tvs.estimator, lr)
+      .set(tvs.evaluator, evaluator)
+      .set(tvs.trainRatio, 0.5)
+      .set(tvs.estimatorParamMaps, paramMaps)
+
+    val tvs2 = testDefaultReadWrite(tvs, testParams = false)
+
+    assert(tvs.getTrainRatio === tvs2.getTrainRatio)
+    assert(tvs.validationMetrics === tvs2.validationMetrics)
+  }
 }
 
 object TrainValidationSplitSuite {
-- 
cgit v1.2.3


From 328c71161bdae569a534dcd05e14ec485e895c5c Mon Sep 17 00:00:00 2001
From: Herman van Hovell 
Date: Mon, 28 Mar 2016 16:22:02 -0700
Subject: [SPARK-14086][SQL] Add DDL commands to ANTLR4 parser

#### What changes were proposed in this pull request?

This PR adds all the current Spark SQL DDL commands to the new ANTLR 4 based SQL parser.

I have found a few inconsistencies in the current commands:
- Function has an alias field. This is actually the class name of the function.
- Partition specifications should contain nulls in some commands, and contain `None`s in others.
- `AlterTableSkewedLocation`: Should defines which columns have skewed values, and should allow us to define storage for each skewed combination of values. We currently only allow one value per field.
- `AlterTableSetFileFormat`: Should only have one file format, it currently supports both.

I have implemented all these comments like they were, and I propose to improve them in follow-up PRs.

#### How was this patch tested?

The existing DDLCommandSuite.

cc rxin andrewor14 yhuai

Author: Herman van Hovell 

Closes #12011 from hvanhovell/SPARK-14086.
---
 .../spark/sql/execution/SparkSqlParser.scala       | 620 ++++++++++++++++++++-
 .../sql/execution/command/DDLCommandSuite.scala    |   5 +-
 2 files changed, 619 insertions(+), 6 deletions(-)

diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala
index c098fa99c2..a8313deeef 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala
@@ -20,7 +20,7 @@ import scala.collection.JavaConverters._
 
 import org.apache.spark.sql.SaveMode
 import org.apache.spark.sql.catalyst.TableIdentifier
-import org.apache.spark.sql.catalyst.parser.ng.{AbstractSqlParser, AstBuilder}
+import org.apache.spark.sql.catalyst.parser.ng.{AbstractSqlParser, AstBuilder, ParseException}
 import org.apache.spark.sql.catalyst.parser.ng.SqlBaseParser._
 import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, OneRowRelation}
 import org.apache.spark.sql.execution.command.{DescribeCommand => _, _}
@@ -200,8 +200,8 @@ class SparkSqlAstBuilder extends AstBuilder {
   }
 
   /**
-    * Convert a table property list into a key-value map.
-    */
+   * Convert a table property list into a key-value map.
+   */
   override def visitTablePropertyList(
       ctx: TablePropertyListContext): Map[String, String] = withOrigin(ctx) {
     ctx.tableProperty.asScala.map { property =>
@@ -216,4 +216,618 @@ class SparkSqlAstBuilder extends AstBuilder {
       key -> value
     }.toMap
   }
+
+  /**
+   * Create a [[CreateDatabase]] command.
+   *
+   * For example:
+   * {{{
+   *   CREATE DATABASE [IF NOT EXISTS] database_name [COMMENT database_comment]
+   *    [LOCATION path] [WITH DBPROPERTIES (key1=val1, key2=val2, ...)]
+   * }}}
+   */
+  override def visitCreateDatabase(ctx: CreateDatabaseContext): LogicalPlan = withOrigin(ctx) {
+    CreateDatabase(
+      ctx.identifier.getText,
+      ctx.EXISTS != null,
+      Option(ctx.locationSpec).map(visitLocationSpec),
+      Option(ctx.comment).map(string),
+      Option(ctx.tablePropertyList).map(visitTablePropertyList).getOrElse(Map.empty))(
+      command(ctx))
+  }
+
+  /**
+   * Create an [[AlterDatabaseProperties]] command.
+   *
+   * For example:
+   * {{{
+   *   ALTER (DATABASE|SCHEMA) database SET DBPROPERTIES (property_name=property_value, ...);
+   * }}}
+   */
+  override def visitSetDatabaseProperties(
+      ctx: SetDatabasePropertiesContext): LogicalPlan = withOrigin(ctx) {
+    AlterDatabaseProperties(
+      ctx.identifier.getText,
+      visitTablePropertyList(ctx.tablePropertyList))(
+      command(ctx))
+  }
+
+  /**
+   * Create a [[DropDatabase]] command.
+   *
+   * For example:
+   * {{{
+   *   DROP (DATABASE|SCHEMA) [IF EXISTS] database [RESTRICT|CASCADE];
+   * }}}
+   */
+  override def visitDropDatabase(ctx: DropDatabaseContext): LogicalPlan = withOrigin(ctx) {
+    DropDatabase(ctx.identifier.getText, ctx.EXISTS != null, ctx.CASCADE == null)(command(ctx))
+  }
+
+  /**
+   * Create a [[DescribeDatabase]] command.
+   *
+   * For example:
+   * {{{
+   *   DESCRIBE DATABASE [EXTENDED] database;
+   * }}}
+   */
+  override def visitDescribeDatabase(ctx: DescribeDatabaseContext): LogicalPlan = withOrigin(ctx) {
+    DescribeDatabase(ctx.identifier.getText, ctx.EXTENDED != null)(command(ctx))
+  }
+
+  /**
+   * Create a [[CreateFunction]] command.
+   *
+   * For example:
+   * {{{
+   *   CREATE [TEMPORARY] FUNCTION [db_name.]function_name AS class_name
+   *    [USING JAR|FILE|ARCHIVE 'file_uri' [, JAR|FILE|ARCHIVE 'file_uri']];
+   * }}}
+   */
+  override def visitCreateFunction(ctx: CreateFunctionContext): LogicalPlan = withOrigin(ctx) {
+    val resources = ctx.resource.asScala.map { resource =>
+      val resourceType = resource.identifier.getText.toLowerCase
+      resourceType match {
+        case "jar" | "file" | "archive" =>
+          resourceType -> string(resource.STRING)
+        case other =>
+          throw new ParseException(s"Resource Type '$resourceType' is not supported.", ctx)
+      }
+    }
+
+    // Extract database, name & alias.
+    val (database, function) = visitFunctionName(ctx.qualifiedName)
+    CreateFunction(
+      database,
+      function,
+      string(ctx.className), // TODO this is not an alias.
+      resources,
+      ctx.TEMPORARY != null)(
+      command(ctx))
+  }
+
+  /**
+   * Create a [[DropFunction]] command.
+   *
+   * For example:
+   * {{{
+   *   DROP [TEMPORARY] FUNCTION [IF EXISTS] function;
+   * }}}
+   */
+  override def visitDropFunction(ctx: DropFunctionContext): LogicalPlan = withOrigin(ctx) {
+    val (database, function) = visitFunctionName(ctx.qualifiedName)
+    DropFunction(database, function, ctx.EXISTS != null, ctx.TEMPORARY != null)(command(ctx))
+  }
+
+  /**
+   * Create a function database (optional) and name pair.
+   */
+  private def visitFunctionName(ctx: QualifiedNameContext): (Option[String], String) = {
+    ctx.identifier().asScala.map(_.getText) match {
+      case Seq(db, fn) => (Option(db), fn)
+      case Seq(fn) => (None, fn)
+      case other => throw new ParseException(s"Unsupported function name '${ctx.getText}'", ctx)
+    }
+  }
+
+  /**
+   * Create a [[AlterTableRename]] command.
+   *
+   * For example:
+   * {{{
+   *   ALTER TABLE table1 RENAME TO table2;
+   * }}}
+   */
+  override def visitRenameTable(ctx: RenameTableContext): LogicalPlan = withOrigin(ctx) {
+    AlterTableRename(
+      visitTableIdentifier(ctx.from),
+      visitTableIdentifier(ctx.to))(
+      command(ctx))
+  }
+
+  /**
+   * Create an [[AlterTableSetProperties]] command.
+   *
+   * For example:
+   * {{{
+   *   ALTER TABLE table SET TBLPROPERTIES ('comment' = new_comment);
+   * }}}
+   */
+  override def visitSetTableProperties(
+      ctx: SetTablePropertiesContext): LogicalPlan = withOrigin(ctx) {
+    AlterTableSetProperties(
+      visitTableIdentifier(ctx.tableIdentifier),
+      visitTablePropertyList(ctx.tablePropertyList))(
+      command(ctx))
+  }
+
+  /**
+   * Create an [[AlterTableUnsetProperties]] command.
+   *
+   * For example:
+   * {{{
+   *   ALTER TABLE table UNSET TBLPROPERTIES IF EXISTS ('comment', 'key');
+   * }}}
+   */
+  override def visitUnsetTableProperties(
+      ctx: UnsetTablePropertiesContext): LogicalPlan = withOrigin(ctx) {
+    AlterTableUnsetProperties(
+      visitTableIdentifier(ctx.tableIdentifier),
+      visitTablePropertyList(ctx.tablePropertyList),
+      ctx.EXISTS != null)(
+      command(ctx))
+  }
+
+  /**
+   * Create an [[AlterTableSerDeProperties]] command.
+   *
+   * For example:
+   * {{{
+   *   ALTER TABLE table [PARTITION spec] SET SERDE serde_name [WITH SERDEPROPERTIES props];
+   *   ALTER TABLE table [PARTITION spec] SET SERDEPROPERTIES serde_properties;
+   * }}}
+   */
+  override def visitSetTableSerDe(ctx: SetTableSerDeContext): LogicalPlan = withOrigin(ctx) {
+    AlterTableSerDeProperties(
+      visitTableIdentifier(ctx.tableIdentifier),
+      Option(ctx.STRING).map(string),
+      Option(ctx.tablePropertyList).map(visitTablePropertyList),
+      // TODO a partition spec is allowed to have optional values. This is currently violated.
+      Option(ctx.partitionSpec).map(visitNonOptionalPartitionSpec))(
+      command(ctx))
+  }
+
+  /**
+   * Create an [[AlterTableStorageProperties]] command.
+   *
+   * For example:
+   * {{{
+   *   ALTER TABLE table CLUSTERED BY (col, ...) [SORTED BY (col, ...)] INTO n BUCKETS;
+   * }}}
+   */
+  override def visitBucketTable(ctx: BucketTableContext): LogicalPlan = withOrigin(ctx) {
+    AlterTableStorageProperties(
+      visitTableIdentifier(ctx.tableIdentifier),
+      visitBucketSpec(ctx.bucketSpec))(
+      command(ctx))
+  }
+
+  /**
+   * Create an [[AlterTableNotClustered]] command.
+   *
+   * For example:
+   * {{{
+   *   ALTER TABLE table NOT CLUSTERED;
+   * }}}
+   */
+  override def visitUnclusterTable(ctx: UnclusterTableContext): LogicalPlan = withOrigin(ctx) {
+    AlterTableNotClustered(visitTableIdentifier(ctx.tableIdentifier))(command(ctx))
+  }
+
+  /**
+   * Create an [[AlterTableNotSorted]] command.
+   *
+   * For example:
+   * {{{
+   *   ALTER TABLE table NOT SORTED;
+   * }}}
+   */
+  override def visitUnsortTable(ctx: UnsortTableContext): LogicalPlan = withOrigin(ctx) {
+    AlterTableNotSorted(visitTableIdentifier(ctx.tableIdentifier))(command(ctx))
+  }
+
+  /**
+   * Create an [[AlterTableSkewed]] command.
+   *
+   * For example:
+   * {{{
+   *   ALTER TABLE table SKEWED BY (col1, col2)
+   *   ON ((col1_value, col2_value) [, (col1_value, col2_value), ...])
+   *   [STORED AS DIRECTORIES];
+   * }}}
+   */
+  override def visitSkewTable(ctx: SkewTableContext): LogicalPlan = withOrigin(ctx) {
+    val table = visitTableIdentifier(ctx.tableIdentifier)
+    val (cols, values, storedAsDirs) = visitSkewSpec(ctx.skewSpec)
+    AlterTableSkewed(table, cols, values, storedAsDirs)(command(ctx))
+  }
+
+  /**
+   * Create an [[AlterTableNotSorted]] command.
+   *
+   * For example:
+   * {{{
+   *   ALTER TABLE table NOT SKEWED;
+   * }}}
+   */
+  override def visitUnskewTable(ctx: UnskewTableContext): LogicalPlan = withOrigin(ctx) {
+    AlterTableNotSkewed(visitTableIdentifier(ctx.tableIdentifier))(command(ctx))
+  }
+
+  /**
+   * Create an [[AlterTableNotStoredAsDirs]] command.
+   *
+   * For example:
+   * {{{
+   *   ALTER TABLE table NOT STORED AS DIRECTORIES
+   * }}}
+   */
+  override def visitUnstoreTable(ctx: UnstoreTableContext): LogicalPlan = withOrigin(ctx) {
+    AlterTableNotStoredAsDirs(visitTableIdentifier(ctx.tableIdentifier))(command(ctx))
+  }
+
+  /**
+   * Create an [[AlterTableSkewedLocation]] command.
+   *
+   * For example:
+   * {{{
+   *   ALTER TABLE table SET SKEWED LOCATION (col1="loc1" [, (col2, col3)="loc2", ...] );
+   * }}}
+   */
+  override def visitSetTableSkewLocations(
+      ctx: SetTableSkewLocationsContext): LogicalPlan = withOrigin(ctx) {
+    val skewedMap = ctx.skewedLocationList.skewedLocation.asScala.flatMap {
+      slCtx =>
+        val location = string(slCtx.STRING)
+        if (slCtx.constant != null) {
+          Seq(visitStringConstant(slCtx.constant) -> location)
+        } else {
+          // TODO this is similar to what was in the original implementation. However this does not
+          // make to much sense to me since we should be storing a tuple of values (not column
+          // names) for which we want a dedicated storage location.
+          visitConstantList(slCtx.constantList).map(_ -> location)
+        }
+    }.toMap
+
+    AlterTableSkewedLocation(
+      visitTableIdentifier(ctx.tableIdentifier),
+      skewedMap)(
+      command(ctx))
+  }
+
+  /**
+   * Create an [[AlterTableAddPartition]] command.
+   *
+   * For example:
+   * {{{
+   *   ALTER TABLE table ADD [IF NOT EXISTS] PARTITION spec [LOCATION 'loc1']
+   * }}}
+   */
+  override def visitAddTablePartition(
+      ctx: AddTablePartitionContext): LogicalPlan = withOrigin(ctx) {
+    // Create partition spec to location mapping.
+    val specsAndLocs = ctx.partitionSpecLocation.asScala.map {
+      splCtx =>
+        val spec = visitNonOptionalPartitionSpec(splCtx.partitionSpec)
+        val location = Option(splCtx.locationSpec).map(visitLocationSpec)
+        spec -> location
+    }
+    AlterTableAddPartition(
+      visitTableIdentifier(ctx.tableIdentifier),
+      specsAndLocs,
+      ctx.EXISTS != null)(
+      command(ctx))
+  }
+
+  /**
+   * Create an [[AlterTableExchangePartition]] command.
+   *
+   * For example:
+   * {{{
+   *   ALTER TABLE table1 EXCHANGE PARTITION spec WITH TABLE table2;
+   * }}}
+   */
+  override def visitExchangeTablePartition(
+      ctx: ExchangeTablePartitionContext): LogicalPlan = withOrigin(ctx) {
+    AlterTableExchangePartition(
+      visitTableIdentifier(ctx.from),
+      visitTableIdentifier(ctx.to),
+      visitNonOptionalPartitionSpec(ctx.partitionSpec))(
+      command(ctx))
+  }
+
+  /**
+   * Create an [[AlterTableRenamePartition]] command
+   *
+   * For example:
+   * {{{
+   *   ALTER TABLE table PARTITION spec1 RENAME TO PARTITION spec2;
+   * }}}
+   */
+  override def visitRenameTablePartition(
+      ctx: RenameTablePartitionContext): LogicalPlan = withOrigin(ctx) {
+    AlterTableRenamePartition(
+      visitTableIdentifier(ctx.tableIdentifier),
+      visitNonOptionalPartitionSpec(ctx.from),
+      visitNonOptionalPartitionSpec(ctx.to))(
+      command(ctx))
+  }
+
+  /**
+   * Create an [[AlterTableDropPartition]] command
+   *
+   * For example:
+   * {{{
+   *   ALTER TABLE table DROP [IF EXISTS] PARTITION spec1[, PARTITION spec2, ...] [PURGE];
+   * }}}
+   */
+  override def visitDropTablePartitions(
+      ctx: DropTablePartitionsContext): LogicalPlan = withOrigin(ctx) {
+    AlterTableDropPartition(
+      visitTableIdentifier(ctx.tableIdentifier),
+      ctx.partitionSpec.asScala.map(visitNonOptionalPartitionSpec),
+      ctx.EXISTS != null,
+      ctx.PURGE != null)(
+      command(ctx))
+  }
+
+  /**
+   * Create an [[AlterTableArchivePartition]] command
+   *
+   * For example:
+   * {{{
+   *   ALTER TABLE table ARCHIVE PARTITION spec;
+   * }}}
+   */
+  override def visitArchiveTablePartition(
+      ctx: ArchiveTablePartitionContext): LogicalPlan = withOrigin(ctx) {
+    AlterTableArchivePartition(
+      visitTableIdentifier(ctx.tableIdentifier),
+      visitNonOptionalPartitionSpec(ctx.partitionSpec))(
+      command(ctx))
+  }
+
+  /**
+   * Create an [[AlterTableUnarchivePartition]] command
+   *
+   * For example:
+   * {{{
+   *   ALTER TABLE table UNARCHIVE PARTITION spec;
+   * }}}
+   */
+  override def visitUnarchiveTablePartition(
+      ctx: UnarchiveTablePartitionContext): LogicalPlan = withOrigin(ctx) {
+    AlterTableUnarchivePartition(
+      visitTableIdentifier(ctx.tableIdentifier),
+      visitNonOptionalPartitionSpec(ctx.partitionSpec))(
+      command(ctx))
+  }
+
+  /**
+   * Create an [[AlterTableSetFileFormat]] command
+   *
+   * For example:
+   * {{{
+   *   ALTER TABLE table [PARTITION spec] SET FILEFORMAT file_format;
+   * }}}
+   */
+  override def visitSetTableFileFormat(
+      ctx: SetTableFileFormatContext): LogicalPlan = withOrigin(ctx) {
+    // AlterTableSetFileFormat currently takes both a GenericFileFormat and a
+    // TableFileFormatContext. This is a bit weird because it should only take one. It also should
+    // use a CatalogFileFormat instead of either a String or a Sequence of Strings. We will address
+    // this in a follow-up PR.
+    val (fileFormat, genericFormat) = ctx.fileFormat match {
+      case s: GenericFileFormatContext =>
+        (Seq.empty[String], Option(s.identifier.getText))
+      case s: TableFileFormatContext =>
+        val elements = Seq(s.inFmt, s.outFmt) ++
+          Option(s.serdeCls).toSeq ++
+          Option(s.inDriver).toSeq ++
+          Option(s.outDriver).toSeq
+        (elements.map(string), None)
+    }
+    AlterTableSetFileFormat(
+      visitTableIdentifier(ctx.tableIdentifier),
+      Option(ctx.partitionSpec).map(visitNonOptionalPartitionSpec),
+      fileFormat,
+      genericFormat)(
+      command(ctx))
+  }
+
+  /**
+   * Create an [[AlterTableSetLocation]] command
+   *
+   * For example:
+   * {{{
+   *   ALTER TABLE table [PARTITION spec] SET LOCATION "loc";
+   * }}}
+   */
+  override def visitSetTableLocation(ctx: SetTableLocationContext): LogicalPlan = withOrigin(ctx) {
+    AlterTableSetLocation(
+      visitTableIdentifier(ctx.tableIdentifier),
+      Option(ctx.partitionSpec).map(visitNonOptionalPartitionSpec),
+      visitLocationSpec(ctx.locationSpec))(
+      command(ctx))
+  }
+
+  /**
+   * Create an [[AlterTableTouch]] command
+   *
+   * For example:
+   * {{{
+   *   ALTER TABLE table TOUCH [PARTITION spec];
+   * }}}
+   */
+  override def visitTouchTable(ctx: TouchTableContext): LogicalPlan = withOrigin(ctx) {
+    AlterTableTouch(
+      visitTableIdentifier(ctx.tableIdentifier),
+      Option(ctx.partitionSpec).map(visitNonOptionalPartitionSpec))(
+      command(ctx))
+  }
+
+  /**
+   * Create an [[AlterTableCompact]] command
+   *
+   * For example:
+   * {{{
+   *   ALTER TABLE table [PARTITION spec] COMPACT 'compaction_type';
+   * }}}
+   */
+  override def visitCompactTable(ctx: CompactTableContext): LogicalPlan = withOrigin(ctx) {
+    AlterTableCompact(
+      visitTableIdentifier(ctx.tableIdentifier),
+      Option(ctx.partitionSpec).map(visitNonOptionalPartitionSpec),
+      string(ctx.STRING))(
+      command(ctx))
+  }
+
+  /**
+   * Create an [[AlterTableMerge]] command
+   *
+   * For example:
+   * {{{
+   *   ALTER TABLE table [PARTITION spec] CONCATENATE;
+   * }}}
+   */
+  override def visitConcatenateTable(ctx: ConcatenateTableContext): LogicalPlan = withOrigin(ctx) {
+    AlterTableMerge(
+      visitTableIdentifier(ctx.tableIdentifier),
+      Option(ctx.partitionSpec).map(visitNonOptionalPartitionSpec))(
+      command(ctx))
+  }
+
+  /**
+   * Create an [[AlterTableChangeCol]] command
+   *
+   * For example:
+   * {{{
+   *   ALTER TABLE tableIdentifier [PARTITION spec]
+   *    CHANGE [COLUMN] col_old_name col_new_name column_type [COMMENT col_comment]
+   *    [FIRST|AFTER column_name] [CASCADE|RESTRICT];
+   * }}}
+   */
+  override def visitChangeColumn(ctx: ChangeColumnContext): LogicalPlan = withOrigin(ctx) {
+    val col = visitColType(ctx.colType())
+    val comment = if (col.metadata.contains("comment")) {
+      Option(col.metadata.getString("comment"))
+    } else {
+      None
+    }
+
+    AlterTableChangeCol(
+      visitTableIdentifier(ctx.tableIdentifier),
+      Option(ctx.partitionSpec).map(visitNonOptionalPartitionSpec),
+      ctx.oldName.getText,
+      // We could also pass in a struct field - seems easier.
+      col.name,
+      col.dataType,
+      comment,
+      Option(ctx.after).map(_.getText),
+      // Note that Restrict and Cascade are mutually exclusive.
+      ctx.RESTRICT != null,
+      ctx.CASCADE != null)(
+      command(ctx))
+  }
+
+  /**
+   * Create an [[AlterTableAddCol]] command
+   *
+   * For example:
+   * {{{
+   *   ALTER TABLE tableIdentifier [PARTITION spec]
+   *    ADD COLUMNS (name type [COMMENT comment], ...) [CASCADE|RESTRICT]
+   * }}}
+   */
+  override def visitAddColumns(ctx: AddColumnsContext): LogicalPlan = withOrigin(ctx) {
+    AlterTableAddCol(
+      visitTableIdentifier(ctx.tableIdentifier),
+      Option(ctx.partitionSpec).map(visitNonOptionalPartitionSpec),
+      createStructType(ctx.colTypeList),
+      // Note that Restrict and Cascade are mutually exclusive.
+      ctx.RESTRICT != null,
+      ctx.CASCADE != null)(
+      command(ctx))
+  }
+
+  /**
+   * Create an [[AlterTableReplaceCol]] command
+   *
+   * For example:
+   * {{{
+   *   ALTER TABLE tableIdentifier [PARTITION spec]
+   *    REPLACE COLUMNS (name type [COMMENT comment], ...) [CASCADE|RESTRICT]
+   * }}}
+   */
+  override def visitReplaceColumns(ctx: ReplaceColumnsContext): LogicalPlan = withOrigin(ctx) {
+    AlterTableReplaceCol(
+      visitTableIdentifier(ctx.tableIdentifier),
+      Option(ctx.partitionSpec).map(visitNonOptionalPartitionSpec),
+      createStructType(ctx.colTypeList),
+      // Note that Restrict and Cascade are mutually exclusive.
+      ctx.RESTRICT != null,
+      ctx.CASCADE != null)(
+      command(ctx))
+  }
+
+  /**
+   * Create location string.
+   */
+  override def visitLocationSpec(ctx: LocationSpecContext): String = withOrigin(ctx) {
+    string(ctx.STRING)
+  }
+
+  /**
+   * Create a [[BucketSpec]].
+   */
+  override def visitBucketSpec(ctx: BucketSpecContext): BucketSpec = withOrigin(ctx) {
+    BucketSpec(
+      ctx.INTEGER_VALUE.getText.toInt,
+      visitIdentifierList(ctx.identifierList),
+      Option(ctx.orderedIdentifierList).toSeq
+        .flatMap(_.orderedIdentifier.asScala)
+        .map(_.identifier.getText))
+  }
+
+  /**
+   * Create a skew specification. This contains three components:
+   * - The Skewed Columns
+   * - Values for which are skewed. The size of each entry must match the number of skewed columns.
+   * - A store in directory flag.
+   */
+  override def visitSkewSpec(
+      ctx: SkewSpecContext): (Seq[String], Seq[Seq[String]], Boolean) = withOrigin(ctx) {
+    val skewedValues = if (ctx.constantList != null) {
+      Seq(visitConstantList(ctx.constantList))
+    } else {
+      visitNestedConstantList(ctx.nestedConstantList)
+    }
+    (visitIdentifierList(ctx.identifierList), skewedValues, ctx.DIRECTORIES != null)
+  }
+
+  /**
+   * Convert a nested constants list into a sequence of string sequences.
+   */
+  override def visitNestedConstantList(
+      ctx: NestedConstantListContext): Seq[Seq[String]] = withOrigin(ctx) {
+    ctx.constantList.asScala.map(visitConstantList)
+  }
+
+  /**
+   * Convert a constants list into a String sequence.
+   */
+  override def visitConstantList(ctx: ConstantListContext): Seq[String] = withOrigin(ctx) {
+    ctx.constant.asScala.map(visitStringConstant)
+  }
 }
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala
index 7a6343748b..03079c6890 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala
@@ -18,14 +18,13 @@
 package org.apache.spark.sql.execution.command
 
 import org.apache.spark.sql.catalyst.TableIdentifier
-import org.apache.spark.sql.catalyst.expressions.{Ascending, Descending}
 import org.apache.spark.sql.catalyst.plans.PlanTest
-import org.apache.spark.sql.execution.SparkQl
+import org.apache.spark.sql.execution.SparkSqlParser
 import org.apache.spark.sql.execution.datasources.BucketSpec
 import org.apache.spark.sql.types._
 
 class DDLCommandSuite extends PlanTest {
-  private val parser = new SparkQl
+  private val parser = SparkSqlParser
 
   test("create database") {
     val sql =
-- 
cgit v1.2.3


From 34c0638ee6f05aef81d90594dd9b8e06006c2c7f Mon Sep 17 00:00:00 2001
From: Shixiong Zhu 
Date: Mon, 28 Mar 2016 16:23:29 -0700
Subject: [SPARK-14180][CORE] Fix a deadlock in CoarseGrainedExecutorBackend
 Shutdown

## What changes were proposed in this pull request?

Call `executor.stop` in a new thread to eliminate deadlock.

## How was this patch tested?

Existing unit tests

Author: Shixiong Zhu 

Closes #12012 from zsxwing/SPARK-14180.
---
 .../apache/spark/executor/CoarseGrainedExecutorBackend.scala | 12 +++++++++---
 1 file changed, 9 insertions(+), 3 deletions(-)

diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala
index 320a20033d..81e41e6fa7 100644
--- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala
+++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala
@@ -113,9 +113,15 @@ private[spark] class CoarseGrainedExecutorBackend(
 
     case Shutdown =>
       stopping.set(true)
-      executor.stop()
-      stop()
-      rpcEnv.shutdown()
+      new Thread("CoarseGrainedExecutorBackend-stop-executor") {
+        override def run(): Unit = {
+          // executor.stop() will call `SparkEnv.stop()` which waits until RpcEnv stops totally.
+          // However, if `executor.stop()` runs in some thread of RpcEnv, RpcEnv won't be able to
+          // stop until `executor.stop()` returns, which becomes a dead-lock (See SPARK-14180).
+          // Therefore, we put this line in a new thread.
+          executor.stop()
+        }
+      }.start()
   }
 
   override def onDisconnected(remoteAddress: RpcAddress): Unit = {
-- 
cgit v1.2.3


From eebc8c1c95fb7752d09a5846b7cac65f7702c8f2 Mon Sep 17 00:00:00 2001
From: Andrew Or 
Date: Mon, 28 Mar 2016 16:25:15 -0700
Subject: [SPARK-13923][SPARK-14014][SQL] Session catalog follow-ups

## What changes were proposed in this pull request?

This patch addresses the remaining comments left in #11750 and #11918 after they are merged. For a full list of changes in this patch, just trace the commits.

## How was this patch tested?

`SessionCatalogSuite` and `CatalogTestCases`

Author: Andrew Or 

Closes #12006 from andrewor14/session-catalog-followup.
---
 .../sql/catalyst/catalog/InMemoryCatalog.scala     |  18 +-
 .../sql/catalyst/catalog/SessionCatalog.scala      |  74 ++---
 .../spark/sql/catalyst/catalog/interface.scala     |  14 +-
 .../spark/sql/catalyst/analysis/AnalysisTest.scala |   2 +-
 .../catalyst/analysis/DecimalPrecisionSuite.scala  |   2 +-
 .../sql/catalyst/catalog/CatalogTestCases.scala    |   6 +-
 .../sql/catalyst/catalog/SessionCatalogSuite.scala |  30 +--
 .../scala/org/apache/spark/sql/SQLContext.scala    |   2 +-
 .../spark/sql/execution/datasources/ddl.scala      |   4 +-
 .../org/apache/spark/sql/hive/HiveCatalog.scala    | 297 --------------------
 .../org/apache/spark/sql/hive/HiveContext.scala    |   4 +-
 .../spark/sql/hive/HiveExternalCatalog.scala       | 298 +++++++++++++++++++++
 .../spark/sql/hive/HiveMetastoreCatalog.scala      |  22 +-
 .../scala/org/apache/spark/sql/hive/HiveQl.scala   |   6 +-
 .../apache/spark/sql/hive/HiveSessionCatalog.scala |   6 +-
 .../apache/spark/sql/hive/client/HiveClient.scala  |   2 +-
 .../spark/sql/hive/client/HiveClientImpl.scala     |  11 +-
 .../sql/hive/execution/CreateTableAsSelect.scala   |   6 +-
 .../sql/hive/execution/CreateViewAsSelect.scala    |   4 +-
 .../org/apache/spark/sql/hive/test/TestHive.scala  |   4 +-
 .../apache/spark/sql/hive/HiveCatalogSuite.scala   |  49 ----
 .../spark/sql/hive/HiveExternalCatalogSuite.scala  |  49 ++++
 .../org/apache/spark/sql/hive/HiveQlSuite.scala    |  16 +-
 .../apache/spark/sql/hive/ListTablesSuite.scala    |   2 +-
 .../spark/sql/hive/MetastoreDataSourcesSuite.scala |   2 +-
 .../spark/sql/hive/client/VersionsSuite.scala      |   2 +-
 26 files changed, 469 insertions(+), 463 deletions(-)
 delete mode 100644 sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveCatalog.scala
 create mode 100644 sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala
 delete mode 100644 sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveCatalogSuite.scala
 create mode 100644 sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogSuite.scala

diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala
index e216fa5528..2bbb970ec9 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala
@@ -155,7 +155,7 @@ class InMemoryCatalog extends ExternalCatalog {
       tableDefinition: CatalogTable,
       ignoreIfExists: Boolean): Unit = synchronized {
     requireDbExists(db)
-    val table = tableDefinition.name.table
+    val table = tableDefinition.identifier.table
     if (tableExists(db, table)) {
       if (!ignoreIfExists) {
         throw new AnalysisException(s"Table '$table' already exists in database '$db'")
@@ -182,14 +182,14 @@ class InMemoryCatalog extends ExternalCatalog {
   override def renameTable(db: String, oldName: String, newName: String): Unit = synchronized {
     requireTableExists(db, oldName)
     val oldDesc = catalog(db).tables(oldName)
-    oldDesc.table = oldDesc.table.copy(name = TableIdentifier(newName, Some(db)))
+    oldDesc.table = oldDesc.table.copy(identifier = TableIdentifier(newName, Some(db)))
     catalog(db).tables.put(newName, oldDesc)
     catalog(db).tables.remove(oldName)
   }
 
   override def alterTable(db: String, tableDefinition: CatalogTable): Unit = synchronized {
-    requireTableExists(db, tableDefinition.name.table)
-    catalog(db).tables(tableDefinition.name.table).table = tableDefinition
+    requireTableExists(db, tableDefinition.identifier.table)
+    catalog(db).tables(tableDefinition.identifier.table).table = tableDefinition
   }
 
   override def getTable(db: String, table: String): CatalogTable = synchronized {
@@ -296,10 +296,10 @@ class InMemoryCatalog extends ExternalCatalog {
 
   override def createFunction(db: String, func: CatalogFunction): Unit = synchronized {
     requireDbExists(db)
-    if (functionExists(db, func.name.funcName)) {
+    if (functionExists(db, func.identifier.funcName)) {
       throw new AnalysisException(s"Function '$func' already exists in '$db' database")
     } else {
-      catalog(db).functions.put(func.name.funcName, func)
+      catalog(db).functions.put(func.identifier.funcName, func)
     }
   }
 
@@ -310,14 +310,14 @@ class InMemoryCatalog extends ExternalCatalog {
 
   override def renameFunction(db: String, oldName: String, newName: String): Unit = synchronized {
     requireFunctionExists(db, oldName)
-    val newFunc = getFunction(db, oldName).copy(name = FunctionIdentifier(newName, Some(db)))
+    val newFunc = getFunction(db, oldName).copy(identifier = FunctionIdentifier(newName, Some(db)))
     catalog(db).functions.remove(oldName)
     catalog(db).functions.put(newName, newFunc)
   }
 
   override def alterFunction(db: String, funcDefinition: CatalogFunction): Unit = synchronized {
-    requireFunctionExists(db, funcDefinition.name.funcName)
-    catalog(db).functions.put(funcDefinition.name.funcName, funcDefinition)
+    requireFunctionExists(db, funcDefinition.identifier.funcName)
+    catalog(db).functions.put(funcDefinition.identifier.funcName, funcDefinition)
   }
 
   override def getFunction(db: String, funcName: String): CatalogFunction = synchronized {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala
index 34265faa74..a9cf80764d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala
@@ -17,9 +17,7 @@
 
 package org.apache.spark.sql.catalyst.catalog
 
-import java.util.concurrent.ConcurrentHashMap
-
-import scala.collection.JavaConverters._
+import scala.collection.mutable
 
 import org.apache.spark.sql.AnalysisException
 import org.apache.spark.sql.catalyst.{CatalystConf, SimpleCatalystConf}
@@ -31,6 +29,8 @@ import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, SubqueryAlias}
  * An internal catalog that is used by a Spark Session. This internal catalog serves as a
  * proxy to the underlying metastore (e.g. Hive Metastore) and it also manages temporary
  * tables and functions of the Spark Session that it belongs to.
+ *
+ * This class is not thread-safe.
  */
 class SessionCatalog(externalCatalog: ExternalCatalog, conf: CatalystConf) {
   import ExternalCatalog._
@@ -39,8 +39,8 @@ class SessionCatalog(externalCatalog: ExternalCatalog, conf: CatalystConf) {
     this(externalCatalog, new SimpleCatalystConf(true))
   }
 
-  protected[this] val tempTables = new ConcurrentHashMap[String, LogicalPlan]
-  protected[this] val tempFunctions = new ConcurrentHashMap[String, CatalogFunction]
+  protected[this] val tempTables = new mutable.HashMap[String, LogicalPlan]
+  protected[this] val tempFunctions = new mutable.HashMap[String, CatalogFunction]
 
   // Note: we track current database here because certain operations do not explicitly
   // specify the database (e.g. DROP TABLE my_table). In these cases we must first
@@ -122,9 +122,9 @@ class SessionCatalog(externalCatalog: ExternalCatalog, conf: CatalystConf) {
    * If no such database is specified, create it in the current database.
    */
   def createTable(tableDefinition: CatalogTable, ignoreIfExists: Boolean): Unit = {
-    val db = tableDefinition.name.database.getOrElse(currentDb)
-    val table = formatTableName(tableDefinition.name.table)
-    val newTableDefinition = tableDefinition.copy(name = TableIdentifier(table, Some(db)))
+    val db = tableDefinition.identifier.database.getOrElse(currentDb)
+    val table = formatTableName(tableDefinition.identifier.table)
+    val newTableDefinition = tableDefinition.copy(identifier = TableIdentifier(table, Some(db)))
     externalCatalog.createTable(db, newTableDefinition, ignoreIfExists)
   }
 
@@ -138,9 +138,9 @@ class SessionCatalog(externalCatalog: ExternalCatalog, conf: CatalystConf) {
    * this becomes a no-op.
    */
   def alterTable(tableDefinition: CatalogTable): Unit = {
-    val db = tableDefinition.name.database.getOrElse(currentDb)
-    val table = formatTableName(tableDefinition.name.table)
-    val newTableDefinition = tableDefinition.copy(name = TableIdentifier(table, Some(db)))
+    val db = tableDefinition.identifier.database.getOrElse(currentDb)
+    val table = formatTableName(tableDefinition.identifier.table)
+    val newTableDefinition = tableDefinition.copy(identifier = TableIdentifier(table, Some(db)))
     externalCatalog.alterTable(db, newTableDefinition)
   }
 
@@ -164,9 +164,9 @@ class SessionCatalog(externalCatalog: ExternalCatalog, conf: CatalystConf) {
   def createTempTable(
       name: String,
       tableDefinition: LogicalPlan,
-      ignoreIfExists: Boolean): Unit = {
+      overrideIfExists: Boolean): Unit = {
     val table = formatTableName(name)
-    if (tempTables.containsKey(table) && !ignoreIfExists) {
+    if (tempTables.contains(table) && !overrideIfExists) {
       throw new AnalysisException(s"Temporary table '$name' already exists.")
     }
     tempTables.put(table, tableDefinition)
@@ -188,10 +188,11 @@ class SessionCatalog(externalCatalog: ExternalCatalog, conf: CatalystConf) {
     val db = oldName.database.getOrElse(currentDb)
     val oldTableName = formatTableName(oldName.table)
     val newTableName = formatTableName(newName.table)
-    if (oldName.database.isDefined || !tempTables.containsKey(oldTableName)) {
+    if (oldName.database.isDefined || !tempTables.contains(oldTableName)) {
       externalCatalog.renameTable(db, oldTableName, newTableName)
     } else {
-      val table = tempTables.remove(oldTableName)
+      val table = tempTables(oldTableName)
+      tempTables.remove(oldTableName)
       tempTables.put(newTableName, table)
     }
   }
@@ -206,7 +207,7 @@ class SessionCatalog(externalCatalog: ExternalCatalog, conf: CatalystConf) {
   def dropTable(name: TableIdentifier, ignoreIfNotExists: Boolean): Unit = {
     val db = name.database.getOrElse(currentDb)
     val table = formatTableName(name.table)
-    if (name.database.isDefined || !tempTables.containsKey(table)) {
+    if (name.database.isDefined || !tempTables.contains(table)) {
       externalCatalog.dropTable(db, table, ignoreIfNotExists)
     } else {
       tempTables.remove(table)
@@ -224,11 +225,11 @@ class SessionCatalog(externalCatalog: ExternalCatalog, conf: CatalystConf) {
     val db = name.database.getOrElse(currentDb)
     val table = formatTableName(name.table)
     val relation =
-      if (name.database.isDefined || !tempTables.containsKey(table)) {
+      if (name.database.isDefined || !tempTables.contains(table)) {
         val metadata = externalCatalog.getTable(db, table)
         CatalogRelation(db, metadata, alias)
       } else {
-        tempTables.get(table)
+        tempTables(table)
       }
     val qualifiedTable = SubqueryAlias(table, relation)
     // If an alias was specified by the lookup, wrap the plan in a subquery so that
@@ -247,7 +248,7 @@ class SessionCatalog(externalCatalog: ExternalCatalog, conf: CatalystConf) {
   def tableExists(name: TableIdentifier): Boolean = {
     val db = name.database.getOrElse(currentDb)
     val table = formatTableName(name.table)
-    if (name.database.isDefined || !tempTables.containsKey(table)) {
+    if (name.database.isDefined || !tempTables.contains(table)) {
       externalCatalog.tableExists(db, table)
     } else {
       true // it's a temporary table
@@ -266,7 +267,7 @@ class SessionCatalog(externalCatalog: ExternalCatalog, conf: CatalystConf) {
     val dbTables =
       externalCatalog.listTables(db, pattern).map { t => TableIdentifier(t, Some(db)) }
     val regex = pattern.replaceAll("\\*", ".*").r
-    val _tempTables = tempTables.keys().asScala
+    val _tempTables = tempTables.keys.toSeq
       .filter { t => regex.pattern.matcher(t).matches() }
       .map { t => TableIdentifier(t) }
     dbTables ++ _tempTables
@@ -290,7 +291,7 @@ class SessionCatalog(externalCatalog: ExternalCatalog, conf: CatalystConf) {
    * For testing only.
    */
   private[catalog] def getTempTable(name: String): Option[LogicalPlan] = {
-    Option(tempTables.get(name))
+    tempTables.get(name)
   }
 
   // ----------------------------------------------------------------------------
@@ -399,9 +400,9 @@ class SessionCatalog(externalCatalog: ExternalCatalog, conf: CatalystConf) {
    * If no such database is specified, create it in the current database.
    */
   def createFunction(funcDefinition: CatalogFunction): Unit = {
-    val db = funcDefinition.name.database.getOrElse(currentDb)
+    val db = funcDefinition.identifier.database.getOrElse(currentDb)
     val newFuncDefinition = funcDefinition.copy(
-      name = FunctionIdentifier(funcDefinition.name.funcName, Some(db)))
+      identifier = FunctionIdentifier(funcDefinition.identifier.funcName, Some(db)))
     externalCatalog.createFunction(db, newFuncDefinition)
   }
 
@@ -424,9 +425,9 @@ class SessionCatalog(externalCatalog: ExternalCatalog, conf: CatalystConf) {
    * this becomes a no-op.
    */
   def alterFunction(funcDefinition: CatalogFunction): Unit = {
-    val db = funcDefinition.name.database.getOrElse(currentDb)
+    val db = funcDefinition.identifier.database.getOrElse(currentDb)
     val newFuncDefinition = funcDefinition.copy(
-      name = FunctionIdentifier(funcDefinition.name.funcName, Some(db)))
+      identifier = FunctionIdentifier(funcDefinition.identifier.funcName, Some(db)))
     externalCatalog.alterFunction(db, newFuncDefinition)
   }
 
@@ -439,10 +440,10 @@ class SessionCatalog(externalCatalog: ExternalCatalog, conf: CatalystConf) {
    * This assumes no database is specified in `funcDefinition`.
    */
   def createTempFunction(funcDefinition: CatalogFunction, ignoreIfExists: Boolean): Unit = {
-    require(funcDefinition.name.database.isEmpty,
+    require(funcDefinition.identifier.database.isEmpty,
       "attempted to create a temporary function while specifying a database")
-    val name = funcDefinition.name.funcName
-    if (tempFunctions.containsKey(name) && !ignoreIfExists) {
+    val name = funcDefinition.identifier.funcName
+    if (tempFunctions.contains(name) && !ignoreIfExists) {
       throw new AnalysisException(s"Temporary function '$name' already exists.")
     }
     tempFunctions.put(name, funcDefinition)
@@ -455,7 +456,7 @@ class SessionCatalog(externalCatalog: ExternalCatalog, conf: CatalystConf) {
   // Hive has DROP FUNCTION and DROP TEMPORARY FUNCTION. We may want to consolidate
   // dropFunction and dropTempFunction.
   def dropTempFunction(name: String, ignoreIfNotExists: Boolean): Unit = {
-    if (!tempFunctions.containsKey(name) && !ignoreIfNotExists) {
+    if (!tempFunctions.contains(name) && !ignoreIfNotExists) {
       throw new AnalysisException(
         s"Temporary function '$name' cannot be dropped because it does not exist!")
     }
@@ -476,11 +477,12 @@ class SessionCatalog(externalCatalog: ExternalCatalog, conf: CatalystConf) {
       throw new AnalysisException("rename does not support moving functions across databases")
     }
     val db = oldName.database.getOrElse(currentDb)
-    if (oldName.database.isDefined || !tempFunctions.containsKey(oldName.funcName)) {
+    if (oldName.database.isDefined || !tempFunctions.contains(oldName.funcName)) {
       externalCatalog.renameFunction(db, oldName.funcName, newName.funcName)
     } else {
-      val func = tempFunctions.remove(oldName.funcName)
-      val newFunc = func.copy(name = func.name.copy(funcName = newName.funcName))
+      val func = tempFunctions(oldName.funcName)
+      val newFunc = func.copy(identifier = func.identifier.copy(funcName = newName.funcName))
+      tempFunctions.remove(oldName.funcName)
       tempFunctions.put(newName.funcName, newFunc)
     }
   }
@@ -494,10 +496,10 @@ class SessionCatalog(externalCatalog: ExternalCatalog, conf: CatalystConf) {
    */
   def getFunction(name: FunctionIdentifier): CatalogFunction = {
     val db = name.database.getOrElse(currentDb)
-    if (name.database.isDefined || !tempFunctions.containsKey(name.funcName)) {
+    if (name.database.isDefined || !tempFunctions.contains(name.funcName)) {
       externalCatalog.getFunction(db, name.funcName)
     } else {
-      tempFunctions.get(name.funcName)
+      tempFunctions(name.funcName)
     }
   }
 
@@ -510,7 +512,7 @@ class SessionCatalog(externalCatalog: ExternalCatalog, conf: CatalystConf) {
     val dbFunctions =
       externalCatalog.listFunctions(db, pattern).map { f => FunctionIdentifier(f, Some(db)) }
     val regex = pattern.replaceAll("\\*", ".*").r
-    val _tempFunctions = tempFunctions.keys().asScala
+    val _tempFunctions = tempFunctions.keys.toSeq
       .filter { f => regex.pattern.matcher(f).matches() }
       .map { f => FunctionIdentifier(f) }
     dbFunctions ++ _tempFunctions
@@ -520,7 +522,7 @@ class SessionCatalog(externalCatalog: ExternalCatalog, conf: CatalystConf) {
    * Return a temporary function. For testing only.
    */
   private[catalog] def getTempFunction(name: String): Option[CatalogFunction] = {
-    Option(tempFunctions.get(name))
+    tempFunctions.get(name)
   }
 
 }
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala
index 34803133f6..8bb8e09a28 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala
@@ -169,10 +169,10 @@ abstract class ExternalCatalog {
 /**
  * A function defined in the catalog.
  *
- * @param name name of the function
+ * @param identifier name of the function
  * @param className fully qualified class name, e.g. "org.apache.spark.util.MyFunc"
  */
-case class CatalogFunction(name: FunctionIdentifier, className: String)
+case class CatalogFunction(identifier: FunctionIdentifier, className: String)
 
 
 /**
@@ -216,7 +216,7 @@ case class CatalogTablePartition(
  * future once we have a better understanding of how we want to handle skewed columns.
  */
 case class CatalogTable(
-    name: TableIdentifier,
+    identifier: TableIdentifier,
     tableType: CatalogTableType,
     storage: CatalogStorageFormat,
     schema: Seq[CatalogColumn],
@@ -230,12 +230,12 @@ case class CatalogTable(
     viewText: Option[String] = None) {
 
   /** Return the database this table was specified to belong to, assuming it exists. */
-  def database: String = name.database.getOrElse {
-    throw new AnalysisException(s"table $name did not specify database")
+  def database: String = identifier.database.getOrElse {
+    throw new AnalysisException(s"table $identifier did not specify database")
   }
 
   /** Return the fully qualified name of this table, assuming the database was specified. */
-  def qualifiedName: String = name.unquotedString
+  def qualifiedName: String = identifier.unquotedString
 
   /** Syntactic sugar to update a field in `storage`. */
   def withNewStorage(
@@ -290,6 +290,6 @@ case class CatalogRelation(
   // TODO: implement this
   override def output: Seq[Attribute] = Seq.empty
 
-  require(metadata.name.database == Some(db),
+  require(metadata.identifier.database == Some(db),
     "provided database does not much the one specified in the table definition")
 }
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala
index 6fa4beed99..34cb97699d 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala
@@ -31,7 +31,7 @@ trait AnalysisTest extends PlanTest {
   private def makeAnalyzer(caseSensitive: Boolean): Analyzer = {
     val conf = new SimpleCatalystConf(caseSensitive)
     val catalog = new SessionCatalog(new InMemoryCatalog, conf)
-    catalog.createTempTable("TaBlE", TestRelations.testRelation, ignoreIfExists = true)
+    catalog.createTempTable("TaBlE", TestRelations.testRelation, overrideIfExists = true)
     new Analyzer(catalog, EmptyFunctionRegistry, conf) {
       override val extendedResolutionRules = EliminateSubqueryAliases :: Nil
     }
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala
index 31501864a8..6c08ccc34c 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala
@@ -52,7 +52,7 @@ class DecimalPrecisionSuite extends PlanTest with BeforeAndAfter {
   private val b: Expression = UnresolvedAttribute("b")
 
   before {
-    catalog.createTempTable("table", relation, ignoreIfExists = true)
+    catalog.createTempTable("table", relation, overrideIfExists = true)
   }
 
   private def checkType(expression: Expression, expectedType: DataType): Unit = {
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/CatalogTestCases.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/CatalogTestCases.scala
index 277c2d717e..959bd564d9 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/CatalogTestCases.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/CatalogTestCases.scala
@@ -210,7 +210,7 @@ abstract class CatalogTestCases extends SparkFunSuite with BeforeAndAfterEach {
   }
 
   test("get table") {
-    assert(newBasicCatalog().getTable("db2", "tbl1").name.table == "tbl1")
+    assert(newBasicCatalog().getTable("db2", "tbl1").identifier.table == "tbl1")
   }
 
   test("get table when database/table does not exist") {
@@ -452,7 +452,7 @@ abstract class CatalogTestCases extends SparkFunSuite with BeforeAndAfterEach {
     assert(catalog.getFunction("db2", "func1").className == funcClass)
     catalog.renameFunction("db2", "func1", newName)
     intercept[AnalysisException] { catalog.getFunction("db2", "func1") }
-    assert(catalog.getFunction("db2", newName).name.funcName == newName)
+    assert(catalog.getFunction("db2", newName).identifier.funcName == newName)
     assert(catalog.getFunction("db2", newName).className == funcClass)
     intercept[AnalysisException] { catalog.renameFunction("db2", "does_not_exist", "me") }
   }
@@ -549,7 +549,7 @@ abstract class CatalogTestUtils {
 
   def newTable(name: String, database: Option[String] = None): CatalogTable = {
     CatalogTable(
-      name = TableIdentifier(name, database),
+      identifier = TableIdentifier(name, database),
       tableType = CatalogTableType.EXTERNAL_TABLE,
       storage = storageFormat,
       schema = Seq(CatalogColumn("col1", "int"), CatalogColumn("col2", "string")),
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala
index 74e995cc5b..2948c5f8bd 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala
@@ -197,17 +197,17 @@ class SessionCatalogSuite extends SparkFunSuite {
     val catalog = new SessionCatalog(newBasicCatalog())
     val tempTable1 = Range(1, 10, 1, 10, Seq())
     val tempTable2 = Range(1, 20, 2, 10, Seq())
-    catalog.createTempTable("tbl1", tempTable1, ignoreIfExists = false)
-    catalog.createTempTable("tbl2", tempTable2, ignoreIfExists = false)
+    catalog.createTempTable("tbl1", tempTable1, overrideIfExists = false)
+    catalog.createTempTable("tbl2", tempTable2, overrideIfExists = false)
     assert(catalog.getTempTable("tbl1") == Some(tempTable1))
     assert(catalog.getTempTable("tbl2") == Some(tempTable2))
     assert(catalog.getTempTable("tbl3") == None)
     // Temporary table already exists
     intercept[AnalysisException] {
-      catalog.createTempTable("tbl1", tempTable1, ignoreIfExists = false)
+      catalog.createTempTable("tbl1", tempTable1, overrideIfExists = false)
     }
     // Temporary table already exists but we override it
-    catalog.createTempTable("tbl1", tempTable2, ignoreIfExists = true)
+    catalog.createTempTable("tbl1", tempTable2, overrideIfExists = true)
     assert(catalog.getTempTable("tbl1") == Some(tempTable2))
   }
 
@@ -243,7 +243,7 @@ class SessionCatalogSuite extends SparkFunSuite {
     val externalCatalog = newBasicCatalog()
     val sessionCatalog = new SessionCatalog(externalCatalog)
     val tempTable = Range(1, 10, 2, 10, Seq())
-    sessionCatalog.createTempTable("tbl1", tempTable, ignoreIfExists = false)
+    sessionCatalog.createTempTable("tbl1", tempTable, overrideIfExists = false)
     sessionCatalog.setCurrentDatabase("db2")
     assert(sessionCatalog.getTempTable("tbl1") == Some(tempTable))
     assert(externalCatalog.listTables("db2").toSet == Set("tbl1", "tbl2"))
@@ -255,7 +255,7 @@ class SessionCatalogSuite extends SparkFunSuite {
     sessionCatalog.dropTable(TableIdentifier("tbl1"), ignoreIfNotExists = false)
     assert(externalCatalog.listTables("db2").toSet == Set("tbl2"))
     // If database is specified, temp tables are never dropped
-    sessionCatalog.createTempTable("tbl1", tempTable, ignoreIfExists = false)
+    sessionCatalog.createTempTable("tbl1", tempTable, overrideIfExists = false)
     sessionCatalog.createTable(newTable("tbl1", "db2"), ignoreIfExists = false)
     sessionCatalog.dropTable(TableIdentifier("tbl1", Some("db2")), ignoreIfNotExists = false)
     assert(sessionCatalog.getTempTable("tbl1") == Some(tempTable))
@@ -299,7 +299,7 @@ class SessionCatalogSuite extends SparkFunSuite {
     val externalCatalog = newBasicCatalog()
     val sessionCatalog = new SessionCatalog(externalCatalog)
     val tempTable = Range(1, 10, 2, 10, Seq())
-    sessionCatalog.createTempTable("tbl1", tempTable, ignoreIfExists = false)
+    sessionCatalog.createTempTable("tbl1", tempTable, overrideIfExists = false)
     sessionCatalog.setCurrentDatabase("db2")
     assert(sessionCatalog.getTempTable("tbl1") == Some(tempTable))
     assert(externalCatalog.listTables("db2").toSet == Set("tbl1", "tbl2"))
@@ -327,7 +327,7 @@ class SessionCatalogSuite extends SparkFunSuite {
     assert(newTbl1.properties.get("toh") == Some("frem"))
     // Alter table without explicitly specifying database
     sessionCatalog.setCurrentDatabase("db2")
-    sessionCatalog.alterTable(tbl1.copy(name = TableIdentifier("tbl1")))
+    sessionCatalog.alterTable(tbl1.copy(identifier = TableIdentifier("tbl1")))
     val newestTbl1 = externalCatalog.getTable("db2", "tbl1")
     assert(newestTbl1 == tbl1)
   }
@@ -368,7 +368,7 @@ class SessionCatalogSuite extends SparkFunSuite {
     val sessionCatalog = new SessionCatalog(externalCatalog)
     val tempTable1 = Range(1, 10, 1, 10, Seq())
     val metastoreTable1 = externalCatalog.getTable("db2", "tbl1")
-    sessionCatalog.createTempTable("tbl1", tempTable1, ignoreIfExists = false)
+    sessionCatalog.createTempTable("tbl1", tempTable1, overrideIfExists = false)
     sessionCatalog.setCurrentDatabase("db2")
     // If we explicitly specify the database, we'll look up the relation in that database
     assert(sessionCatalog.lookupRelation(TableIdentifier("tbl1", Some("db2")))
@@ -406,7 +406,7 @@ class SessionCatalogSuite extends SparkFunSuite {
     assert(!catalog.tableExists(TableIdentifier("tbl2", Some("db1"))))
     // If database is explicitly specified, do not check temporary tables
     val tempTable = Range(1, 10, 1, 10, Seq())
-    catalog.createTempTable("tbl3", tempTable, ignoreIfExists = false)
+    catalog.createTempTable("tbl3", tempTable, overrideIfExists = false)
     assert(!catalog.tableExists(TableIdentifier("tbl3", Some("db2"))))
     // If database is not explicitly specified, check the current database
     catalog.setCurrentDatabase("db2")
@@ -418,8 +418,8 @@ class SessionCatalogSuite extends SparkFunSuite {
   test("list tables without pattern") {
     val catalog = new SessionCatalog(newBasicCatalog())
     val tempTable = Range(1, 10, 2, 10, Seq())
-    catalog.createTempTable("tbl1", tempTable, ignoreIfExists = false)
-    catalog.createTempTable("tbl4", tempTable, ignoreIfExists = false)
+    catalog.createTempTable("tbl1", tempTable, overrideIfExists = false)
+    catalog.createTempTable("tbl4", tempTable, overrideIfExists = false)
     assert(catalog.listTables("db1").toSet ==
       Set(TableIdentifier("tbl1"), TableIdentifier("tbl4")))
     assert(catalog.listTables("db2").toSet ==
@@ -435,8 +435,8 @@ class SessionCatalogSuite extends SparkFunSuite {
   test("list tables with pattern") {
     val catalog = new SessionCatalog(newBasicCatalog())
     val tempTable = Range(1, 10, 2, 10, Seq())
-    catalog.createTempTable("tbl1", tempTable, ignoreIfExists = false)
-    catalog.createTempTable("tbl4", tempTable, ignoreIfExists = false)
+    catalog.createTempTable("tbl1", tempTable, overrideIfExists = false)
+    catalog.createTempTable("tbl4", tempTable, overrideIfExists = false)
     assert(catalog.listTables("db1", "*").toSet == catalog.listTables("db1").toSet)
     assert(catalog.listTables("db2", "*").toSet == catalog.listTables("db2").toSet)
     assert(catalog.listTables("db2", "tbl*").toSet ==
@@ -826,7 +826,7 @@ class SessionCatalogSuite extends SparkFunSuite {
     sessionCatalog.createFunction(newFunc("func1", Some("db2")))
     sessionCatalog.renameFunction(FunctionIdentifier("func1"), FunctionIdentifier("func4"))
     assert(sessionCatalog.getTempFunction("func4") ==
-      Some(tempFunc.copy(name = FunctionIdentifier("func4"))))
+      Some(tempFunc.copy(identifier = FunctionIdentifier("func4"))))
     assert(sessionCatalog.getTempFunction("func1") == None)
     assert(externalCatalog.listFunctions("db2", "*").toSet == Set("func1", "func3"))
     // Then, if no such temporary function exist, rename the function in the current database
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
index e413e77bc1..c94600925f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
@@ -671,7 +671,7 @@ class SQLContext private[sql](
     sessionState.catalog.createTempTable(
       sessionState.sqlParser.parseTableIdentifier(tableName).table,
       df.logicalPlan,
-      ignoreIfExists = true)
+      overrideIfExists = true)
   }
 
   /**
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala
index 24923bbb10..877e159fbd 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala
@@ -107,7 +107,7 @@ case class CreateTempTableUsing(
     sqlContext.sessionState.catalog.createTempTable(
       tableIdent.table,
       Dataset.ofRows(sqlContext, LogicalRelation(dataSource.resolveRelation())).logicalPlan,
-      ignoreIfExists = true)
+      overrideIfExists = true)
 
     Seq.empty[Row]
   }
@@ -138,7 +138,7 @@ case class CreateTempTableUsingAsSelect(
     sqlContext.sessionState.catalog.createTempTable(
       tableIdent.table,
       Dataset.ofRows(sqlContext, LogicalRelation(result)).logicalPlan,
-      ignoreIfExists = true)
+      overrideIfExists = true)
 
     Seq.empty[Row]
   }
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveCatalog.scala
deleted file mode 100644
index 0722fb02a8..0000000000
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveCatalog.scala
+++ /dev/null
@@ -1,297 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.hive
-
-import scala.util.control.NonFatal
-
-import org.apache.hadoop.hive.ql.metadata.HiveException
-import org.apache.thrift.TException
-
-import org.apache.spark.internal.Logging
-import org.apache.spark.sql.AnalysisException
-import org.apache.spark.sql.catalyst.TableIdentifier
-import org.apache.spark.sql.catalyst.analysis.NoSuchItemException
-import org.apache.spark.sql.catalyst.catalog._
-import org.apache.spark.sql.hive.client.HiveClient
-
-
-/**
- * A persistent implementation of the system catalog using Hive.
- * All public methods must be synchronized for thread-safety.
- */
-private[spark] class HiveCatalog(client: HiveClient) extends ExternalCatalog with Logging {
-  import ExternalCatalog._
-
-  // Exceptions thrown by the hive client that we would like to wrap
-  private val clientExceptions = Set(
-    classOf[HiveException].getCanonicalName,
-    classOf[TException].getCanonicalName)
-
-  /**
-   * Whether this is an exception thrown by the hive client that should be wrapped.
-   *
-   * Due to classloader isolation issues, pattern matching won't work here so we need
-   * to compare the canonical names of the exceptions, which we assume to be stable.
-   */
-  private def isClientException(e: Throwable): Boolean = {
-    var temp: Class[_] = e.getClass
-    var found = false
-    while (temp != null && !found) {
-      found = clientExceptions.contains(temp.getCanonicalName)
-      temp = temp.getSuperclass
-    }
-    found
-  }
-
-  /**
-   * Run some code involving `client` in a [[synchronized]] block and wrap certain
-   * exceptions thrown in the process in [[AnalysisException]].
-   */
-  private def withClient[T](body: => T): T = synchronized {
-    try {
-      body
-    } catch {
-      case e: NoSuchItemException =>
-        throw new AnalysisException(e.getMessage)
-      case NonFatal(e) if isClientException(e) =>
-        throw new AnalysisException(e.getClass.getCanonicalName + ": " + e.getMessage)
-    }
-  }
-
-  private def requireDbMatches(db: String, table: CatalogTable): Unit = {
-    if (table.name.database != Some(db)) {
-      throw new AnalysisException(
-        s"Provided database $db does not much the one specified in the " +
-        s"table definition (${table.name.database.getOrElse("n/a")})")
-    }
-  }
-
-  private def requireTableExists(db: String, table: String): Unit = {
-    withClient { getTable(db, table) }
-  }
-
-  // --------------------------------------------------------------------------
-  // Databases
-  // --------------------------------------------------------------------------
-
-  override def createDatabase(
-      dbDefinition: CatalogDatabase,
-      ignoreIfExists: Boolean): Unit = withClient {
-    client.createDatabase(dbDefinition, ignoreIfExists)
-  }
-
-  override def dropDatabase(
-      db: String,
-      ignoreIfNotExists: Boolean,
-      cascade: Boolean): Unit = withClient {
-    client.dropDatabase(db, ignoreIfNotExists, cascade)
-  }
-
-  /**
-   * Alter a database whose name matches the one specified in `dbDefinition`,
-   * assuming the database exists.
-   *
-   * Note: As of now, this only supports altering database properties!
-   */
-  override def alterDatabase(dbDefinition: CatalogDatabase): Unit = withClient {
-    val existingDb = getDatabase(dbDefinition.name)
-    if (existingDb.properties == dbDefinition.properties) {
-      logWarning(s"Request to alter database ${dbDefinition.name} is a no-op because " +
-        s"the provided database properties are the same as the old ones. Hive does not " +
-        s"currently support altering other database fields.")
-    }
-    client.alterDatabase(dbDefinition)
-  }
-
-  override def getDatabase(db: String): CatalogDatabase = withClient {
-    client.getDatabase(db)
-  }
-
-  override def databaseExists(db: String): Boolean = withClient {
-    client.getDatabaseOption(db).isDefined
-  }
-
-  override def listDatabases(): Seq[String] = withClient {
-    client.listDatabases("*")
-  }
-
-  override def listDatabases(pattern: String): Seq[String] = withClient {
-    client.listDatabases(pattern)
-  }
-
-  override def setCurrentDatabase(db: String): Unit = withClient {
-    client.setCurrentDatabase(db)
-  }
-
-  // --------------------------------------------------------------------------
-  // Tables
-  // --------------------------------------------------------------------------
-
-  override def createTable(
-      db: String,
-      tableDefinition: CatalogTable,
-      ignoreIfExists: Boolean): Unit = withClient {
-    requireDbExists(db)
-    requireDbMatches(db, tableDefinition)
-    client.createTable(tableDefinition, ignoreIfExists)
-  }
-
-  override def dropTable(
-      db: String,
-      table: String,
-      ignoreIfNotExists: Boolean): Unit = withClient {
-    requireDbExists(db)
-    client.dropTable(db, table, ignoreIfNotExists)
-  }
-
-  override def renameTable(db: String, oldName: String, newName: String): Unit = withClient {
-    val newTable = client.getTable(db, oldName).copy(name = TableIdentifier(newName, Some(db)))
-    client.alterTable(oldName, newTable)
-  }
-
-  /**
-   * Alter a table whose name that matches the one specified in `tableDefinition`,
-   * assuming the table exists.
-   *
-   * Note: As of now, this only supports altering table properties, serde properties,
-   * and num buckets!
-   */
-  override def alterTable(db: String, tableDefinition: CatalogTable): Unit = withClient {
-    requireDbMatches(db, tableDefinition)
-    requireTableExists(db, tableDefinition.name.table)
-    client.alterTable(tableDefinition)
-  }
-
-  override def getTable(db: String, table: String): CatalogTable = withClient {
-    client.getTable(db, table)
-  }
-
-  override def tableExists(db: String, table: String): Boolean = withClient {
-    client.getTableOption(db, table).isDefined
-  }
-
-  override def listTables(db: String): Seq[String] = withClient {
-    requireDbExists(db)
-    client.listTables(db)
-  }
-
-  override def listTables(db: String, pattern: String): Seq[String] = withClient {
-    requireDbExists(db)
-    client.listTables(db, pattern)
-  }
-
-  // --------------------------------------------------------------------------
-  // Partitions
-  // --------------------------------------------------------------------------
-
-  override def createPartitions(
-      db: String,
-      table: String,
-      parts: Seq[CatalogTablePartition],
-      ignoreIfExists: Boolean): Unit = withClient {
-    requireTableExists(db, table)
-    client.createPartitions(db, table, parts, ignoreIfExists)
-  }
-
-  override def dropPartitions(
-      db: String,
-      table: String,
-      parts: Seq[TablePartitionSpec],
-      ignoreIfNotExists: Boolean): Unit = withClient {
-    requireTableExists(db, table)
-    // Note: Unfortunately Hive does not currently support `ignoreIfNotExists` so we
-    // need to implement it here ourselves. This is currently somewhat expensive because
-    // we make multiple synchronous calls to Hive for each partition we want to drop.
-    val partsToDrop =
-      if (ignoreIfNotExists) {
-        parts.filter { spec =>
-          try {
-            getPartition(db, table, spec)
-            true
-          } catch {
-            // Filter out the partitions that do not actually exist
-            case _: AnalysisException => false
-          }
-        }
-      } else {
-        parts
-      }
-    if (partsToDrop.nonEmpty) {
-      client.dropPartitions(db, table, partsToDrop)
-    }
-  }
-
-  override def renamePartitions(
-      db: String,
-      table: String,
-      specs: Seq[TablePartitionSpec],
-      newSpecs: Seq[TablePartitionSpec]): Unit = withClient {
-    client.renamePartitions(db, table, specs, newSpecs)
-  }
-
-  override def alterPartitions(
-      db: String,
-      table: String,
-      newParts: Seq[CatalogTablePartition]): Unit = withClient {
-    client.alterPartitions(db, table, newParts)
-  }
-
-  override def getPartition(
-      db: String,
-      table: String,
-      spec: TablePartitionSpec): CatalogTablePartition = withClient {
-    client.getPartition(db, table, spec)
-  }
-
-  override def listPartitions(
-      db: String,
-      table: String): Seq[CatalogTablePartition] = withClient {
-    client.getAllPartitions(db, table)
-  }
-
-  // --------------------------------------------------------------------------
-  // Functions
-  // --------------------------------------------------------------------------
-
-  override def createFunction(
-      db: String,
-      funcDefinition: CatalogFunction): Unit = withClient {
-    client.createFunction(db, funcDefinition)
-  }
-
-  override def dropFunction(db: String, name: String): Unit = withClient {
-    client.dropFunction(db, name)
-  }
-
-  override def renameFunction(db: String, oldName: String, newName: String): Unit = withClient {
-    client.renameFunction(db, oldName, newName)
-  }
-
-  override def alterFunction(db: String, funcDefinition: CatalogFunction): Unit = withClient {
-    client.alterFunction(db, funcDefinition)
-  }
-
-  override def getFunction(db: String, funcName: String): CatalogFunction = withClient {
-    client.getFunction(db, funcName)
-  }
-
-  override def listFunctions(db: String, pattern: String): Seq[String] = withClient {
-    client.listFunctions(db, pattern)
-  }
-
-}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
index ca3ce43591..c0b6d16d3c 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
@@ -86,7 +86,7 @@ class HiveContext private[hive](
     @transient private[hive] val executionHive: HiveClientImpl,
     @transient private[hive] val metadataHive: HiveClient,
     isRootContext: Boolean,
-    @transient private[sql] val hiveCatalog: HiveCatalog)
+    @transient private[sql] val hiveCatalog: HiveExternalCatalog)
   extends SQLContext(sc, cacheManager, listener, isRootContext, hiveCatalog) with Logging {
   self =>
 
@@ -98,7 +98,7 @@ class HiveContext private[hive](
       execHive,
       metaHive,
       true,
-      new HiveCatalog(metaHive))
+      new HiveExternalCatalog(metaHive))
   }
 
   def this(sc: SparkContext) = {
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala
new file mode 100644
index 0000000000..f75509fe80
--- /dev/null
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala
@@ -0,0 +1,298 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.hive
+
+import scala.util.control.NonFatal
+
+import org.apache.hadoop.hive.ql.metadata.HiveException
+import org.apache.thrift.TException
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.AnalysisException
+import org.apache.spark.sql.catalyst.TableIdentifier
+import org.apache.spark.sql.catalyst.analysis.NoSuchItemException
+import org.apache.spark.sql.catalyst.catalog._
+import org.apache.spark.sql.hive.client.HiveClient
+
+
+/**
+ * A persistent implementation of the system catalog using Hive.
+ * All public methods must be synchronized for thread-safety.
+ */
+private[spark] class HiveExternalCatalog(client: HiveClient) extends ExternalCatalog with Logging {
+  import ExternalCatalog._
+
+  // Exceptions thrown by the hive client that we would like to wrap
+  private val clientExceptions = Set(
+    classOf[HiveException].getCanonicalName,
+    classOf[TException].getCanonicalName)
+
+  /**
+   * Whether this is an exception thrown by the hive client that should be wrapped.
+   *
+   * Due to classloader isolation issues, pattern matching won't work here so we need
+   * to compare the canonical names of the exceptions, which we assume to be stable.
+   */
+  private def isClientException(e: Throwable): Boolean = {
+    var temp: Class[_] = e.getClass
+    var found = false
+    while (temp != null && !found) {
+      found = clientExceptions.contains(temp.getCanonicalName)
+      temp = temp.getSuperclass
+    }
+    found
+  }
+
+  /**
+   * Run some code involving `client` in a [[synchronized]] block and wrap certain
+   * exceptions thrown in the process in [[AnalysisException]].
+   */
+  private def withClient[T](body: => T): T = synchronized {
+    try {
+      body
+    } catch {
+      case e: NoSuchItemException =>
+        throw new AnalysisException(e.getMessage)
+      case NonFatal(e) if isClientException(e) =>
+        throw new AnalysisException(e.getClass.getCanonicalName + ": " + e.getMessage)
+    }
+  }
+
+  private def requireDbMatches(db: String, table: CatalogTable): Unit = {
+    if (table.identifier.database != Some(db)) {
+      throw new AnalysisException(
+        s"Provided database $db does not much the one specified in the " +
+        s"table definition (${table.identifier.database.getOrElse("n/a")})")
+    }
+  }
+
+  private def requireTableExists(db: String, table: String): Unit = {
+    withClient { getTable(db, table) }
+  }
+
+  // --------------------------------------------------------------------------
+  // Databases
+  // --------------------------------------------------------------------------
+
+  override def createDatabase(
+      dbDefinition: CatalogDatabase,
+      ignoreIfExists: Boolean): Unit = withClient {
+    client.createDatabase(dbDefinition, ignoreIfExists)
+  }
+
+  override def dropDatabase(
+      db: String,
+      ignoreIfNotExists: Boolean,
+      cascade: Boolean): Unit = withClient {
+    client.dropDatabase(db, ignoreIfNotExists, cascade)
+  }
+
+  /**
+   * Alter a database whose name matches the one specified in `dbDefinition`,
+   * assuming the database exists.
+   *
+   * Note: As of now, this only supports altering database properties!
+   */
+  override def alterDatabase(dbDefinition: CatalogDatabase): Unit = withClient {
+    val existingDb = getDatabase(dbDefinition.name)
+    if (existingDb.properties == dbDefinition.properties) {
+      logWarning(s"Request to alter database ${dbDefinition.name} is a no-op because " +
+        s"the provided database properties are the same as the old ones. Hive does not " +
+        s"currently support altering other database fields.")
+    }
+    client.alterDatabase(dbDefinition)
+  }
+
+  override def getDatabase(db: String): CatalogDatabase = withClient {
+    client.getDatabase(db)
+  }
+
+  override def databaseExists(db: String): Boolean = withClient {
+    client.getDatabaseOption(db).isDefined
+  }
+
+  override def listDatabases(): Seq[String] = withClient {
+    client.listDatabases("*")
+  }
+
+  override def listDatabases(pattern: String): Seq[String] = withClient {
+    client.listDatabases(pattern)
+  }
+
+  override def setCurrentDatabase(db: String): Unit = withClient {
+    client.setCurrentDatabase(db)
+  }
+
+  // --------------------------------------------------------------------------
+  // Tables
+  // --------------------------------------------------------------------------
+
+  override def createTable(
+      db: String,
+      tableDefinition: CatalogTable,
+      ignoreIfExists: Boolean): Unit = withClient {
+    requireDbExists(db)
+    requireDbMatches(db, tableDefinition)
+    client.createTable(tableDefinition, ignoreIfExists)
+  }
+
+  override def dropTable(
+      db: String,
+      table: String,
+      ignoreIfNotExists: Boolean): Unit = withClient {
+    requireDbExists(db)
+    client.dropTable(db, table, ignoreIfNotExists)
+  }
+
+  override def renameTable(db: String, oldName: String, newName: String): Unit = withClient {
+    val newTable = client.getTable(db, oldName)
+      .copy(identifier = TableIdentifier(newName, Some(db)))
+    client.alterTable(oldName, newTable)
+  }
+
+  /**
+   * Alter a table whose name that matches the one specified in `tableDefinition`,
+   * assuming the table exists.
+   *
+   * Note: As of now, this only supports altering table properties, serde properties,
+   * and num buckets!
+   */
+  override def alterTable(db: String, tableDefinition: CatalogTable): Unit = withClient {
+    requireDbMatches(db, tableDefinition)
+    requireTableExists(db, tableDefinition.identifier.table)
+    client.alterTable(tableDefinition)
+  }
+
+  override def getTable(db: String, table: String): CatalogTable = withClient {
+    client.getTable(db, table)
+  }
+
+  override def tableExists(db: String, table: String): Boolean = withClient {
+    client.getTableOption(db, table).isDefined
+  }
+
+  override def listTables(db: String): Seq[String] = withClient {
+    requireDbExists(db)
+    client.listTables(db)
+  }
+
+  override def listTables(db: String, pattern: String): Seq[String] = withClient {
+    requireDbExists(db)
+    client.listTables(db, pattern)
+  }
+
+  // --------------------------------------------------------------------------
+  // Partitions
+  // --------------------------------------------------------------------------
+
+  override def createPartitions(
+      db: String,
+      table: String,
+      parts: Seq[CatalogTablePartition],
+      ignoreIfExists: Boolean): Unit = withClient {
+    requireTableExists(db, table)
+    client.createPartitions(db, table, parts, ignoreIfExists)
+  }
+
+  override def dropPartitions(
+      db: String,
+      table: String,
+      parts: Seq[TablePartitionSpec],
+      ignoreIfNotExists: Boolean): Unit = withClient {
+    requireTableExists(db, table)
+    // Note: Unfortunately Hive does not currently support `ignoreIfNotExists` so we
+    // need to implement it here ourselves. This is currently somewhat expensive because
+    // we make multiple synchronous calls to Hive for each partition we want to drop.
+    val partsToDrop =
+      if (ignoreIfNotExists) {
+        parts.filter { spec =>
+          try {
+            getPartition(db, table, spec)
+            true
+          } catch {
+            // Filter out the partitions that do not actually exist
+            case _: AnalysisException => false
+          }
+        }
+      } else {
+        parts
+      }
+    if (partsToDrop.nonEmpty) {
+      client.dropPartitions(db, table, partsToDrop)
+    }
+  }
+
+  override def renamePartitions(
+      db: String,
+      table: String,
+      specs: Seq[TablePartitionSpec],
+      newSpecs: Seq[TablePartitionSpec]): Unit = withClient {
+    client.renamePartitions(db, table, specs, newSpecs)
+  }
+
+  override def alterPartitions(
+      db: String,
+      table: String,
+      newParts: Seq[CatalogTablePartition]): Unit = withClient {
+    client.alterPartitions(db, table, newParts)
+  }
+
+  override def getPartition(
+      db: String,
+      table: String,
+      spec: TablePartitionSpec): CatalogTablePartition = withClient {
+    client.getPartition(db, table, spec)
+  }
+
+  override def listPartitions(
+      db: String,
+      table: String): Seq[CatalogTablePartition] = withClient {
+    client.getAllPartitions(db, table)
+  }
+
+  // --------------------------------------------------------------------------
+  // Functions
+  // --------------------------------------------------------------------------
+
+  override def createFunction(
+      db: String,
+      funcDefinition: CatalogFunction): Unit = withClient {
+    client.createFunction(db, funcDefinition)
+  }
+
+  override def dropFunction(db: String, name: String): Unit = withClient {
+    client.dropFunction(db, name)
+  }
+
+  override def renameFunction(db: String, oldName: String, newName: String): Unit = withClient {
+    client.renameFunction(db, oldName, newName)
+  }
+
+  override def alterFunction(db: String, funcDefinition: CatalogFunction): Unit = withClient {
+    client.alterFunction(db, funcDefinition)
+  }
+
+  override def getFunction(db: String, funcName: String): CatalogFunction = withClient {
+    client.getFunction(db, funcName)
+  }
+
+  override def listFunctions(db: String, pattern: String): Seq[String] = withClient {
+    client.listFunctions(db, pattern)
+  }
+
+}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala
index c7066d7363..eedd12d76a 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala
@@ -102,7 +102,7 @@ private[hive] object HiveSerDe {
  * Legacy catalog for interacting with the Hive metastore.
  *
  * This is still used for things like creating data source tables, but in the future will be
- * cleaned up to integrate more nicely with [[HiveCatalog]].
+ * cleaned up to integrate more nicely with [[HiveExternalCatalog]].
  */
 private[hive] class HiveMetastoreCatalog(val client: HiveClient, hive: HiveContext)
   extends Logging {
@@ -124,8 +124,8 @@ private[hive] class HiveMetastoreCatalog(val client: HiveClient, hive: HiveConte
 
   private def getQualifiedTableName(t: CatalogTable): QualifiedTableName = {
     QualifiedTableName(
-      t.name.database.getOrElse(getCurrentDatabase).toLowerCase,
-      t.name.table.toLowerCase)
+      t.identifier.database.getOrElse(getCurrentDatabase).toLowerCase,
+      t.identifier.table.toLowerCase)
   }
 
   /** A cache of Spark SQL data source tables that have been accessed. */
@@ -299,7 +299,7 @@ private[hive] class HiveMetastoreCatalog(val client: HiveClient, hive: HiveConte
 
     def newSparkSQLSpecificMetastoreTable(): CatalogTable = {
       CatalogTable(
-        name = TableIdentifier(tblName, Option(dbName)),
+        identifier = TableIdentifier(tblName, Option(dbName)),
         tableType = tableType,
         schema = Nil,
         storage = CatalogStorageFormat(
@@ -319,7 +319,7 @@ private[hive] class HiveMetastoreCatalog(val client: HiveClient, hive: HiveConte
       assert(relation.partitionSchema.isEmpty)
 
       CatalogTable(
-        name = TableIdentifier(tblName, Option(dbName)),
+        identifier = TableIdentifier(tblName, Option(dbName)),
         tableType = tableType,
         storage = CatalogStorageFormat(
           locationUri = Some(relation.location.paths.map(_.toUri.toString).head),
@@ -431,7 +431,7 @@ private[hive] class HiveMetastoreCatalog(val client: HiveClient, hive: HiveConte
       alias match {
         // because hive use things like `_c0` to build the expanded text
         // currently we cannot support view from "create view v1(c1) as ..."
-        case None => SubqueryAlias(table.name.table, hive.parseSql(viewText))
+        case None => SubqueryAlias(table.identifier.table, hive.parseSql(viewText))
         case Some(aliasText) => SubqueryAlias(aliasText, hive.parseSql(viewText))
       }
     } else {
@@ -611,7 +611,7 @@ private[hive] class HiveMetastoreCatalog(val client: HiveClient, hive: HiveConte
         val QualifiedTableName(dbName, tblName) = getQualifiedTableName(table)
 
         execution.CreateViewAsSelect(
-          table.copy(name = TableIdentifier(tblName, Some(dbName))),
+          table.copy(identifier = TableIdentifier(tblName, Some(dbName))),
           child,
           allowExisting,
           replace)
@@ -633,7 +633,7 @@ private[hive] class HiveMetastoreCatalog(val client: HiveClient, hive: HiveConte
         if (hive.convertCTAS && table.storage.serde.isEmpty) {
           // Do the conversion when spark.sql.hive.convertCTAS is true and the query
           // does not specify any storage format (file format and storage handler).
-          if (table.name.database.isDefined) {
+          if (table.identifier.database.isDefined) {
             throw new AnalysisException(
               "Cannot specify database name in a CTAS statement " +
                 "when spark.sql.hive.convertCTAS is set to true.")
@@ -641,7 +641,7 @@ private[hive] class HiveMetastoreCatalog(val client: HiveClient, hive: HiveConte
 
           val mode = if (allowExisting) SaveMode.Ignore else SaveMode.ErrorIfExists
           CreateTableUsingAsSelect(
-            TableIdentifier(desc.name.table),
+            TableIdentifier(desc.identifier.table),
             conf.defaultDataSourceName,
             temporary = false,
             Array.empty[String],
@@ -662,7 +662,7 @@ private[hive] class HiveMetastoreCatalog(val client: HiveClient, hive: HiveConte
           val QualifiedTableName(dbName, tblName) = getQualifiedTableName(table)
 
           execution.CreateTableAsSelect(
-            desc.copy(name = TableIdentifier(tblName, Some(dbName))),
+            desc.copy(identifier = TableIdentifier(tblName, Some(dbName))),
             child,
             allowExisting)
         }
@@ -792,7 +792,7 @@ private[hive] case class MetastoreRelation(
     // We start by constructing an API table as Hive performs several important transformations
     // internally when converting an API table to a QL table.
     val tTable = new org.apache.hadoop.hive.metastore.api.Table()
-    tTable.setTableName(table.name.table)
+    tTable.setTableName(table.identifier.table)
     tTable.setDbName(table.database)
 
     val tableParameters = new java.util.HashMap[String, String]()
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
index e5bcb9b1db..b3ec95fc73 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
@@ -60,7 +60,7 @@ private[hive] case class CreateTableAsSelect(
 
   override def output: Seq[Attribute] = Seq.empty[Attribute]
   override lazy val resolved: Boolean =
-    tableDesc.name.database.isDefined &&
+    tableDesc.identifier.database.isDefined &&
     tableDesc.schema.nonEmpty &&
     tableDesc.storage.serde.isDefined &&
     tableDesc.storage.inputFormat.isDefined &&
@@ -183,7 +183,7 @@ private[hive] class HiveQl(conf: ParserConf) extends SparkQl(conf) with Logging
     val tableIdentifier = extractTableIdent(viewNameParts)
     val originalText = query.source
     val tableDesc = CatalogTable(
-      name = tableIdentifier,
+      identifier = tableIdentifier,
       tableType = CatalogTableType.VIRTUAL_VIEW,
       schema = schema,
       storage = CatalogStorageFormat(
@@ -352,7 +352,7 @@ private[hive] class HiveQl(conf: ParserConf) extends SparkQl(conf) with Logging
 
         // TODO add bucket support
         var tableDesc: CatalogTable = CatalogTable(
-          name = tableIdentifier,
+          identifier = tableIdentifier,
           tableType =
             if (externalTable.isDefined) {
               CatalogTableType.EXTERNAL_TABLE
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala
index aa44cba4b5..ec7bf61be1 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala
@@ -28,7 +28,7 @@ import org.apache.spark.sql.types.StructType
 
 
 class HiveSessionCatalog(
-    externalCatalog: HiveCatalog,
+    externalCatalog: HiveExternalCatalog,
     client: HiveClient,
     context: HiveContext,
     conf: SQLConf)
@@ -41,11 +41,11 @@ class HiveSessionCatalog(
 
   override def lookupRelation(name: TableIdentifier, alias: Option[String]): LogicalPlan = {
     val table = formatTableName(name.table)
-    if (name.database.isDefined || !tempTables.containsKey(table)) {
+    if (name.database.isDefined || !tempTables.contains(table)) {
       val newName = name.copy(table = table)
       metastoreCatalog.lookupRelation(newName, alias)
     } else {
-      val relation = tempTables.get(table)
+      val relation = tempTables(table)
       val tableWithQualifiers = SubqueryAlias(table, relation)
       // If an alias was specified by the lookup, wrap the plan in a subquery so that
       // attributes are properly qualified with this alias.
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClient.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClient.scala
index f4d30358ca..ee56f9d75d 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClient.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClient.scala
@@ -88,7 +88,7 @@ private[hive] trait HiveClient {
   def dropTable(dbName: String, tableName: String, ignoreIfNotExists: Boolean): Unit
 
   /** Alter a table whose name matches the one specified in `table`, assuming it exists. */
-  final def alterTable(table: CatalogTable): Unit = alterTable(table.name.table, table)
+  final def alterTable(table: CatalogTable): Unit = alterTable(table.identifier.table, table)
 
   /** Updates the given table with new metadata, optionally renaming the table. */
   def alterTable(tableName: String, table: CatalogTable): Unit
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala
index e4e15d13df..a31178e347 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala
@@ -298,7 +298,7 @@ private[hive] class HiveClientImpl(
     logDebug(s"Looking up $dbName.$tableName")
     Option(client.getTable(dbName, tableName, false)).map { h =>
       CatalogTable(
-        name = TableIdentifier(h.getTableName, Option(h.getDbName)),
+        identifier = TableIdentifier(h.getTableName, Option(h.getDbName)),
         tableType = h.getTableType match {
           case HiveTableType.EXTERNAL_TABLE => CatalogTableType.EXTERNAL_TABLE
           case HiveTableType.MANAGED_TABLE => CatalogTableType.MANAGED_TABLE
@@ -544,13 +544,14 @@ private[hive] class HiveClientImpl(
   }
 
   override def renameFunction(db: String, oldName: String, newName: String): Unit = withHiveState {
-    val catalogFunc = getFunction(db, oldName).copy(name = FunctionIdentifier(newName, Some(db)))
+    val catalogFunc = getFunction(db, oldName)
+      .copy(identifier = FunctionIdentifier(newName, Some(db)))
     val hiveFunc = toHiveFunction(catalogFunc, db)
     client.alterFunction(db, oldName, hiveFunc)
   }
 
   override def alterFunction(db: String, func: CatalogFunction): Unit = withHiveState {
-    client.alterFunction(db, func.name.funcName, toHiveFunction(func, db))
+    client.alterFunction(db, func.identifier.funcName, toHiveFunction(func, db))
   }
 
   override def getFunctionOption(
@@ -611,7 +612,7 @@ private[hive] class HiveClientImpl(
 
   private def toHiveFunction(f: CatalogFunction, db: String): HiveFunction = {
     new HiveFunction(
-      f.name.funcName,
+      f.identifier.funcName,
       db,
       f.className,
       null,
@@ -639,7 +640,7 @@ private[hive] class HiveClientImpl(
   }
 
   private def toHiveTable(table: CatalogTable): HiveTable = {
-    val hiveTable = new HiveTable(table.database, table.name.table)
+    val hiveTable = new HiveTable(table.database, table.identifier.table)
     hiveTable.setTableType(table.tableType match {
       case CatalogTableType.EXTERNAL_TABLE => HiveTableType.EXTERNAL_TABLE
       case CatalogTableType.MANAGED_TABLE => HiveTableType.MANAGED_TABLE
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateTableAsSelect.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateTableAsSelect.scala
index 5a61eef0f2..29f7dc2997 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateTableAsSelect.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateTableAsSelect.scala
@@ -38,7 +38,7 @@ case class CreateTableAsSelect(
     allowExisting: Boolean)
   extends RunnableCommand {
 
-  private val tableIdentifier = tableDesc.name
+  private val tableIdentifier = tableDesc.identifier
 
   override def children: Seq[LogicalPlan] = Seq(query)
 
@@ -93,6 +93,8 @@ case class CreateTableAsSelect(
   }
 
   override def argString: String = {
-    s"[Database:${tableDesc.database}}, TableName: ${tableDesc.name.table}, InsertIntoHiveTable]"
+    s"[Database:${tableDesc.database}}, " +
+    s"TableName: ${tableDesc.identifier.table}, " +
+    s"InsertIntoHiveTable]"
   }
 }
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateViewAsSelect.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateViewAsSelect.scala
index 9ff520da1d..33cd8b4480 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateViewAsSelect.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateViewAsSelect.scala
@@ -44,7 +44,7 @@ private[hive] case class CreateViewAsSelect(
   assert(tableDesc.schema == Nil || tableDesc.schema.length == childSchema.length)
   assert(tableDesc.viewText.isDefined)
 
-  private val tableIdentifier = tableDesc.name
+  private val tableIdentifier = tableDesc.identifier
 
   override def run(sqlContext: SQLContext): Seq[Row] = {
     val hiveContext = sqlContext.asInstanceOf[HiveContext]
@@ -116,7 +116,7 @@ private[hive] case class CreateViewAsSelect(
     }
 
     val viewText = tableDesc.viewText.get
-    val viewName = quote(tableDesc.name.table)
+    val viewName = quote(tableDesc.identifier.table)
     s"SELECT $viewOutput FROM ($viewText) $viewName"
   }
 
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala
index a1785ca038..4afc8d18a6 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala
@@ -78,7 +78,7 @@ class TestHiveContext private[hive](
     executionHive: HiveClientImpl,
     metadataHive: HiveClient,
     isRootContext: Boolean,
-    hiveCatalog: HiveCatalog,
+    hiveCatalog: HiveExternalCatalog,
     val warehousePath: File,
     val scratchDirPath: File,
     metastoreTemporaryConf: Map[String, String])
@@ -110,7 +110,7 @@ class TestHiveContext private[hive](
       executionHive,
       metadataHive,
       true,
-      new HiveCatalog(metadataHive),
+      new HiveExternalCatalog(metadataHive),
       warehousePath,
       scratchDirPath,
       metastoreTemporaryConf)
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveCatalogSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveCatalogSuite.scala
deleted file mode 100644
index 427f5747a0..0000000000
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveCatalogSuite.scala
+++ /dev/null
@@ -1,49 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.hive
-
-import org.apache.hadoop.conf.Configuration
-import org.apache.hadoop.util.VersionInfo
-
-import org.apache.spark.SparkConf
-import org.apache.spark.sql.catalyst.catalog._
-import org.apache.spark.sql.hive.client.{HiveClient, IsolatedClientLoader}
-import org.apache.spark.util.Utils
-
-/**
- * Test suite for the [[HiveCatalog]].
- */
-class HiveCatalogSuite extends CatalogTestCases {
-
-  private val client: HiveClient = {
-    IsolatedClientLoader.forVersion(
-      hiveMetastoreVersion = HiveContext.hiveExecutionVersion,
-      hadoopVersion = VersionInfo.getVersion,
-      sparkConf = new SparkConf(),
-      hadoopConf = new Configuration()).createClient()
-  }
-
-  protected override val utils: CatalogTestUtils = new CatalogTestUtils {
-    override val tableInputFormat: String = "org.apache.hadoop.mapred.SequenceFileInputFormat"
-    override val tableOutputFormat: String = "org.apache.hadoop.mapred.SequenceFileOutputFormat"
-    override def newEmptyCatalog(): ExternalCatalog = new HiveCatalog(client)
-  }
-
-  protected override def resetState(): Unit = client.reset()
-
-}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogSuite.scala
new file mode 100644
index 0000000000..3334c16f0b
--- /dev/null
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogSuite.scala
@@ -0,0 +1,49 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.hive
+
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.util.VersionInfo
+
+import org.apache.spark.SparkConf
+import org.apache.spark.sql.catalyst.catalog._
+import org.apache.spark.sql.hive.client.{HiveClient, IsolatedClientLoader}
+import org.apache.spark.util.Utils
+
+/**
+ * Test suite for the [[HiveExternalCatalog]].
+ */
+class HiveExternalCatalogSuite extends CatalogTestCases {
+
+  private val client: HiveClient = {
+    IsolatedClientLoader.forVersion(
+      hiveMetastoreVersion = HiveContext.hiveExecutionVersion,
+      hadoopVersion = VersionInfo.getVersion,
+      sparkConf = new SparkConf(),
+      hadoopConf = new Configuration()).createClient()
+  }
+
+  protected override val utils: CatalogTestUtils = new CatalogTestUtils {
+    override val tableInputFormat: String = "org.apache.hadoop.mapred.SequenceFileInputFormat"
+    override val tableOutputFormat: String = "org.apache.hadoop.mapred.SequenceFileOutputFormat"
+    override def newEmptyCatalog(): ExternalCatalog = new HiveExternalCatalog(client)
+  }
+
+  protected override def resetState(): Unit = client.reset()
+
+}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveQlSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveQlSuite.scala
index 1c775db9b6..0aaf57649c 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveQlSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveQlSuite.scala
@@ -54,8 +54,8 @@ class HiveQlSuite extends SparkFunSuite with BeforeAndAfterAll {
 
     val (desc, exists) = extractTableDesc(s1)
     assert(exists)
-    assert(desc.name.database == Some("mydb"))
-    assert(desc.name.table == "page_view")
+    assert(desc.identifier.database == Some("mydb"))
+    assert(desc.identifier.table == "page_view")
     assert(desc.tableType == CatalogTableType.EXTERNAL_TABLE)
     assert(desc.storage.locationUri == Some("/user/external/page_view"))
     assert(desc.schema ==
@@ -100,8 +100,8 @@ class HiveQlSuite extends SparkFunSuite with BeforeAndAfterAll {
 
     val (desc, exists) = extractTableDesc(s2)
     assert(exists)
-    assert(desc.name.database == Some("mydb"))
-    assert(desc.name.table == "page_view")
+    assert(desc.identifier.database == Some("mydb"))
+    assert(desc.identifier.table == "page_view")
     assert(desc.tableType == CatalogTableType.EXTERNAL_TABLE)
     assert(desc.storage.locationUri == Some("/user/external/page_view"))
     assert(desc.schema ==
@@ -127,8 +127,8 @@ class HiveQlSuite extends SparkFunSuite with BeforeAndAfterAll {
     val s3 = """CREATE TABLE page_view AS SELECT * FROM src"""
     val (desc, exists) = extractTableDesc(s3)
     assert(exists == false)
-    assert(desc.name.database == None)
-    assert(desc.name.table == "page_view")
+    assert(desc.identifier.database == None)
+    assert(desc.identifier.table == "page_view")
     assert(desc.tableType == CatalogTableType.MANAGED_TABLE)
     assert(desc.storage.locationUri == None)
     assert(desc.schema == Seq.empty[CatalogColumn])
@@ -162,8 +162,8 @@ class HiveQlSuite extends SparkFunSuite with BeforeAndAfterAll {
                |   ORDER BY key, value""".stripMargin
     val (desc, exists) = extractTableDesc(s5)
     assert(exists == false)
-    assert(desc.name.database == None)
-    assert(desc.name.table == "ctas2")
+    assert(desc.identifier.database == None)
+    assert(desc.identifier.table == "ctas2")
     assert(desc.tableType == CatalogTableType.MANAGED_TABLE)
     assert(desc.storage.locationUri == None)
     assert(desc.schema == Seq.empty[CatalogColumn])
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ListTablesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ListTablesSuite.scala
index 5272f4192e..e8188e5f02 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ListTablesSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ListTablesSuite.scala
@@ -34,7 +34,7 @@ class ListTablesSuite extends QueryTest with TestHiveSingleton with BeforeAndAft
     super.beforeAll()
     // The catalog in HiveContext is a case insensitive one.
     sessionState.catalog.createTempTable(
-      "ListTablesSuiteTable", df.logicalPlan, ignoreIfExists = true)
+      "ListTablesSuiteTable", df.logicalPlan, overrideIfExists = true)
     sql("CREATE TABLE HiveListTablesSuiteTable (key int, value string)")
     sql("CREATE DATABASE IF NOT EXISTS ListTablesSuiteDB")
     sql("CREATE TABLE ListTablesSuiteDB.HiveInDBListTablesSuiteTable (key int, value string)")
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala
index 71652897e6..3c299daa77 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala
@@ -722,7 +722,7 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv
     withTable(tableName) {
       val schema = StructType(StructField("int", IntegerType, true) :: Nil)
       val hiveTable = CatalogTable(
-        name = TableIdentifier(tableName, Some("default")),
+        identifier = TableIdentifier(tableName, Some("default")),
         tableType = CatalogTableType.MANAGED_TABLE,
         schema = Seq.empty,
         storage = CatalogStorageFormat(
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala
index d59bca4c7e..8b0719209d 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala
@@ -148,7 +148,7 @@ class VersionsSuite extends SparkFunSuite with Logging {
     test(s"$version: createTable") {
       val table =
         CatalogTable(
-          name = TableIdentifier("src", Some("default")),
+          identifier = TableIdentifier("src", Some("default")),
           tableType = CatalogTableType.MANAGED_TABLE,
           schema = Seq(CatalogColumn("key", "int")),
           storage = CatalogStorageFormat(
-- 
cgit v1.2.3


From b7836492bb0b5b430539d2bfa20bcc32e3fe3504 Mon Sep 17 00:00:00 2001
From: Reynold Xin 
Date: Mon, 28 Mar 2016 16:26:32 -0700
Subject: [SPARK-14155][SQL] Hide UserDefinedType interface in Spark 2.0

## What changes were proposed in this pull request?
UserDefinedType is a developer API in Spark 1.x. With very high probability we will create a new API for user-defined type that also works well with column batches as well as encoders (datasets). In Spark 2.0, let's make `UserDefinedType` `private[spark]` first.

## How was this patch tested?
Existing unit tests.

Author: Reynold Xin 

Closes #11955 from rxin/SPARK-14155.
---
 .../src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala | 6 ++++--
 1 file changed, 4 insertions(+), 2 deletions(-)

diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala
index dabf9a2fc0..fb7251d71b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala
@@ -23,7 +23,6 @@ import org.json4s.JsonDSL._
 import org.apache.spark.annotation.DeveloperApi
 
 /**
- * ::DeveloperApi::
  * The data type for User Defined Types (UDTs).
  *
  * This interface allows a user to make their own classes more interoperable with SparkSQL;
@@ -35,8 +34,11 @@ import org.apache.spark.annotation.DeveloperApi
  *
  * The conversion via `serialize` occurs when instantiating a `DataFrame` from another RDD.
  * The conversion via `deserialize` occurs when reading from a `DataFrame`.
+ *
+ * Note: This was previously a developer API in Spark 1.x. We are making this private in Spark 2.0
+ * because we will very likely create a new version of this that works better with Datasets.
  */
-@DeveloperApi
+private[spark]
 abstract class UserDefinedType[UserType >: Null] extends DataType with Serializable {
 
   /** Underlying storage type for this UDT */
-- 
cgit v1.2.3


From 2f98ee67dff0be38a4c92d7d29c8cc8ea8b6576e Mon Sep 17 00:00:00 2001
From: Shixiong Zhu 
Date: Mon, 28 Mar 2016 16:29:11 -0700
Subject: [SPARK-14169][CORE] Add UninterruptibleThread

## What changes were proposed in this pull request?

Extract the workaround for HADOOP-10622 introduced by #11940 into UninterruptibleThread so that we can test and reuse it.

## How was this patch tested?

Unit tests

Author: Shixiong Zhu 

Closes #11971 from zsxwing/uninterrupt.
---
 .../apache/spark/util/UninterruptibleThread.scala  | 112 +++++++++++++++
 .../spark/util/UninterruptibleThreadSuite.scala    | 159 +++++++++++++++++++++
 .../sql/execution/streaming/StreamExecution.scala  |  74 ++--------
 3 files changed, 279 insertions(+), 66 deletions(-)
 create mode 100644 core/src/main/scala/org/apache/spark/util/UninterruptibleThread.scala
 create mode 100644 core/src/test/scala/org/apache/spark/util/UninterruptibleThreadSuite.scala

diff --git a/core/src/main/scala/org/apache/spark/util/UninterruptibleThread.scala b/core/src/main/scala/org/apache/spark/util/UninterruptibleThread.scala
new file mode 100644
index 0000000000..4dcf95177a
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/util/UninterruptibleThread.scala
@@ -0,0 +1,112 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util
+
+import javax.annotation.concurrent.GuardedBy
+
+/**
+ * A special Thread that provides "runUninterruptibly" to allow running codes without being
+ * interrupted by `Thread.interrupt()`. If `Thread.interrupt()` is called during runUninterruptibly
+ * is running, it won't set the interrupted status. Instead, setting the interrupted status will be
+ * deferred until it's returning from "runUninterruptibly".
+ *
+ * Note: "runUninterruptibly" should be called only in `this` thread.
+ */
+private[spark] class UninterruptibleThread(name: String) extends Thread(name) {
+
+  /** A monitor to protect "uninterruptible" and "interrupted" */
+  private val uninterruptibleLock = new Object
+
+  /**
+   * Indicates if `this`  thread are in the uninterruptible status. If so, interrupting
+   * "this" will be deferred until `this`  enters into the interruptible status.
+   */
+  @GuardedBy("uninterruptibleLock")
+  private var uninterruptible = false
+
+  /**
+   * Indicates if we should interrupt `this` when we are leaving the uninterruptible zone.
+   */
+  @GuardedBy("uninterruptibleLock")
+  private var shouldInterruptThread = false
+
+  /**
+   * Run `f` uninterruptibly in `this` thread. The thread won't be interrupted before returning
+   * from `f`.
+   *
+   * If this method finds that `interrupt` is called before calling `f` and it's not inside another
+   * `runUninterruptibly`, it will throw `InterruptedException`.
+   *
+   * Note: this method should be called only in `this` thread.
+   */
+  def runUninterruptibly[T](f: => T): T = {
+    if (Thread.currentThread() != this) {
+      throw new IllegalStateException(s"Call runUninterruptibly in a wrong thread. " +
+        s"Expected: $this but was ${Thread.currentThread()}")
+    }
+
+    if (uninterruptibleLock.synchronized { uninterruptible }) {
+      // We are already in the uninterruptible status. So just run "f" and return
+      return f
+    }
+
+    uninterruptibleLock.synchronized {
+      // Clear the interrupted status if it's set.
+      if (Thread.interrupted() || shouldInterruptThread) {
+        shouldInterruptThread = false
+        // Since it's interrupted, we don't need to run `f` which may be a long computation.
+        // Throw InterruptedException as we don't have a T to return.
+        throw new InterruptedException()
+      }
+      uninterruptible = true
+    }
+    try {
+      f
+    } finally {
+      uninterruptibleLock.synchronized {
+        uninterruptible = false
+        if (shouldInterruptThread) {
+          // Recover the interrupted status
+          super.interrupt()
+          shouldInterruptThread = false
+        }
+      }
+    }
+  }
+
+  /**
+   * Tests whether `interrupt()` has been called.
+   */
+  override def isInterrupted: Boolean = {
+    super.isInterrupted || uninterruptibleLock.synchronized { shouldInterruptThread }
+  }
+
+  /**
+   * Interrupt `this` thread if possible. If `this` is in the uninterruptible status, it won't be
+   * interrupted until it enters into the interruptible status.
+   */
+  override def interrupt(): Unit = {
+    uninterruptibleLock.synchronized {
+      if (uninterruptible) {
+        shouldInterruptThread = true
+      } else {
+        super.interrupt()
+      }
+    }
+  }
+}
diff --git a/core/src/test/scala/org/apache/spark/util/UninterruptibleThreadSuite.scala b/core/src/test/scala/org/apache/spark/util/UninterruptibleThreadSuite.scala
new file mode 100644
index 0000000000..39b31f8dde
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/util/UninterruptibleThreadSuite.scala
@@ -0,0 +1,159 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util
+
+import java.util.concurrent.{CountDownLatch, TimeUnit}
+
+import scala.util.Random
+
+import com.google.common.util.concurrent.Uninterruptibles
+
+import org.apache.spark.SparkFunSuite
+
+class UninterruptibleThreadSuite extends SparkFunSuite {
+
+  /** Sleep millis and return true if it's interrupted */
+  private def sleep(millis: Long): Boolean = {
+    try {
+      Thread.sleep(millis)
+      false
+    } catch {
+      case _: InterruptedException =>
+        true
+    }
+  }
+
+  test("interrupt when runUninterruptibly is running") {
+    val enterRunUninterruptibly = new CountDownLatch(1)
+    @volatile var hasInterruptedException = false
+    @volatile var interruptStatusBeforeExit = false
+    val t = new UninterruptibleThread("test") {
+      override def run(): Unit = {
+        runUninterruptibly {
+          enterRunUninterruptibly.countDown()
+          hasInterruptedException = sleep(1000)
+        }
+        interruptStatusBeforeExit = Thread.interrupted()
+      }
+    }
+    t.start()
+    assert(enterRunUninterruptibly.await(10, TimeUnit.SECONDS), "await timeout")
+    t.interrupt()
+    t.join()
+    assert(hasInterruptedException === false)
+    assert(interruptStatusBeforeExit === true)
+  }
+
+  test("interrupt before runUninterruptibly runs") {
+    val interruptLatch = new CountDownLatch(1)
+    @volatile var hasInterruptedException = false
+    @volatile var interruptStatusBeforeExit = false
+    val t = new UninterruptibleThread("test") {
+      override def run(): Unit = {
+        Uninterruptibles.awaitUninterruptibly(interruptLatch, 10, TimeUnit.SECONDS)
+        try {
+          runUninterruptibly {
+            assert(false, "Should not reach here")
+          }
+        } catch {
+          case _: InterruptedException => hasInterruptedException = true
+        }
+        interruptStatusBeforeExit = Thread.interrupted()
+      }
+    }
+    t.start()
+    t.interrupt()
+    interruptLatch.countDown()
+    t.join()
+    assert(hasInterruptedException === true)
+    assert(interruptStatusBeforeExit === false)
+  }
+
+  test("nested runUninterruptibly") {
+    val enterRunUninterruptibly = new CountDownLatch(1)
+    val interruptLatch = new CountDownLatch(1)
+    @volatile var hasInterruptedException = false
+    @volatile var interruptStatusBeforeExit = false
+    val t = new UninterruptibleThread("test") {
+      override def run(): Unit = {
+        runUninterruptibly {
+          enterRunUninterruptibly.countDown()
+          Uninterruptibles.awaitUninterruptibly(interruptLatch, 10, TimeUnit.SECONDS)
+          hasInterruptedException = sleep(1)
+          runUninterruptibly {
+            if (sleep(1)) {
+              hasInterruptedException = true
+            }
+          }
+          if (sleep(1)) {
+            hasInterruptedException = true
+          }
+        }
+        interruptStatusBeforeExit = Thread.interrupted()
+      }
+    }
+    t.start()
+    assert(enterRunUninterruptibly.await(10, TimeUnit.SECONDS), "await timeout")
+    t.interrupt()
+    interruptLatch.countDown()
+    t.join()
+    assert(hasInterruptedException === false)
+    assert(interruptStatusBeforeExit === true)
+  }
+
+  test("stress test") {
+    @volatile var hasInterruptedException = false
+    val t = new UninterruptibleThread("test") {
+      override def run(): Unit = {
+        for (i <- 0 until 100) {
+          try {
+            runUninterruptibly {
+              if (sleep(Random.nextInt(10))) {
+                hasInterruptedException = true
+              }
+              runUninterruptibly {
+                if (sleep(Random.nextInt(10))) {
+                  hasInterruptedException = true
+                }
+              }
+              if (sleep(Random.nextInt(10))) {
+                hasInterruptedException = true
+              }
+            }
+            Uninterruptibles.sleepUninterruptibly(Random.nextInt(10), TimeUnit.MILLISECONDS)
+            // 50% chance to clear the interrupted status
+            if (Random.nextBoolean()) {
+              Thread.interrupted()
+            }
+          } catch {
+            case _: InterruptedException =>
+              // The first runUninterruptibly may throw InterruptedException if the interrupt status
+              // is set before running `f`.
+          }
+        }
+      }
+    }
+    t.start()
+    for (i <- 0 until 400) {
+      Thread.sleep(Random.nextInt(10))
+      t.interrupt()
+    }
+    t.join()
+    assert(hasInterruptedException === false)
+  }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala
index 60e00d203c..c4e410d92c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala
@@ -19,7 +19,6 @@ package org.apache.spark.sql.execution.streaming
 
 import java.util.concurrent.{CountDownLatch, TimeUnit}
 import java.util.concurrent.atomic.AtomicInteger
-import javax.annotation.concurrent.GuardedBy
 
 import scala.collection.mutable.ArrayBuffer
 import scala.util.control.NonFatal
@@ -34,6 +33,7 @@ import org.apache.spark.sql.catalyst.util._
 import org.apache.spark.sql.execution.QueryExecution
 import org.apache.spark.sql.util.ContinuousQueryListener
 import org.apache.spark.sql.util.ContinuousQueryListener._
+import org.apache.spark.util.UninterruptibleThread
 
 /**
  * Manages the execution of a streaming Spark SQL query that is occurring in a separate thread.
@@ -89,9 +89,10 @@ class StreamExecution(
   private[sql] var streamDeathCause: ContinuousQueryException = null
 
   /** The thread that runs the micro-batches of this stream. */
-  private[sql] val microBatchThread = new Thread(s"stream execution thread for $name") {
-    override def run(): Unit = { runBatches() }
-  }
+  private[sql] val microBatchThread =
+    new UninterruptibleThread(s"stream execution thread for $name") {
+      override def run(): Unit = { runBatches() }
+    }
 
   /**
    * A write-ahead-log that records the offsets that are present in each batch. In order to ensure
@@ -102,65 +103,6 @@ class StreamExecution(
   private val offsetLog =
     new HDFSMetadataLog[CompositeOffset](sqlContext, checkpointFile("offsets"))
 
-  /** A monitor to protect "uninterruptible" and "interrupted" */
-  private val uninterruptibleLock = new Object
-
-  /**
-   * Indicates if "microBatchThread" are in the uninterruptible status. If so, interrupting
-   * "microBatchThread" will be deferred until "microBatchThread" enters into the interruptible
-   * status.
-   */
-  @GuardedBy("uninterruptibleLock")
-  private var uninterruptible = false
-
-  /**
-   * Indicates if we should interrupt "microBatchThread" when we are leaving the uninterruptible
-   * zone.
-   */
-  @GuardedBy("uninterruptibleLock")
-  private var shouldInterruptThread = false
-
-  /**
-   * Interrupt "microBatchThread" if possible. If "microBatchThread" is in the uninterruptible
-   * status, "microBatchThread" won't be interrupted until it enters into the interruptible status.
-   */
-  private def interruptMicroBatchThreadSafely(): Unit = {
-    uninterruptibleLock.synchronized {
-      if (uninterruptible) {
-        shouldInterruptThread = true
-      } else {
-        microBatchThread.interrupt()
-      }
-    }
-  }
-
-  /**
-   * Run `f` uninterruptibly in "microBatchThread". "microBatchThread" won't be interrupted before
-   * returning from `f`.
-   */
-  private def runUninterruptiblyInMicroBatchThread[T](f: => T): T = {
-    assert(Thread.currentThread() == microBatchThread)
-    uninterruptibleLock.synchronized {
-      uninterruptible = true
-      // Clear the interrupted status if it's set.
-      if (Thread.interrupted()) {
-        shouldInterruptThread = true
-      }
-    }
-    try {
-      f
-    } finally {
-      uninterruptibleLock.synchronized {
-        uninterruptible = false
-        if (shouldInterruptThread) {
-          // Recover the interrupted status
-          microBatchThread.interrupt()
-          shouldInterruptThread = false
-        }
-      }
-    }
-  }
-
   /** Whether the query is currently active or not */
   override def isActive: Boolean = state == ACTIVE
 
@@ -294,7 +236,7 @@ class StreamExecution(
     // method. See SPARK-14131.
     //
     // Check to see what new data is available.
-    val newData = runUninterruptiblyInMicroBatchThread {
+    val newData = microBatchThread.runUninterruptibly {
       uniqueSources.flatMap(s => s.getOffset.map(o => s -> o))
     }
     availableOffsets ++= newData
@@ -305,7 +247,7 @@ class StreamExecution(
       // As "offsetLog.add" will create a file using HDFS API and call "Shell.runCommand" to set
       // the file permission, we should not interrupt "microBatchThread" when running this method.
       // See SPARK-14131.
-      runUninterruptiblyInMicroBatchThread {
+      microBatchThread.runUninterruptibly {
         assert(
           offsetLog.add(currentBatchId, availableOffsets.toCompositeOffset(sources)),
           s"Concurrent update to the log.  Multiple streaming jobs detected for $currentBatchId")
@@ -395,7 +337,7 @@ class StreamExecution(
     // intentionally
     state = TERMINATED
     if (microBatchThread.isAlive) {
-      interruptMicroBatchThreadSafely()
+      microBatchThread.interrupt()
       microBatchThread.join()
     }
     logInfo(s"Query $name was stopped")
-- 
cgit v1.2.3


From 27aab80695cfcf0c0ecf1e98a5a862a8123213a1 Mon Sep 17 00:00:00 2001
From: Andrew Or 
Date: Mon, 28 Mar 2016 16:45:02 -0700
Subject: [SPARK-14013][SQL] Proper temp function support in catalog

## What changes were proposed in this pull request?

Session catalog was added in #11750. However, it doesn't really support temporary functions properly; right now we only store the metadata in the form of `CatalogFunction`, but this doesn't make sense for temporary functions because there is no class name.

This patch moves the `FunctionRegistry` into the `SessionCatalog`. With this, the user can call `catalog.createTempFunction` and `catalog.lookupFunction` to use the function they registered previously. This is currently still dead code, however.

## How was this patch tested?

`SessionCatalogSuite`.

Author: Andrew Or 

Closes #11972 from andrewor14/temp-functions.
---
 .../spark/sql/catalyst/analysis/Analyzer.scala     | 12 +++-
 .../sql/catalyst/analysis/FunctionRegistry.scala   | 24 +++++++
 .../sql/catalyst/catalog/SessionCatalog.scala      | 82 +++++++++++++---------
 .../spark/sql/catalyst/analysis/AnalysisTest.scala |  2 +-
 .../catalyst/analysis/DecimalPrecisionSuite.scala  |  2 +-
 .../sql/catalyst/catalog/SessionCatalogSuite.scala | 57 +++++++--------
 .../optimizer/BooleanSimplificationSuite.scala     |  2 +-
 .../catalyst/optimizer/EliminateSortsSuite.scala   |  2 +-
 .../apache/spark/sql/internal/SessionState.scala   |  8 +--
 .../apache/spark/sql/hive/HiveSessionCatalog.scala |  4 +-
 .../apache/spark/sql/hive/HiveSessionState.scala   | 14 ++--
 .../scala/org/apache/spark/sql/hive/hiveUDFs.scala | 10 +++
 12 files changed, 136 insertions(+), 83 deletions(-)

diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index 3b83e68018..8dc0532b3f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -42,9 +42,15 @@ import org.apache.spark.sql.types._
  * to resolve attribute references.
  */
 object SimpleAnalyzer
-  extends SimpleAnalyzer(new SimpleCatalystConf(caseSensitiveAnalysis = true))
-class SimpleAnalyzer(conf: CatalystConf)
-  extends Analyzer(new SessionCatalog(new InMemoryCatalog, conf), EmptyFunctionRegistry, conf)
+  extends SimpleAnalyzer(
+    EmptyFunctionRegistry,
+    new SimpleCatalystConf(caseSensitiveAnalysis = true))
+
+class SimpleAnalyzer(functionRegistry: FunctionRegistry, conf: CatalystConf)
+  extends Analyzer(
+    new SessionCatalog(new InMemoryCatalog, functionRegistry, conf),
+    functionRegistry,
+    conf)
 
 /**
  * Provides a logical query plan analyzer, which translates [[UnresolvedAttribute]]s and
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
index f584a4b73a..e9788b7e4d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -45,6 +45,13 @@ trait FunctionRegistry {
 
   /* Get the class of the registered function by specified name. */
   def lookupFunction(name: String): Option[ExpressionInfo]
+
+  /* Get the builder of the registered function by specified name. */
+  def lookupFunctionBuilder(name: String): Option[FunctionBuilder]
+
+  /** Drop a function and return whether the function existed. */
+  def dropFunction(name: String): Boolean
+
 }
 
 class SimpleFunctionRegistry extends FunctionRegistry {
@@ -76,6 +83,14 @@ class SimpleFunctionRegistry extends FunctionRegistry {
     functionBuilders.get(name).map(_._1)
   }
 
+  override def lookupFunctionBuilder(name: String): Option[FunctionBuilder] = synchronized {
+    functionBuilders.get(name).map(_._2)
+  }
+
+  override def dropFunction(name: String): Boolean = synchronized {
+    functionBuilders.remove(name).isDefined
+  }
+
   def copy(): SimpleFunctionRegistry = synchronized {
     val registry = new SimpleFunctionRegistry
     functionBuilders.iterator.foreach { case (name, (info, builder)) =>
@@ -106,6 +121,15 @@ object EmptyFunctionRegistry extends FunctionRegistry {
   override def lookupFunction(name: String): Option[ExpressionInfo] = {
     throw new UnsupportedOperationException
   }
+
+  override def lookupFunctionBuilder(name: String): Option[FunctionBuilder] = {
+    throw new UnsupportedOperationException
+  }
+
+  override def dropFunction(name: String): Boolean = {
+    throw new UnsupportedOperationException
+  }
+
 }
 
 
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala
index a9cf80764d..7165db1d5d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala
@@ -22,6 +22,9 @@ import scala.collection.mutable
 import org.apache.spark.sql.AnalysisException
 import org.apache.spark.sql.catalyst.{CatalystConf, SimpleCatalystConf}
 import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier}
+import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, SimpleFunctionRegistry}
+import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder
+import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionInfo}
 import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, SubqueryAlias}
 
 
@@ -32,15 +35,22 @@ import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, SubqueryAlias}
  *
  * This class is not thread-safe.
  */
-class SessionCatalog(externalCatalog: ExternalCatalog, conf: CatalystConf) {
+class SessionCatalog(
+    externalCatalog: ExternalCatalog,
+    functionRegistry: FunctionRegistry,
+    conf: CatalystConf) {
   import ExternalCatalog._
 
+  def this(externalCatalog: ExternalCatalog, functionRegistry: FunctionRegistry) {
+    this(externalCatalog, functionRegistry, new SimpleCatalystConf(true))
+  }
+
+  // For testing only.
   def this(externalCatalog: ExternalCatalog) {
-    this(externalCatalog, new SimpleCatalystConf(true))
+    this(externalCatalog, new SimpleFunctionRegistry)
   }
 
   protected[this] val tempTables = new mutable.HashMap[String, LogicalPlan]
-  protected[this] val tempFunctions = new mutable.HashMap[String, CatalogFunction]
 
   // Note: we track current database here because certain operations do not explicitly
   // specify the database (e.g. DROP TABLE my_table). In these cases we must first
@@ -431,6 +441,18 @@ class SessionCatalog(externalCatalog: ExternalCatalog, conf: CatalystConf) {
     externalCatalog.alterFunction(db, newFuncDefinition)
   }
 
+  /**
+   * Retrieve the metadata of a metastore function.
+   *
+   * If a database is specified in `name`, this will return the function in that database.
+   * If no database is specified, this will return the function in the current database.
+   */
+  def getFunction(name: FunctionIdentifier): CatalogFunction = {
+    val db = name.database.getOrElse(currentDb)
+    externalCatalog.getFunction(db, name.funcName)
+  }
+
+
   // ----------------------------------------------------------------
   // | Methods that interact with temporary and metastore functions |
   // ----------------------------------------------------------------
@@ -439,14 +461,14 @@ class SessionCatalog(externalCatalog: ExternalCatalog, conf: CatalystConf) {
    * Create a temporary function.
    * This assumes no database is specified in `funcDefinition`.
    */
-  def createTempFunction(funcDefinition: CatalogFunction, ignoreIfExists: Boolean): Unit = {
-    require(funcDefinition.identifier.database.isEmpty,
-      "attempted to create a temporary function while specifying a database")
-    val name = funcDefinition.identifier.funcName
-    if (tempFunctions.contains(name) && !ignoreIfExists) {
+  def createTempFunction(
+      name: String,
+      funcDefinition: FunctionBuilder,
+      ignoreIfExists: Boolean): Unit = {
+    if (functionRegistry.lookupFunctionBuilder(name).isDefined && !ignoreIfExists) {
       throw new AnalysisException(s"Temporary function '$name' already exists.")
     }
-    tempFunctions.put(name, funcDefinition)
+    functionRegistry.registerFunction(name, funcDefinition)
   }
 
   /**
@@ -456,11 +478,10 @@ class SessionCatalog(externalCatalog: ExternalCatalog, conf: CatalystConf) {
   // Hive has DROP FUNCTION and DROP TEMPORARY FUNCTION. We may want to consolidate
   // dropFunction and dropTempFunction.
   def dropTempFunction(name: String, ignoreIfNotExists: Boolean): Unit = {
-    if (!tempFunctions.contains(name) && !ignoreIfNotExists) {
+    if (!functionRegistry.dropFunction(name) && !ignoreIfNotExists) {
       throw new AnalysisException(
         s"Temporary function '$name' cannot be dropped because it does not exist!")
     }
-    tempFunctions.remove(name)
   }
 
   /**
@@ -477,34 +498,29 @@ class SessionCatalog(externalCatalog: ExternalCatalog, conf: CatalystConf) {
       throw new AnalysisException("rename does not support moving functions across databases")
     }
     val db = oldName.database.getOrElse(currentDb)
-    if (oldName.database.isDefined || !tempFunctions.contains(oldName.funcName)) {
+    val oldBuilder = functionRegistry.lookupFunctionBuilder(oldName.funcName)
+    if (oldName.database.isDefined || oldBuilder.isEmpty) {
       externalCatalog.renameFunction(db, oldName.funcName, newName.funcName)
     } else {
-      val func = tempFunctions(oldName.funcName)
-      val newFunc = func.copy(identifier = func.identifier.copy(funcName = newName.funcName))
-      tempFunctions.remove(oldName.funcName)
-      tempFunctions.put(newName.funcName, newFunc)
+      val oldExpressionInfo = functionRegistry.lookupFunction(oldName.funcName).get
+      val newExpressionInfo = new ExpressionInfo(
+        oldExpressionInfo.getClassName,
+        newName.funcName,
+        oldExpressionInfo.getUsage,
+        oldExpressionInfo.getExtended)
+      functionRegistry.dropFunction(oldName.funcName)
+      functionRegistry.registerFunction(newName.funcName, newExpressionInfo, oldBuilder.get)
     }
   }
 
   /**
-   * Retrieve the metadata of an existing function.
-   *
-   * If a database is specified in `name`, this will return the function in that database.
-   * If no database is specified, this will first attempt to return a temporary function with
-   * the same name, then, if that does not exist, return the function in the current database.
+   * Return an [[Expression]] that represents the specified function, assuming it exists.
+   * Note: This is currently only used for temporary functions.
    */
-  def getFunction(name: FunctionIdentifier): CatalogFunction = {
-    val db = name.database.getOrElse(currentDb)
-    if (name.database.isDefined || !tempFunctions.contains(name.funcName)) {
-      externalCatalog.getFunction(db, name.funcName)
-    } else {
-      tempFunctions(name.funcName)
-    }
+  def lookupFunction(name: String, children: Seq[Expression]): Expression = {
+    functionRegistry.lookupFunction(name, children)
   }
 
-  // TODO: implement lookupFunction that returns something from the registry itself
-
   /**
    * List all matching functions in the specified database, including temporary functions.
    */
@@ -512,7 +528,7 @@ class SessionCatalog(externalCatalog: ExternalCatalog, conf: CatalystConf) {
     val dbFunctions =
       externalCatalog.listFunctions(db, pattern).map { f => FunctionIdentifier(f, Some(db)) }
     val regex = pattern.replaceAll("\\*", ".*").r
-    val _tempFunctions = tempFunctions.keys.toSeq
+    val _tempFunctions = functionRegistry.listFunction()
       .filter { f => regex.pattern.matcher(f).matches() }
       .map { f => FunctionIdentifier(f) }
     dbFunctions ++ _tempFunctions
@@ -521,8 +537,8 @@ class SessionCatalog(externalCatalog: ExternalCatalog, conf: CatalystConf) {
   /**
    * Return a temporary function. For testing only.
    */
-  private[catalog] def getTempFunction(name: String): Option[CatalogFunction] = {
-    tempFunctions.get(name)
+  private[catalog] def getTempFunction(name: String): Option[FunctionBuilder] = {
+    functionRegistry.lookupFunctionBuilder(name)
   }
 
 }
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala
index 34cb97699d..3ec95ef5b5 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala
@@ -30,7 +30,7 @@ trait AnalysisTest extends PlanTest {
 
   private def makeAnalyzer(caseSensitive: Boolean): Analyzer = {
     val conf = new SimpleCatalystConf(caseSensitive)
-    val catalog = new SessionCatalog(new InMemoryCatalog, conf)
+    val catalog = new SessionCatalog(new InMemoryCatalog, EmptyFunctionRegistry, conf)
     catalog.createTempTable("TaBlE", TestRelations.testRelation, overrideIfExists = true)
     new Analyzer(catalog, EmptyFunctionRegistry, conf) {
       override val extendedResolutionRules = EliminateSubqueryAliases :: Nil
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala
index 6c08ccc34c..1a350bf847 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala
@@ -32,7 +32,7 @@ import org.apache.spark.sql.types._
 
 class DecimalPrecisionSuite extends PlanTest with BeforeAndAfter {
   private val conf = new SimpleCatalystConf(caseSensitiveAnalysis = true)
-  private val catalog = new SessionCatalog(new InMemoryCatalog, conf)
+  private val catalog = new SessionCatalog(new InMemoryCatalog, EmptyFunctionRegistry, conf)
   private val analyzer = new Analyzer(catalog, EmptyFunctionRegistry, conf)
 
   private val relation = LocalRelation(
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala
index 2948c5f8bd..acd97592b6 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.catalog
 import org.apache.spark.SparkFunSuite
 import org.apache.spark.sql.AnalysisException
 import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier}
+import org.apache.spark.sql.catalyst.expressions.{Expression, Literal}
 import org.apache.spark.sql.catalyst.plans.logical.{Range, SubqueryAlias}
 
 
@@ -682,20 +683,20 @@ class SessionCatalogSuite extends SparkFunSuite {
 
   test("create temp function") {
     val catalog = new SessionCatalog(newBasicCatalog())
-    val tempFunc1 = newFunc("temp1")
-    val tempFunc2 = newFunc("temp2")
-    catalog.createTempFunction(tempFunc1, ignoreIfExists = false)
-    catalog.createTempFunction(tempFunc2, ignoreIfExists = false)
+    val tempFunc1 = (e: Seq[Expression]) => e.head
+    val tempFunc2 = (e: Seq[Expression]) => e.last
+    catalog.createTempFunction("temp1", tempFunc1, ignoreIfExists = false)
+    catalog.createTempFunction("temp2", tempFunc2, ignoreIfExists = false)
     assert(catalog.getTempFunction("temp1") == Some(tempFunc1))
     assert(catalog.getTempFunction("temp2") == Some(tempFunc2))
     assert(catalog.getTempFunction("temp3") == None)
+    val tempFunc3 = (e: Seq[Expression]) => Literal(e.size)
     // Temporary function already exists
     intercept[AnalysisException] {
-      catalog.createTempFunction(tempFunc1, ignoreIfExists = false)
+      catalog.createTempFunction("temp1", tempFunc3, ignoreIfExists = false)
     }
     // Temporary function is overridden
-    val tempFunc3 = tempFunc1.copy(className = "something else")
-    catalog.createTempFunction(tempFunc3, ignoreIfExists = true)
+    catalog.createTempFunction("temp1", tempFunc3, ignoreIfExists = true)
     assert(catalog.getTempFunction("temp1") == Some(tempFunc3))
   }
 
@@ -725,8 +726,8 @@ class SessionCatalogSuite extends SparkFunSuite {
 
   test("drop temp function") {
     val catalog = new SessionCatalog(newBasicCatalog())
-    val tempFunc = newFunc("func1")
-    catalog.createTempFunction(tempFunc, ignoreIfExists = false)
+    val tempFunc = (e: Seq[Expression]) => e.head
+    catalog.createTempFunction("func1", tempFunc, ignoreIfExists = false)
     assert(catalog.getTempFunction("func1") == Some(tempFunc))
     catalog.dropTempFunction("func1", ignoreIfNotExists = false)
     assert(catalog.getTempFunction("func1") == None)
@@ -755,20 +756,15 @@ class SessionCatalogSuite extends SparkFunSuite {
     }
   }
 
-  test("get temp function") {
-    val externalCatalog = newBasicCatalog()
-    val sessionCatalog = new SessionCatalog(externalCatalog)
-    val metastoreFunc = externalCatalog.getFunction("db2", "func1")
-    val tempFunc = newFunc("func1").copy(className = "something weird")
-    sessionCatalog.createTempFunction(tempFunc, ignoreIfExists = false)
-    sessionCatalog.setCurrentDatabase("db2")
-    // If a database is specified, we'll always return the function in that database
-    assert(sessionCatalog.getFunction(FunctionIdentifier("func1", Some("db2"))) == metastoreFunc)
-    // If no database is specified, we'll first return temporary functions
-    assert(sessionCatalog.getFunction(FunctionIdentifier("func1")) == tempFunc)
-    // Then, if no such temporary function exist, check the current database
-    sessionCatalog.dropTempFunction("func1", ignoreIfNotExists = false)
-    assert(sessionCatalog.getFunction(FunctionIdentifier("func1")) == metastoreFunc)
+  test("lookup temp function") {
+    val catalog = new SessionCatalog(newBasicCatalog())
+    val tempFunc1 = (e: Seq[Expression]) => e.head
+    catalog.createTempFunction("func1", tempFunc1, ignoreIfExists = false)
+    assert(catalog.lookupFunction("func1", Seq(Literal(1), Literal(2), Literal(3))) == Literal(1))
+    catalog.dropTempFunction("func1", ignoreIfNotExists = false)
+    intercept[AnalysisException] {
+      catalog.lookupFunction("func1", Seq(Literal(1), Literal(2), Literal(3)))
+    }
   }
 
   test("rename function") {
@@ -813,8 +809,8 @@ class SessionCatalogSuite extends SparkFunSuite {
   test("rename temp function") {
     val externalCatalog = newBasicCatalog()
     val sessionCatalog = new SessionCatalog(externalCatalog)
-    val tempFunc = newFunc("func1").copy(className = "something weird")
-    sessionCatalog.createTempFunction(tempFunc, ignoreIfExists = false)
+    val tempFunc = (e: Seq[Expression]) => e.head
+    sessionCatalog.createTempFunction("func1", tempFunc, ignoreIfExists = false)
     sessionCatalog.setCurrentDatabase("db2")
     // If a database is specified, we'll always rename the function in that database
     sessionCatalog.renameFunction(
@@ -825,8 +821,7 @@ class SessionCatalogSuite extends SparkFunSuite {
     // If no database is specified, we'll first rename temporary functions
     sessionCatalog.createFunction(newFunc("func1", Some("db2")))
     sessionCatalog.renameFunction(FunctionIdentifier("func1"), FunctionIdentifier("func4"))
-    assert(sessionCatalog.getTempFunction("func4") ==
-      Some(tempFunc.copy(identifier = FunctionIdentifier("func4"))))
+    assert(sessionCatalog.getTempFunction("func4") == Some(tempFunc))
     assert(sessionCatalog.getTempFunction("func1") == None)
     assert(externalCatalog.listFunctions("db2", "*").toSet == Set("func1", "func3"))
     // Then, if no such temporary function exist, rename the function in the current database
@@ -858,12 +853,12 @@ class SessionCatalogSuite extends SparkFunSuite {
 
   test("list functions") {
     val catalog = new SessionCatalog(newBasicCatalog())
-    val tempFunc1 = newFunc("func1").copy(className = "march")
-    val tempFunc2 = newFunc("yes_me").copy(className = "april")
+    val tempFunc1 = (e: Seq[Expression]) => e.head
+    val tempFunc2 = (e: Seq[Expression]) => e.last
     catalog.createFunction(newFunc("func2", Some("db2")))
     catalog.createFunction(newFunc("not_me", Some("db2")))
-    catalog.createTempFunction(tempFunc1, ignoreIfExists = false)
-    catalog.createTempFunction(tempFunc2, ignoreIfExists = false)
+    catalog.createTempFunction("func1", tempFunc1, ignoreIfExists = false)
+    catalog.createTempFunction("yes_me", tempFunc2, ignoreIfExists = false)
     assert(catalog.listFunctions("db1", "*").toSet ==
       Set(FunctionIdentifier("func1"),
         FunctionIdentifier("yes_me")))
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala
index e2c76b700f..dd6b5cac28 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala
@@ -140,7 +140,7 @@ class BooleanSimplificationSuite extends PlanTest with PredicateHelper {
 
   private val caseInsensitiveConf = new SimpleCatalystConf(false)
   private val caseInsensitiveAnalyzer = new Analyzer(
-    new SessionCatalog(new InMemoryCatalog, caseInsensitiveConf),
+    new SessionCatalog(new InMemoryCatalog, EmptyFunctionRegistry, caseInsensitiveConf),
     EmptyFunctionRegistry,
     caseInsensitiveConf)
 
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSortsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSortsSuite.scala
index 3824c67563..009889d5a1 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSortsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSortsSuite.scala
@@ -29,7 +29,7 @@ import org.apache.spark.sql.catalyst.rules._
 
 class EliminateSortsSuite extends PlanTest {
   val conf = new SimpleCatalystConf(caseSensitiveAnalysis = true, orderByOrdinal = false)
-  val catalog = new SessionCatalog(new InMemoryCatalog, conf)
+  val catalog = new SessionCatalog(new InMemoryCatalog, EmptyFunctionRegistry, conf)
   val analyzer = new Analyzer(catalog, EmptyFunctionRegistry, conf)
 
   object Optimize extends RuleExecutor[LogicalPlan] {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala
index 9bc640763f..f7fdfacd31 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala
@@ -44,14 +44,14 @@ private[sql] class SessionState(ctx: SQLContext) {
   lazy val experimentalMethods = new ExperimentalMethods
 
   /**
-   * Internal catalog for managing table and database states.
+   * Internal catalog for managing functions registered by the user.
    */
-  lazy val catalog = new SessionCatalog(ctx.externalCatalog, conf)
+  lazy val functionRegistry: FunctionRegistry = FunctionRegistry.builtin.copy()
 
   /**
-   * Internal catalog for managing functions registered by the user.
+   * Internal catalog for managing table and database states.
    */
-  lazy val functionRegistry: FunctionRegistry = FunctionRegistry.builtin.copy()
+  lazy val catalog = new SessionCatalog(ctx.externalCatalog, functionRegistry, conf)
 
   /**
    * Interface exposed to the user for registering user-defined functions.
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala
index ec7bf61be1..ff12245e8d 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala
@@ -18,6 +18,7 @@
 package org.apache.spark.sql.hive
 
 import org.apache.spark.sql.catalyst.TableIdentifier
+import org.apache.spark.sql.catalyst.analysis.FunctionRegistry
 import org.apache.spark.sql.catalyst.catalog.SessionCatalog
 import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, SubqueryAlias}
 import org.apache.spark.sql.catalyst.rules.Rule
@@ -31,8 +32,9 @@ class HiveSessionCatalog(
     externalCatalog: HiveExternalCatalog,
     client: HiveClient,
     context: HiveContext,
+    functionRegistry: FunctionRegistry,
     conf: SQLConf)
-  extends SessionCatalog(externalCatalog, conf) {
+  extends SessionCatalog(externalCatalog, functionRegistry, conf) {
 
   override def setCurrentDatabase(db: String): Unit = {
     super.setCurrentDatabase(db)
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala
index caa7f296ed..c9b6b1dfb6 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala
@@ -34,13 +34,6 @@ private[hive] class HiveSessionState(ctx: HiveContext) extends SessionState(ctx)
     override def caseSensitiveAnalysis: Boolean = getConf(SQLConf.CASE_SENSITIVE, false)
   }
 
-  /**
-   * Internal catalog for managing table and database states.
-   */
-  override lazy val catalog = {
-    new HiveSessionCatalog(ctx.hiveCatalog, ctx.metadataHive, ctx, conf)
-  }
-
   /**
    * Internal catalog for managing functions registered by the user.
    * Note that HiveUDFs will be overridden by functions registered in this context.
@@ -49,6 +42,13 @@ private[hive] class HiveSessionState(ctx: HiveContext) extends SessionState(ctx)
     new HiveFunctionRegistry(FunctionRegistry.builtin.copy(), ctx.executionHive)
   }
 
+  /**
+   * Internal catalog for managing table and database states.
+   */
+  override lazy val catalog = {
+    new HiveSessionCatalog(ctx.hiveCatalog, ctx.metadataHive, ctx, functionRegistry, conf)
+  }
+
   /**
    * An analyzer that uses the Hive metastore.
    */
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala
index efaa052370..c07c428895 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala
@@ -141,6 +141,16 @@ private[hive] class HiveFunctionRegistry(
       }
     }.getOrElse(None))
   }
+
+  override def lookupFunctionBuilder(name: String): Option[FunctionBuilder] = {
+    underlying.lookupFunctionBuilder(name)
+  }
+
+  // Note: This does not drop functions stored in the metastore
+  override def dropFunction(name: String): Boolean = {
+    underlying.dropFunction(name)
+  }
+
 }
 
 private[hive] case class HiveSimpleUDF(
-- 
cgit v1.2.3


From a916d2a454b63a4c234b1e0b5bf9c5b212bd37fa Mon Sep 17 00:00:00 2001
From: Andrew Or 
Date: Mon, 28 Mar 2016 16:45:31 -0700
Subject: [SPARK-14119][SPARK-14120][SPARK-14122][SQL] Throw exception on
 unsupported DDL commands

## What changes were proposed in this pull request?

Before: We just pass all role commands to Hive even though it doesn't work.
After: We throw an `AnalysisException` that looks like this:

```
scala> sql("CREATE ROLE x")
org.apache.spark.sql.AnalysisException: Unsupported Hive operation: CREATE ROLE;
  at org.apache.spark.sql.hive.HiveQl$$anonfun$parsePlan$1.apply(HiveQl.scala:213)
  at org.apache.spark.sql.hive.HiveQl$$anonfun$parsePlan$1.apply(HiveQl.scala:208)
  at org.apache.spark.sql.catalyst.parser.CatalystQl.safeParse(CatalystQl.scala:49)
  at org.apache.spark.sql.hive.HiveQl.parsePlan(HiveQl.scala:208)
  at org.apache.spark.sql.SQLContext.parseSql(SQLContext.scala:198)
```

## How was this patch tested?

`HiveQuerySuite`

Author: Andrew Or 

Closes #11948 from andrewor14/ddl-role-management.
---
 .../hive/execution/HiveCompatibilitySuite.scala    | 17 ++++----
 .../scala/org/apache/spark/sql/hive/HiveQl.scala   | 50 +++++++++++-----------
 .../spark/sql/hive/execution/HiveQuerySuite.scala  | 34 +++++++++++++++
 3 files changed, 69 insertions(+), 32 deletions(-)

diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala
index 650797f768..bedbf9ae17 100644
--- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala
+++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala
@@ -291,7 +291,6 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter {
     "compute_stats_empty_table",
     "compute_stats_long",
     "create_view_translate",
-    "show_create_table_serde",
     "show_tblproperties",
 
     // Odd changes to output
@@ -344,6 +343,15 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter {
     // These tests check the VIEW table definition, but Spark handles CREATE VIEW itself and
     // generates different View Expanded Text.
     "alter_view_as_select",
+
+    // We don't support show create table commands in general
+    "show_create_table_alter",
+    "show_create_table_db_table",
+    "show_create_table_delimited",
+    "show_create_table_does_not_exist",
+    "show_create_table_index",
+    "show_create_table_partitioned",
+    "show_create_table_serde",
     "show_create_table_view"
   )
 
@@ -833,13 +841,6 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter {
     "serde_reported_schema",
     "set_variable_sub",
     "show_columns",
-    "show_create_table_alter",
-    "show_create_table_db_table",
-    "show_create_table_delimited",
-    "show_create_table_does_not_exist",
-    "show_create_table_index",
-    "show_create_table_partitioned",
-    "show_create_table_serde",
     "show_describe_func_quotes",
     "show_functions",
     "show_partitions",
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
index b3ec95fc73..052c43a3ce 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
@@ -31,7 +31,6 @@ import org.apache.hadoop.hive.serde.serdeConstants
 import org.apache.hadoop.hive.serde2.`lazy`.LazySimpleSerDe
 
 import org.apache.spark.internal.Logging
-import org.apache.spark.sql.catalyst.TableIdentifier
 import org.apache.spark.sql.catalyst.catalog._
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.parser._
@@ -83,7 +82,28 @@ private[hive] class HiveQl(conf: ParserConf) extends SparkQl(conf) with Logging
   import ParseUtils._
   import ParserUtils._
 
-  protected val nativeCommands = Seq(
+  // Token text -> human readable text
+  private val hiveUnsupportedCommands = Map(
+    "TOK_CREATEROLE" -> "CREATE ROLE",
+    "TOK_DROPROLE" -> "DROP ROLE",
+    "TOK_EXPORT" -> "EXPORT TABLE",
+    "TOK_GRANT" -> "GRANT",
+    "TOK_GRANT_ROLE" -> "GRANT",
+    "TOK_IMPORT" -> "IMPORT TABLE",
+    "TOK_REVOKE" -> "REVOKE",
+    "TOK_REVOKE_ROLE" -> "REVOKE",
+    "TOK_SHOW_COMPACTIONS" -> "SHOW COMPACTIONS",
+    "TOK_SHOW_CREATETABLE" -> "SHOW CREATE TABLE",
+    "TOK_SHOW_GRANT" -> "SHOW GRANT",
+    "TOK_SHOW_ROLE_GRANT" -> "SHOW ROLE GRANT",
+    "TOK_SHOW_ROLE_PRINCIPALS" -> "SHOW PRINCIPALS",
+    "TOK_SHOW_ROLES" -> "SHOW ROLES",
+    "TOK_SHOW_SET_ROLE" -> "SHOW CURRENT ROLES / SET ROLE",
+    "TOK_SHOW_TRANSACTIONS" -> "SHOW TRANSACTIONS",
+    "TOK_SHOWINDEXES" -> "SHOW INDEXES",
+    "TOK_SHOWLOCKS" -> "SHOW LOCKS")
+
+  private val nativeCommands = Set(
     "TOK_ALTERDATABASE_OWNER",
     "TOK_ALTERINDEX_PROPERTIES",
     "TOK_ALTERINDEX_REBUILD",
@@ -97,51 +117,30 @@ private[hive] class HiveQl(conf: ParserConf) extends SparkQl(conf) with Logging
 
     "TOK_CREATEINDEX",
     "TOK_CREATEMACRO",
-    "TOK_CREATEROLE",
 
     "TOK_DROPINDEX",
     "TOK_DROPMACRO",
-    "TOK_DROPROLE",
     "TOK_DROPTABLE_PROPERTIES",
     "TOK_DROPVIEW",
     "TOK_DROPVIEW_PROPERTIES",
 
-    "TOK_EXPORT",
-
-    "TOK_GRANT",
-    "TOK_GRANT_ROLE",
-
-    "TOK_IMPORT",
-
     "TOK_LOAD",
 
     "TOK_LOCKTABLE",
 
     "TOK_MSCK",
 
-    "TOK_REVOKE",
-
-    "TOK_SHOW_COMPACTIONS",
-    "TOK_SHOW_CREATETABLE",
-    "TOK_SHOW_GRANT",
-    "TOK_SHOW_ROLE_GRANT",
-    "TOK_SHOW_ROLE_PRINCIPALS",
-    "TOK_SHOW_ROLES",
-    "TOK_SHOW_SET_ROLE",
     "TOK_SHOW_TABLESTATUS",
     "TOK_SHOW_TBLPROPERTIES",
-    "TOK_SHOW_TRANSACTIONS",
     "TOK_SHOWCOLUMNS",
     "TOK_SHOWDATABASES",
-    "TOK_SHOWINDEXES",
-    "TOK_SHOWLOCKS",
     "TOK_SHOWPARTITIONS",
 
     "TOK_UNLOCKTABLE"
   )
 
   // Commands that we do not need to explain.
-  protected val noExplainCommands = Seq(
+  private val noExplainCommands = Set(
     "TOK_DESCTABLE",
     "TOK_SHOWTABLES",
     "TOK_TRUNCATETABLE", // truncate table" is a NativeCommand, does not need to explain.
@@ -209,6 +208,9 @@ private[hive] class HiveQl(conf: ParserConf) extends SparkQl(conf) with Logging
     safeParse(sql, ParseDriver.parsePlan(sql, conf)) { ast =>
       if (nativeCommands.contains(ast.text)) {
         HiveNativeCommand(sql)
+      } else if (hiveUnsupportedCommands.contains(ast.text)) {
+        val humanReadableText = hiveUnsupportedCommands(ast.text)
+        throw new AnalysisException("Unsupported operation: " + humanReadableText)
       } else {
         nodeToPlan(ast) match {
           case NativePlaceholder => HiveNativeCommand(sql)
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala
index 197a123905..79774f5913 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala
@@ -68,6 +68,11 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter {
     }
   }
 
+  private def assertUnsupportedFeature(body: => Unit): Unit = {
+    val e = intercept[AnalysisException] { body }
+    assert(e.getMessage.toLowerCase.contains("unsupported operation"))
+  }
+
   test("SPARK-4908: concurrent hive native commands") {
     (1 to 100).par.map { _ =>
       sql("USE default")
@@ -1246,6 +1251,35 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter {
 
   // Put tests that depend on specific Hive settings before these last two test,
   // since they modify /clear stuff.
+
+  test("role management commands are not supported") {
+    assertUnsupportedFeature { sql("CREATE ROLE my_role") }
+    assertUnsupportedFeature { sql("DROP ROLE my_role") }
+    assertUnsupportedFeature { sql("SHOW CURRENT ROLES") }
+    assertUnsupportedFeature { sql("SHOW ROLES") }
+    assertUnsupportedFeature { sql("SHOW GRANT") }
+    assertUnsupportedFeature { sql("SHOW ROLE GRANT USER my_principal") }
+    assertUnsupportedFeature { sql("SHOW PRINCIPALS my_role") }
+    assertUnsupportedFeature { sql("SET ROLE my_role") }
+    assertUnsupportedFeature { sql("GRANT my_role TO USER my_user") }
+    assertUnsupportedFeature { sql("GRANT ALL ON my_table TO USER my_user") }
+    assertUnsupportedFeature { sql("REVOKE my_role FROM USER my_user") }
+    assertUnsupportedFeature { sql("REVOKE ALL ON my_table FROM USER my_user") }
+  }
+
+  test("import/export commands are not supported") {
+    assertUnsupportedFeature { sql("IMPORT TABLE my_table FROM 'my_path'") }
+    assertUnsupportedFeature { sql("EXPORT TABLE my_table TO 'my_path'") }
+  }
+
+  test("some show commands are not supported") {
+    assertUnsupportedFeature { sql("SHOW CREATE TABLE my_table") }
+    assertUnsupportedFeature { sql("SHOW COMPACTIONS") }
+    assertUnsupportedFeature { sql("SHOW TRANSACTIONS") }
+    assertUnsupportedFeature { sql("SHOW INDEXES ON my_table") }
+    assertUnsupportedFeature { sql("SHOW LOCKS my_table") }
+  }
+
 }
 
 // for SPARK-2180 test
-- 
cgit v1.2.3


From ad9e3d50f71b096872806a86d89c03a208b1cf8b Mon Sep 17 00:00:00 2001
From: jeanlyn 
Date: Mon, 28 Mar 2016 16:56:25 -0700
Subject: [SPARK-13845][CORE] Using onBlockUpdated to replace onTaskEnd
 avioding driver OOM

## What changes were proposed in this pull request?

We have a streaming job using `FlumePollInputStream` always driver OOM after few days, here is some driver heap dump before OOM
```
 num     #instances         #bytes  class name
----------------------------------------------
   1:      13845916      553836640  org.apache.spark.storage.BlockStatus
   2:      14020324      336487776  org.apache.spark.storage.StreamBlockId
   3:      13883881      333213144  scala.collection.mutable.DefaultEntry
   4:          8907       89043952  [Lscala.collection.mutable.HashEntry;
   5:         62360       65107352  [B
   6:        163368       24453904  [Ljava.lang.Object;
   7:        293651       20342664  [C
...
```
`BlockStatus` and `StreamBlockId` keep on growing, and the driver OOM in the end.
After investigated, i found the `executorIdToStorageStatus` in `StorageStatusListener` seems never remove the blocks from `StorageStatus`.
In order to fix the issue, i try to use `onBlockUpdated` replace `onTaskEnd ` , so we can update the block informations(add blocks, drop the block from memory to disk and delete the blocks) in time.

## How was this patch tested?

Existing unit tests and manual tests

Author: jeanlyn 

Closes #11779 from jeanlyn/fix_driver_oom.
---
 .../spark/storage/StorageStatusListener.scala      | 21 ++++---
 .../org/apache/spark/ui/storage/StorageTab.scala   | 21 ++++---
 .../executor_list_json_expectation.json            |  4 +-
 .../rdd_list_storage_json_expectation.json         | 10 +---
 .../spark/deploy/history/HistoryServerSuite.scala  |  5 +-
 .../spark/storage/StorageStatusListenerSuite.scala | 67 ++++++++++++----------
 .../apache/spark/ui/storage/StorageTabSuite.scala  | 58 +++++++++----------
 7 files changed, 91 insertions(+), 95 deletions(-)

diff --git a/core/src/main/scala/org/apache/spark/storage/StorageStatusListener.scala b/core/src/main/scala/org/apache/spark/storage/StorageStatusListener.scala
index f552b498a7..3008520f61 100644
--- a/core/src/main/scala/org/apache/spark/storage/StorageStatusListener.scala
+++ b/core/src/main/scala/org/apache/spark/storage/StorageStatusListener.scala
@@ -66,17 +66,6 @@ class StorageStatusListener(conf: SparkConf) extends SparkListener {
     }
   }
 
-  override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = synchronized {
-    val info = taskEnd.taskInfo
-    val metrics = taskEnd.taskMetrics
-    if (info != null && metrics != null) {
-      val updatedBlocks = metrics.updatedBlockStatuses
-      if (updatedBlocks.length > 0) {
-        updateStorageStatus(info.executorId, updatedBlocks)
-      }
-    }
-  }
-
   override def onUnpersistRDD(unpersistRDD: SparkListenerUnpersistRDD): Unit = synchronized {
     updateStorageStatus(unpersistRDD.rddId)
   }
@@ -102,4 +91,14 @@ class StorageStatusListener(conf: SparkConf) extends SparkListener {
       }
     }
   }
+
+  override def onBlockUpdated(blockUpdated: SparkListenerBlockUpdated): Unit = {
+    val executorId = blockUpdated.blockUpdatedInfo.blockManagerId.executorId
+    val blockId = blockUpdated.blockUpdatedInfo.blockId
+    val storageLevel = blockUpdated.blockUpdatedInfo.storageLevel
+    val memSize = blockUpdated.blockUpdatedInfo.memSize
+    val diskSize = blockUpdated.blockUpdatedInfo.diskSize
+    val blockStatus = BlockStatus(storageLevel, memSize, diskSize)
+    updateStorageStatus(executorId, Seq((blockId, blockStatus)))
+  }
 }
diff --git a/core/src/main/scala/org/apache/spark/ui/storage/StorageTab.scala b/core/src/main/scala/org/apache/spark/ui/storage/StorageTab.scala
index 8f75b586e1..50095831b4 100644
--- a/core/src/main/scala/org/apache/spark/ui/storage/StorageTab.scala
+++ b/core/src/main/scala/org/apache/spark/ui/storage/StorageTab.scala
@@ -57,17 +57,6 @@ class StorageListener(storageStatusListener: StorageStatusListener) extends Bloc
     StorageUtils.updateRddInfo(rddInfosToUpdate, activeStorageStatusList)
   }
 
-  /**
-   * Assumes the storage status list is fully up-to-date. This implies the corresponding
-   * StorageStatusSparkListener must process the SparkListenerTaskEnd event before this listener.
-   */
-  override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = synchronized {
-    val metrics = taskEnd.taskMetrics
-    if (metrics != null && metrics.updatedBlockStatuses.nonEmpty) {
-      updateRDDInfo(metrics.updatedBlockStatuses)
-    }
-  }
-
   override def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted): Unit = synchronized {
     val rddInfos = stageSubmitted.stageInfo.rddInfos
     rddInfos.foreach { info => _rddInfoMap.getOrElseUpdate(info.id, info) }
@@ -84,4 +73,14 @@ class StorageListener(storageStatusListener: StorageStatusListener) extends Bloc
   override def onUnpersistRDD(unpersistRDD: SparkListenerUnpersistRDD): Unit = synchronized {
     _rddInfoMap.remove(unpersistRDD.rddId)
   }
+
+  override def onBlockUpdated(blockUpdated: SparkListenerBlockUpdated): Unit = {
+    super.onBlockUpdated(blockUpdated)
+    val blockId = blockUpdated.blockUpdatedInfo.blockId
+    val storageLevel = blockUpdated.blockUpdatedInfo.storageLevel
+    val memSize = blockUpdated.blockUpdatedInfo.memSize
+    val diskSize = blockUpdated.blockUpdatedInfo.diskSize
+    val blockStatus = BlockStatus(storageLevel, memSize, diskSize)
+    updateRDDInfo(Seq((blockId, blockStatus)))
+  }
 }
diff --git a/core/src/test/resources/HistoryServerExpectations/executor_list_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/executor_list_json_expectation.json
index 4a88eeee74..efc865919b 100644
--- a/core/src/test/resources/HistoryServerExpectations/executor_list_json_expectation.json
+++ b/core/src/test/resources/HistoryServerExpectations/executor_list_json_expectation.json
@@ -2,8 +2,8 @@
   "id" : "",
   "hostPort" : "localhost:57971",
   "isActive" : true,
-  "rddBlocks" : 8,
-  "memoryUsed" : 28000128,
+  "rddBlocks" : 0,
+  "memoryUsed" : 0,
   "diskUsed" : 0,
   "totalCores" : 0,
   "maxTasks" : 0,
diff --git a/core/src/test/resources/HistoryServerExpectations/rdd_list_storage_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/rdd_list_storage_json_expectation.json
index f79a31022d..8878e547a7 100644
--- a/core/src/test/resources/HistoryServerExpectations/rdd_list_storage_json_expectation.json
+++ b/core/src/test/resources/HistoryServerExpectations/rdd_list_storage_json_expectation.json
@@ -1,9 +1 @@
-[ {
-  "id" : 0,
-  "name" : "0",
-  "numPartitions" : 8,
-  "numCachedPartitions" : 8,
-  "storageLevel" : "Memory Deserialized 1x Replicated",
-  "memoryUsed" : 28000128,
-  "diskUsed" : 0
-} ]
\ No newline at end of file
+[ ]
\ No newline at end of file
diff --git a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala
index 5822261d8d..79e4efb1a8 100644
--- a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala
@@ -140,8 +140,9 @@ class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with Matchers
     "stage task list from multi-attempt app json(2)" ->
       "applications/local-1426533911241/2/stages/0/0/taskList",
 
-    "rdd list storage json" -> "applications/local-1422981780767/storage/rdd",
-    "one rdd storage json" -> "applications/local-1422981780767/storage/rdd/0"
+    "rdd list storage json" -> "applications/local-1422981780767/storage/rdd"
+    // Todo: enable this test when logging the even of onBlockUpdated. See: SPARK-13845
+    // "one rdd storage json" -> "applications/local-1422981780767/storage/rdd/0"
   )
 
   // run a bunch of characterization tests -- just verify the behavior is the same as what is saved
diff --git a/core/src/test/scala/org/apache/spark/storage/StorageStatusListenerSuite.scala b/core/src/test/scala/org/apache/spark/storage/StorageStatusListenerSuite.scala
index 14daa003bc..9835f11a2f 100644
--- a/core/src/test/scala/org/apache/spark/storage/StorageStatusListenerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/StorageStatusListenerSuite.scala
@@ -82,48 +82,51 @@ class StorageStatusListenerSuite extends SparkFunSuite {
     assert(listener.executorIdToStorageStatus("fat").numBlocks === 0)
   }
 
-  test("task end with updated blocks") {
+  test("updated blocks") {
     val listener = new StorageStatusListener(conf)
     listener.onBlockManagerAdded(SparkListenerBlockManagerAdded(1L, bm1, 1000L))
     listener.onBlockManagerAdded(SparkListenerBlockManagerAdded(1L, bm2, 2000L))
-    val taskMetrics1 = new TaskMetrics
-    val taskMetrics2 = new TaskMetrics
-    val block1 = (RDDBlockId(1, 1), BlockStatus(StorageLevel.DISK_ONLY, 0L, 100L))
-    val block2 = (RDDBlockId(1, 2), BlockStatus(StorageLevel.DISK_ONLY, 0L, 200L))
-    val block3 = (RDDBlockId(4, 0), BlockStatus(StorageLevel.DISK_ONLY, 0L, 300L))
-    taskMetrics1.setUpdatedBlockStatuses(Seq(block1, block2))
-    taskMetrics2.setUpdatedBlockStatuses(Seq(block3))
-
-    // Task end with new blocks
+
+    val blockUpdateInfos1 = Seq(
+      BlockUpdatedInfo(bm1, RDDBlockId(1, 1), StorageLevel.DISK_ONLY, 0L, 100L),
+      BlockUpdatedInfo(bm1, RDDBlockId(1, 2), StorageLevel.DISK_ONLY, 0L, 200L)
+    )
+    val blockUpdateInfos2 =
+      Seq(BlockUpdatedInfo(bm2, RDDBlockId(4, 0), StorageLevel.DISK_ONLY, 0L, 300L))
+
+    // Add some new blocks
     assert(listener.executorIdToStorageStatus("big").numBlocks === 0)
     assert(listener.executorIdToStorageStatus("fat").numBlocks === 0)
-    listener.onTaskEnd(SparkListenerTaskEnd(1, 0, "obliteration", Success, taskInfo1, taskMetrics1))
+    postUpdateBlock(listener, blockUpdateInfos1)
     assert(listener.executorIdToStorageStatus("big").numBlocks === 2)
     assert(listener.executorIdToStorageStatus("fat").numBlocks === 0)
     assert(listener.executorIdToStorageStatus("big").containsBlock(RDDBlockId(1, 1)))
     assert(listener.executorIdToStorageStatus("big").containsBlock(RDDBlockId(1, 2)))
     assert(listener.executorIdToStorageStatus("fat").numBlocks === 0)
-    listener.onTaskEnd(SparkListenerTaskEnd(1, 0, "obliteration", Success, taskInfo2, taskMetrics2))
+    postUpdateBlock(listener, blockUpdateInfos2)
     assert(listener.executorIdToStorageStatus("big").numBlocks === 2)
     assert(listener.executorIdToStorageStatus("fat").numBlocks === 1)
     assert(listener.executorIdToStorageStatus("big").containsBlock(RDDBlockId(1, 1)))
     assert(listener.executorIdToStorageStatus("big").containsBlock(RDDBlockId(1, 2)))
     assert(listener.executorIdToStorageStatus("fat").containsBlock(RDDBlockId(4, 0)))
 
-    // Task end with dropped blocks
-    val droppedBlock1 = (RDDBlockId(1, 1), BlockStatus(StorageLevel.NONE, 0L, 0L))
-    val droppedBlock2 = (RDDBlockId(1, 2), BlockStatus(StorageLevel.NONE, 0L, 0L))
-    val droppedBlock3 = (RDDBlockId(4, 0), BlockStatus(StorageLevel.NONE, 0L, 0L))
-    taskMetrics1.setUpdatedBlockStatuses(Seq(droppedBlock1, droppedBlock3))
-    taskMetrics2.setUpdatedBlockStatuses(Seq(droppedBlock2, droppedBlock3))
+    // Dropped the blocks
+    val droppedBlockInfo1 = Seq(
+      BlockUpdatedInfo(bm1, RDDBlockId(1, 1), StorageLevel.NONE, 0L, 0L),
+      BlockUpdatedInfo(bm1, RDDBlockId(4, 0), StorageLevel.NONE, 0L, 0L)
+    )
+    val droppedBlockInfo2 = Seq(
+      BlockUpdatedInfo(bm2, RDDBlockId(1, 2), StorageLevel.NONE, 0L, 0L),
+      BlockUpdatedInfo(bm2, RDDBlockId(4, 0), StorageLevel.NONE, 0L, 0L)
+    )
 
-    listener.onTaskEnd(SparkListenerTaskEnd(1, 0, "obliteration", Success, taskInfo1, taskMetrics1))
+    postUpdateBlock(listener, droppedBlockInfo1)
     assert(listener.executorIdToStorageStatus("big").numBlocks === 1)
     assert(listener.executorIdToStorageStatus("fat").numBlocks === 1)
     assert(!listener.executorIdToStorageStatus("big").containsBlock(RDDBlockId(1, 1)))
     assert(listener.executorIdToStorageStatus("big").containsBlock(RDDBlockId(1, 2)))
     assert(listener.executorIdToStorageStatus("fat").containsBlock(RDDBlockId(4, 0)))
-    listener.onTaskEnd(SparkListenerTaskEnd(1, 0, "obliteration", Success, taskInfo2, taskMetrics2))
+    postUpdateBlock(listener, droppedBlockInfo2)
     assert(listener.executorIdToStorageStatus("big").numBlocks === 1)
     assert(listener.executorIdToStorageStatus("fat").numBlocks === 0)
     assert(!listener.executorIdToStorageStatus("big").containsBlock(RDDBlockId(1, 1)))
@@ -134,15 +137,14 @@ class StorageStatusListenerSuite extends SparkFunSuite {
   test("unpersist RDD") {
     val listener = new StorageStatusListener(conf)
     listener.onBlockManagerAdded(SparkListenerBlockManagerAdded(1L, bm1, 1000L))
-    val taskMetrics1 = new TaskMetrics
-    val taskMetrics2 = new TaskMetrics
-    val block1 = (RDDBlockId(1, 1), BlockStatus(StorageLevel.DISK_ONLY, 0L, 100L))
-    val block2 = (RDDBlockId(1, 2), BlockStatus(StorageLevel.DISK_ONLY, 0L, 200L))
-    val block3 = (RDDBlockId(4, 0), BlockStatus(StorageLevel.DISK_ONLY, 0L, 300L))
-    taskMetrics1.setUpdatedBlockStatuses(Seq(block1, block2))
-    taskMetrics2.setUpdatedBlockStatuses(Seq(block3))
-    listener.onTaskEnd(SparkListenerTaskEnd(1, 0, "obliteration", Success, taskInfo1, taskMetrics1))
-    listener.onTaskEnd(SparkListenerTaskEnd(1, 0, "obliteration", Success, taskInfo1, taskMetrics2))
+    val blockUpdateInfos1 = Seq(
+      BlockUpdatedInfo(bm1, RDDBlockId(1, 1), StorageLevel.DISK_ONLY, 0L, 100L),
+      BlockUpdatedInfo(bm1, RDDBlockId(1, 2), StorageLevel.DISK_ONLY, 0L, 200L)
+    )
+    val blockUpdateInfos2 =
+      Seq(BlockUpdatedInfo(bm1, RDDBlockId(4, 0), StorageLevel.DISK_ONLY, 0L, 300L))
+    postUpdateBlock(listener, blockUpdateInfos1)
+    postUpdateBlock(listener, blockUpdateInfos2)
     assert(listener.executorIdToStorageStatus("big").numBlocks === 3)
 
     // Unpersist RDD
@@ -155,4 +157,11 @@ class StorageStatusListenerSuite extends SparkFunSuite {
     listener.onUnpersistRDD(SparkListenerUnpersistRDD(1))
     assert(listener.executorIdToStorageStatus("big").numBlocks === 0)
   }
+
+  private def postUpdateBlock(
+      listener: StorageStatusListener, updateBlockInfos: Seq[BlockUpdatedInfo]): Unit = {
+    updateBlockInfos.foreach { updateBlockInfo =>
+      listener.onBlockUpdated(SparkListenerBlockUpdated(updateBlockInfo))
+    }
+  }
 }
diff --git a/core/src/test/scala/org/apache/spark/ui/storage/StorageTabSuite.scala b/core/src/test/scala/org/apache/spark/ui/storage/StorageTabSuite.scala
index 6b7c538ac8..7d77deeb60 100644
--- a/core/src/test/scala/org/apache/spark/ui/storage/StorageTabSuite.scala
+++ b/core/src/test/scala/org/apache/spark/ui/storage/StorageTabSuite.scala
@@ -106,7 +106,7 @@ class StorageTabSuite extends SparkFunSuite with BeforeAndAfter {
     assert(storageListener.rddInfoList.size === 0)
   }
 
-  test("task end") {
+  test("block update") {
     val myRddInfo0 = rddInfo0
     val myRddInfo1 = rddInfo1
     val myRddInfo2 = rddInfo2
@@ -120,19 +120,13 @@ class StorageTabSuite extends SparkFunSuite with BeforeAndAfter {
     assert(!storageListener._rddInfoMap(1).isCached)
     assert(!storageListener._rddInfoMap(2).isCached)
 
-    // Task end with no updated blocks. This should not change anything.
-    bus.postToAll(SparkListenerTaskEnd(0, 0, "obliteration", Success, taskInfo, new TaskMetrics))
-    assert(storageListener._rddInfoMap.size === 3)
-    assert(storageListener.rddInfoList.size === 0)
-
-    // Task end with a few new persisted blocks, some from the same RDD
-    val metrics1 = new TaskMetrics
-    metrics1.setUpdatedBlockStatuses(Seq(
-      (RDDBlockId(0, 100), BlockStatus(memAndDisk, 400L, 0L)),
-      (RDDBlockId(0, 101), BlockStatus(memAndDisk, 0L, 400L)),
-      (RDDBlockId(1, 20), BlockStatus(memAndDisk, 0L, 240L))
-    ))
-    bus.postToAll(SparkListenerTaskEnd(1, 0, "obliteration", Success, taskInfo, metrics1))
+    // Some blocks updated
+    val blockUpdateInfos = Seq(
+      BlockUpdatedInfo(bm1, RDDBlockId(0, 100), memAndDisk, 400L, 0L),
+      BlockUpdatedInfo(bm1, RDDBlockId(0, 101), memAndDisk, 0L, 400L),
+      BlockUpdatedInfo(bm1, RDDBlockId(1, 20), memAndDisk, 0L, 240L)
+    )
+    postUpdateBlocks(bus, blockUpdateInfos)
     assert(storageListener._rddInfoMap(0).memSize === 400L)
     assert(storageListener._rddInfoMap(0).diskSize === 400L)
     assert(storageListener._rddInfoMap(0).numCachedPartitions === 2)
@@ -144,15 +138,14 @@ class StorageTabSuite extends SparkFunSuite with BeforeAndAfter {
     assert(!storageListener._rddInfoMap(2).isCached)
     assert(storageListener._rddInfoMap(2).numCachedPartitions === 0)
 
-    // Task end with a few dropped blocks
-    val metrics2 = new TaskMetrics
-    metrics2.setUpdatedBlockStatuses(Seq(
-      (RDDBlockId(0, 100), BlockStatus(none, 0L, 0L)),
-      (RDDBlockId(1, 20), BlockStatus(none, 0L, 0L)),
-      (RDDBlockId(2, 40), BlockStatus(none, 0L, 0L)), // doesn't actually exist
-      (RDDBlockId(4, 80), BlockStatus(none, 0L, 0L)) // doesn't actually exist
-    ))
-    bus.postToAll(SparkListenerTaskEnd(2, 0, "obliteration", Success, taskInfo, metrics2))
+    // Drop some blocks
+    val blockUpdateInfos2 = Seq(
+      BlockUpdatedInfo(bm1, RDDBlockId(0, 100), none, 0L, 0L),
+      BlockUpdatedInfo(bm1, RDDBlockId(1, 20), none, 0L, 0L),
+      BlockUpdatedInfo(bm1, RDDBlockId(2, 40), none, 0L, 0L), // doesn't actually exist
+      BlockUpdatedInfo(bm1, RDDBlockId(4, 80), none, 0L, 0L) // doesn't actually exist
+    )
+    postUpdateBlocks(bus, blockUpdateInfos2)
     assert(storageListener._rddInfoMap(0).memSize === 0L)
     assert(storageListener._rddInfoMap(0).diskSize === 400L)
     assert(storageListener._rddInfoMap(0).numCachedPartitions === 1)
@@ -169,24 +162,27 @@ class StorageTabSuite extends SparkFunSuite with BeforeAndAfter {
     val rddInfo1 = new RDDInfo(1, "rdd1", 1, memOnly, Seq(4))
     val stageInfo0 = new StageInfo(0, 0, "stage0", 1, Seq(rddInfo0), Seq.empty, "details")
     val stageInfo1 = new StageInfo(1, 0, "stage1", 1, Seq(rddInfo1), Seq.empty, "details")
-    val taskMetrics0 = new TaskMetrics
-    val taskMetrics1 = new TaskMetrics
-    val block0 = (RDDBlockId(0, 1), BlockStatus(memOnly, 100L, 0L))
-    val block1 = (RDDBlockId(1, 1), BlockStatus(memOnly, 200L, 0L))
-    taskMetrics0.setUpdatedBlockStatuses(Seq(block0))
-    taskMetrics1.setUpdatedBlockStatuses(Seq(block1))
+    val blockUpdateInfos1 = Seq(BlockUpdatedInfo(bm1, RDDBlockId(0, 1), memOnly, 100L, 0L))
+    val blockUpdateInfos2 = Seq(BlockUpdatedInfo(bm1, RDDBlockId(1, 1), memOnly, 200L, 0L))
     bus.postToAll(SparkListenerBlockManagerAdded(1L, bm1, 1000L))
     bus.postToAll(SparkListenerStageSubmitted(stageInfo0))
     assert(storageListener.rddInfoList.size === 0)
-    bus.postToAll(SparkListenerTaskEnd(0, 0, "big", Success, taskInfo, taskMetrics0))
+    postUpdateBlocks(bus, blockUpdateInfos1)
     assert(storageListener.rddInfoList.size === 1)
     bus.postToAll(SparkListenerStageSubmitted(stageInfo1))
     assert(storageListener.rddInfoList.size === 1)
     bus.postToAll(SparkListenerStageCompleted(stageInfo0))
     assert(storageListener.rddInfoList.size === 1)
-    bus.postToAll(SparkListenerTaskEnd(1, 0, "small", Success, taskInfo1, taskMetrics1))
+    postUpdateBlocks(bus, blockUpdateInfos2)
     assert(storageListener.rddInfoList.size === 2)
     bus.postToAll(SparkListenerStageCompleted(stageInfo1))
     assert(storageListener.rddInfoList.size === 2)
   }
+
+  private def postUpdateBlocks(
+      bus: SparkListenerBus, blockUpdateInfos: Seq[BlockUpdatedInfo]): Unit = {
+    blockUpdateInfos.foreach { blockUpdateInfo =>
+      bus.postToAll(SparkListenerBlockUpdated(blockUpdateInfo))
+    }
+  }
 }
-- 
cgit v1.2.3


From 2bc7c96d61a51bd458ba04e9d318640ddada559d Mon Sep 17 00:00:00 2001
From: jerryshao 
Date: Mon, 28 Mar 2016 17:03:21 -0700
Subject: [SPARK-13447][YARN][CORE] Clean the stale states for AM failure and
 restart situation

## What changes were proposed in this pull request?

This is a follow-up fix of #9963, in #9963 we handle this stale states clean-up work only for dynamic allocation enabled scenario. Here we should also clean the states in `CoarseGrainedSchedulerBackend` for dynamic allocation disabled scenario.

Please review, CC andrewor14 lianhuiwang , thanks a lot.

## How was this patch tested?

Run the unit test locally, also with integration test manually.

Author: jerryshao 

Closes #11366 from jerryshao/SPARK-13447.
---
 .../cluster/CoarseGrainedSchedulerBackend.scala     | 21 +++++++++------------
 1 file changed, 9 insertions(+), 12 deletions(-)

diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala
index b7919efc4b..eb4f5331d6 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala
@@ -356,20 +356,17 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp
 
   /**
    * Reset the state of CoarseGrainedSchedulerBackend to the initial state. Currently it will only
-   * be called in the yarn-client mode when AM re-registers after a failure, also dynamic
-   * allocation is enabled.
+   * be called in the yarn-client mode when AM re-registers after a failure.
    * */
   protected def reset(): Unit = synchronized {
-    if (Utils.isDynamicAllocationEnabled(conf)) {
-      numPendingExecutors = 0
-      executorsPendingToRemove.clear()
-
-      // Remove all the lingering executors that should be removed but not yet. The reason might be
-      // because (1) disconnected event is not yet received; (2) executors die silently.
-      executorDataMap.toMap.foreach { case (eid, _) =>
-        driverEndpoint.askWithRetry[Boolean](
-          RemoveExecutor(eid, SlaveLost("Stale executor after cluster manager re-registered.")))
-      }
+    numPendingExecutors = 0
+    executorsPendingToRemove.clear()
+
+    // Remove all the lingering executors that should be removed but not yet. The reason might be
+    // because (1) disconnected event is not yet received; (2) executors die silently.
+    executorDataMap.toMap.foreach { case (eid, _) =>
+      driverEndpoint.askWithRetry[Boolean](
+        RemoveExecutor(eid, SlaveLost("Stale executor after cluster manager re-registered.")))
     }
   }
 
-- 
cgit v1.2.3


From 289257c4c6005779e416b23e593c61e6531b2d9a Mon Sep 17 00:00:00 2001
From: Dongjoon Hyun 
Date: Mon, 28 Mar 2016 17:38:45 -0700
Subject: [SPARK-14219][GRAPHX] Fix `pickRandomVertex` not to fall into
 infinite loops for graphs with one vertex

## What changes were proposed in this pull request?

Currently, `GraphOps.pickRandomVertex()` falls into infinite loops for graphs having only one vertex. This PR fixes it by modifying the following termination-checking condition.
```scala
-      if (selectedVertices.count > 1) {
+      if (selectedVertices.count > 0) {
```

## How was this patch tested?

Pass the Jenkins tests (including new test case).

Author: Dongjoon Hyun 

Closes #12018 from dongjoon-hyun/SPARK-14219.
---
 graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala   | 2 +-
 graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala | 9 +++++++++
 2 files changed, 10 insertions(+), 1 deletion(-)

diff --git a/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala b/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala
index fcb1b5999f..a783fe305f 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala
@@ -276,7 +276,7 @@ class GraphOps[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED]) extends Seriali
         if (Random.nextDouble() < probability) { Some(vidVvals._1) }
         else { None }
       }
-      if (selectedVertices.count > 1) {
+      if (selectedVertices.count > 0) {
         found = true
         val collectedVertices = selectedVertices.collect()
         retVal = collectedVertices(Random.nextInt(collectedVertices.length))
diff --git a/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala
index cb981797d3..96aa262a39 100644
--- a/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala
+++ b/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala
@@ -404,4 +404,13 @@ class GraphSuite extends SparkFunSuite with LocalSparkContext {
       assert(sc.getPersistentRDDs.isEmpty)
     }
   }
+
+  test("SPARK-14219: pickRandomVertex") {
+    withSpark { sc =>
+      val vert = sc.parallelize(List((1L, "a")), 1)
+      val edges = sc.parallelize(List(Edge[Long](1L, 1L)), 1)
+      val g0 = Graph(vert, edges)
+      assert(g0.pickRandomVertex() === 1L)
+    }
+  }
 }
-- 
cgit v1.2.3


From 38326cad873017ca07e90bc4472d01a42589d4cb Mon Sep 17 00:00:00 2001
From: Wenchen Fan 
Date: Mon, 28 Mar 2016 18:53:47 -0700
Subject: [SPARK-14205][SQL] remove trait Queryable

## What changes were proposed in this pull request?

After DataFrame and Dataset are merged, the trait `Queryable` becomes unnecessary as it has only one implementation. We should remove it.

## How was this patch tested?

existing tests.

Author: Wenchen Fan 

Closes #12001 from cloud-fan/df-ds.
---
 project/MimaExcludes.scala                         |   3 +
 .../main/scala/org/apache/spark/sql/Dataset.scala  |  88 +++++++++++++--
 .../apache/spark/sql/KeyValueGroupedDataset.scala  |  13 ---
 .../scala/org/apache/spark/sql/SQLContext.scala    |   4 +-
 .../apache/spark/sql/execution/CacheManager.scala  |  17 +--
 .../org/apache/spark/sql/execution/Queryable.scala | 124 ---------------------
 .../scala/org/apache/spark/sql/QueryTest.scala     |  10 +-
 7 files changed, 98 insertions(+), 161 deletions(-)
 delete mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/Queryable.scala

diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala
index 208c7a28cf..94621d7fa3 100644
--- a/project/MimaExcludes.scala
+++ b/project/MimaExcludes.scala
@@ -589,6 +589,9 @@ object MimaExcludes {
         ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.mllib.util.MLUtils.loadLabeledData"),
         ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.mllib.optimization.LBFGS.setMaxNumIterations"),
         ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.evaluation.BinaryClassificationEvaluator.setScoreCol")
+      ) ++ Seq(
+        // [SPARK-14205][SQL] remove trait Queryable
+        ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.sql.Dataset")
       )
     case v if v.startsWith("1.6") =>
       Seq(
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index 703ea4d149..41cb799b97 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -22,8 +22,10 @@ import java.io.CharArrayWriter
 import scala.collection.JavaConverters._
 import scala.language.implicitConversions
 import scala.reflect.runtime.universe.TypeTag
+import scala.util.control.NonFatal
 
 import com.fasterxml.jackson.core.JsonFactory
+import org.apache.commons.lang3.StringUtils
 
 import org.apache.spark.annotation.{DeveloperApi, Experimental}
 import org.apache.spark.api.java.JavaRDD
@@ -39,7 +41,7 @@ import org.apache.spark.sql.catalyst.optimizer.CombineUnions
 import org.apache.spark.sql.catalyst.plans._
 import org.apache.spark.sql.catalyst.plans.logical._
 import org.apache.spark.sql.catalyst.util.usePrettyExpression
-import org.apache.spark.sql.execution.{FileRelation, LogicalRDD, Queryable, QueryExecution, SQLExecution}
+import org.apache.spark.sql.execution.{FileRelation, LogicalRDD, QueryExecution, SQLExecution}
 import org.apache.spark.sql.execution.command.ExplainCommand
 import org.apache.spark.sql.execution.datasources.{CreateTableUsingAsSelect, LogicalRelation}
 import org.apache.spark.sql.execution.datasources.json.JacksonGenerator
@@ -150,10 +152,10 @@ private[sql] object Dataset {
  * @since 1.6.0
  */
 class Dataset[T] private[sql](
-    @transient override val sqlContext: SQLContext,
-    @DeveloperApi @transient override val queryExecution: QueryExecution,
+    @transient val sqlContext: SQLContext,
+    @DeveloperApi @transient val queryExecution: QueryExecution,
     encoder: Encoder[T])
-  extends Queryable with Serializable {
+  extends Serializable {
 
   queryExecution.assertAnalyzed()
 
@@ -224,7 +226,7 @@ class Dataset[T] private[sql](
    * @param _numRows Number of rows to show
    * @param truncate Whether truncate long strings and align cells right
    */
-  override private[sql] def showString(_numRows: Int, truncate: Boolean = true): String = {
+  private[sql] def showString(_numRows: Int, truncate: Boolean = true): String = {
     val numRows = _numRows.max(0)
     val takeResult = take(numRows + 1)
     val hasMoreData = takeResult.length > numRows
@@ -249,7 +251,75 @@ class Dataset[T] private[sql](
       }: Seq[String]
     }
 
-    formatString ( rows, numRows, hasMoreData, truncate )
+    val sb = new StringBuilder
+    val numCols = schema.fieldNames.length
+
+    // Initialise the width of each column to a minimum value of '3'
+    val colWidths = Array.fill(numCols)(3)
+
+    // Compute the width of each column
+    for (row <- rows) {
+      for ((cell, i) <- row.zipWithIndex) {
+        colWidths(i) = math.max(colWidths(i), cell.length)
+      }
+    }
+
+    // Create SeparateLine
+    val sep: String = colWidths.map("-" * _).addString(sb, "+", "+", "+\n").toString()
+
+    // column names
+    rows.head.zipWithIndex.map { case (cell, i) =>
+      if (truncate) {
+        StringUtils.leftPad(cell, colWidths(i))
+      } else {
+        StringUtils.rightPad(cell, colWidths(i))
+      }
+    }.addString(sb, "|", "|", "|\n")
+
+    sb.append(sep)
+
+    // data
+    rows.tail.map {
+      _.zipWithIndex.map { case (cell, i) =>
+        if (truncate) {
+          StringUtils.leftPad(cell.toString, colWidths(i))
+        } else {
+          StringUtils.rightPad(cell.toString, colWidths(i))
+        }
+      }.addString(sb, "|", "|", "|\n")
+    }
+
+    sb.append(sep)
+
+    // For Data that has more than "numRows" records
+    if (hasMoreData) {
+      val rowsString = if (numRows == 1) "row" else "rows"
+      sb.append(s"only showing top $numRows $rowsString\n")
+    }
+
+    sb.toString()
+  }
+
+  override def toString: String = {
+    try {
+      val builder = new StringBuilder
+      val fields = schema.take(2).map {
+        case f => s"${f.name}: ${f.dataType.simpleString(2)}"
+      }
+      builder.append("[")
+      builder.append(fields.mkString(", "))
+      if (schema.length > 2) {
+        if (schema.length - fields.size == 1) {
+          builder.append(" ... 1 more field")
+        } else {
+          builder.append(" ... " + (schema.length - 2) + " more fields")
+        }
+      }
+      builder.append("]").toString()
+    } catch {
+      case NonFatal(e) =>
+        s"Invalid tree; ${e.getMessage}:\n$queryExecution"
+    }
   }
 
   /**
@@ -325,7 +395,7 @@ class Dataset[T] private[sql](
    * @since 1.6.0
    */
   // scalastyle:off println
-  override def printSchema(): Unit = println(schema.treeString)
+  def printSchema(): Unit = println(schema.treeString)
   // scalastyle:on println
 
   /**
@@ -334,7 +404,7 @@ class Dataset[T] private[sql](
    * @group basic
    * @since 1.6.0
    */
-  override def explain(extended: Boolean): Unit = {
+  def explain(extended: Boolean): Unit = {
     val explain = ExplainCommand(queryExecution.logical, extended = extended)
     sqlContext.executePlan(explain).executedPlan.executeCollect().foreach {
       // scalastyle:off println
@@ -349,7 +419,7 @@ class Dataset[T] private[sql](
    * @group basic
    * @since 1.6.0
    */
-  override def explain(): Unit = explain(extended = false)
+  def explain(): Unit = explain(extended = false)
 
   /**
    * Returns all column names and their data types as an array.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala
index 07aa1515f3..f19ad6e707 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala
@@ -57,13 +57,6 @@ class KeyValueGroupedDataset[K, V] private[sql](
   private def logicalPlan = queryExecution.analyzed
   private def sqlContext = queryExecution.sqlContext
 
-  private def groupedData = {
-    new RelationalGroupedDataset(
-      Dataset.ofRows(sqlContext, logicalPlan),
-      groupingAttributes,
-      RelationalGroupedDataset.GroupByType)
-  }
-
   /**
    * Returns a new [[KeyValueGroupedDataset]] where the type of the key has been mapped to the
    * specified type. The mapping of key columns to the type follows the same rules as `as` on
@@ -207,12 +200,6 @@ class KeyValueGroupedDataset[K, V] private[sql](
     reduceGroups(f.call _)
   }
 
-  private def withEncoder(c: Column): Column = c match {
-    case tc: TypedColumn[_, _] =>
-      tc.withInputType(resolvedVEncoder.bind(dataAttributes), dataAttributes)
-    case _ => c
-  }
-
   /**
    * Internal helper function for building typed aggregations that return tuples.  For simplicity
    * and code reuse, we do this without the help of the type system and then use helper functions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
index c94600925f..0576a1a178 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
@@ -272,11 +272,11 @@ class SQLContext private[sql](
   }
 
   /**
-   * Returns true if the [[Queryable]] is currently cached in-memory.
+   * Returns true if the [[Dataset]] is currently cached in-memory.
    * @group cachemgmt
    * @since 1.3.0
    */
-  private[sql] def isCached(qName: Queryable): Boolean = {
+  private[sql] def isCached(qName: Dataset[_]): Boolean = {
     cacheManager.lookupCachedData(qName).nonEmpty
   }
 
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala
index 14b8b6fc3b..f3478a873a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala
@@ -22,6 +22,7 @@ import java.util.concurrent.locks.ReentrantReadWriteLock
 import org.apache.spark.internal.Logging
 import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
 import org.apache.spark.sql.execution.columnar.InMemoryRelation
+import org.apache.spark.sql.Dataset
 import org.apache.spark.storage.StorageLevel
 import org.apache.spark.storage.StorageLevel.MEMORY_AND_DISK
 
@@ -74,12 +75,12 @@ private[sql] class CacheManager extends Logging {
   }
 
   /**
-   * Caches the data produced by the logical representation of the given [[Queryable]].
+   * Caches the data produced by the logical representation of the given [[Dataset]].
    * Unlike `RDD.cache()`, the default storage level is set to be `MEMORY_AND_DISK` because
    * recomputing the in-memory columnar representation of the underlying table is expensive.
    */
   private[sql] def cacheQuery(
-      query: Queryable,
+      query: Dataset[_],
       tableName: Option[String] = None,
       storageLevel: StorageLevel = MEMORY_AND_DISK): Unit = writeLock {
     val planToCache = query.queryExecution.analyzed
@@ -99,8 +100,8 @@ private[sql] class CacheManager extends Logging {
     }
   }
 
-  /** Removes the data for the given [[Queryable]] from the cache */
-  private[sql] def uncacheQuery(query: Queryable, blocking: Boolean = true): Unit = writeLock {
+  /** Removes the data for the given [[Dataset]] from the cache */
+  private[sql] def uncacheQuery(query: Dataset[_], blocking: Boolean = true): Unit = writeLock {
     val planToCache = query.queryExecution.analyzed
     val dataIndex = cachedData.indexWhere(cd => planToCache.sameResult(cd.plan))
     require(dataIndex >= 0, s"Table $query is not cached.")
@@ -108,11 +109,11 @@ private[sql] class CacheManager extends Logging {
     cachedData.remove(dataIndex)
   }
 
-  /** Tries to remove the data for the given [[Queryable]] from the cache
+  /** Tries to remove the data for the given [[Dataset]] from the cache
     * if it's cached
     */
   private[sql] def tryUncacheQuery(
-      query: Queryable,
+      query: Dataset[_],
       blocking: Boolean = true): Boolean = writeLock {
     val planToCache = query.queryExecution.analyzed
     val dataIndex = cachedData.indexWhere(cd => planToCache.sameResult(cd.plan))
@@ -124,8 +125,8 @@ private[sql] class CacheManager extends Logging {
     found
   }
 
-  /** Optionally returns cached data for the given [[Queryable]] */
-  private[sql] def lookupCachedData(query: Queryable): Option[CachedData] = readLock {
+  /** Optionally returns cached data for the given [[Dataset]] */
+  private[sql] def lookupCachedData(query: Dataset[_]): Option[CachedData] = readLock {
     lookupCachedData(query.queryExecution.analyzed)
   }
 
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Queryable.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Queryable.scala
deleted file mode 100644
index 38263af0f7..0000000000
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Queryable.scala
+++ /dev/null
@@ -1,124 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.execution
-
-import scala.util.control.NonFatal
-
-import org.apache.commons.lang3.StringUtils
-
-import org.apache.spark.sql.SQLContext
-import org.apache.spark.sql.types.StructType
-
-/** A trait that holds shared code between DataFrames and Datasets. */
-private[sql] trait Queryable {
-  def schema: StructType
-  def queryExecution: QueryExecution
-  def sqlContext: SQLContext
-
-  override def toString: String = {
-    try {
-      val builder = new StringBuilder
-      val fields = schema.take(2).map {
-        case f => s"${f.name}: ${f.dataType.simpleString(2)}"
-      }
-      builder.append("[")
-      builder.append(fields.mkString(", "))
-      if (schema.length > 2) {
-        if (schema.length - fields.size == 1) {
-          builder.append(" ... 1 more field")
-        } else {
-          builder.append(" ... " + (schema.length - 2) + " more fields")
-        }
-      }
-      builder.append("]").toString()
-    } catch {
-      case NonFatal(e) =>
-        s"Invalid tree; ${e.getMessage}:\n$queryExecution"
-    }
-  }
-
-  def printSchema(): Unit
-
-  def explain(extended: Boolean): Unit
-
-  def explain(): Unit
-
-  private[sql] def showString(_numRows: Int, truncate: Boolean = true): String
-
-  /**
-   * Format the string representing rows for output
-   * @param rows The rows to show
-   * @param numRows Number of rows to show
-   * @param hasMoreData Whether some rows are not shown due to the limit
-   * @param truncate Whether truncate long strings and align cells right
-   *
-   */
-  private[sql] def formatString (
-      rows: Seq[Seq[String]],
-      numRows: Int,
-      hasMoreData : Boolean,
-      truncate: Boolean = true): String = {
-    val sb = new StringBuilder
-    val numCols = schema.fieldNames.length
-
-    // Initialise the width of each column to a minimum value of '3'
-    val colWidths = Array.fill(numCols)(3)
-
-    // Compute the width of each column
-    for (row <- rows) {
-      for ((cell, i) <- row.zipWithIndex) {
-        colWidths(i) = math.max(colWidths(i), cell.length)
-      }
-    }
-
-    // Create SeparateLine
-    val sep: String = colWidths.map("-" * _).addString(sb, "+", "+", "+\n").toString()
-
-    // column names
-    rows.head.zipWithIndex.map { case (cell, i) =>
-      if (truncate) {
-        StringUtils.leftPad(cell, colWidths(i))
-      } else {
-        StringUtils.rightPad(cell, colWidths(i))
-      }
-    }.addString(sb, "|", "|", "|\n")
-
-    sb.append(sep)
-
-    // data
-    rows.tail.map {
-      _.zipWithIndex.map { case (cell, i) =>
-        if (truncate) {
-          StringUtils.leftPad(cell.toString, colWidths(i))
-        } else {
-          StringUtils.rightPad(cell.toString, colWidths(i))
-        }
-      }.addString(sb, "|", "|", "|\n")
-    }
-
-    sb.append(sep)
-
-    // For Data that has more than "numRows" records
-    if (hasMoreData) {
-      val rowsString = if (numRows == 1) "row" else "rows"
-      sb.append(s"only showing top $numRows $rowsString\n")
-    }
-
-    sb.toString()
-  }
-}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
index a1b45ca7eb..7ff4ffcaec 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
@@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.plans._
 import org.apache.spark.sql.catalyst.plans.logical._
 import org.apache.spark.sql.catalyst.trees.TreeNode
 import org.apache.spark.sql.catalyst.util._
-import org.apache.spark.sql.execution.{LogicalRDD, Queryable}
+import org.apache.spark.sql.execution.LogicalRDD
 import org.apache.spark.sql.execution.columnar.InMemoryRelation
 import org.apache.spark.sql.execution.datasources.LogicalRelation
 
@@ -180,9 +180,9 @@ abstract class QueryTest extends PlanTest {
   }
 
   /**
-   * Asserts that a given [[Queryable]] will be executed using the given number of cached results.
+   * Asserts that a given [[Dataset]] will be executed using the given number of cached results.
    */
-  def assertCached(query: Queryable, numCachedTables: Int = 1): Unit = {
+  def assertCached(query: Dataset[_], numCachedTables: Int = 1): Unit = {
     val planWithCaching = query.queryExecution.withCachedData
     val cachedData = planWithCaching collect {
       case cached: InMemoryRelation => cached
@@ -286,9 +286,9 @@ abstract class QueryTest extends PlanTest {
   }
 
   /**
-    * Asserts that a given [[Queryable]] does not have missing inputs in all the analyzed plans.
+    * Asserts that a given [[Dataset]] does not have missing inputs in all the analyzed plans.
     */
-  def assertEmptyMissingInput(query: Queryable): Unit = {
+  def assertEmptyMissingInput(query: Dataset[_]): Unit = {
     assert(query.queryExecution.analyzed.missingInput.isEmpty,
       s"The analyzed logical plan has missing inputs: ${query.queryExecution.analyzed}")
     assert(query.queryExecution.optimizedPlan.missingInput.isEmpty,
-- 
cgit v1.2.3


From 27d4ef0c619501be592366ae8f0be77294c9687d Mon Sep 17 00:00:00 2001
From: Herman van Hovell 
Date: Mon, 28 Mar 2016 20:19:21 -0700
Subject: [SPARK-14213][SQL] Migrate HiveQl parsing to ANTLR4 parser

### What changes were proposed in this pull request?

This PR migrates all HiveQl parsing to the new ANTLR4 parser. This PR is build on top of https://github.com/apache/spark/pull/12011, and we should wait with merging until that one is in (hence the WIP tag).

As soon as this PR is merged we can start removing much of the old parser infrastructure.

### How was this patch tested?

Exisiting Hive unit tests.

cc rxin andrewor14 yhuai

Author: Herman van Hovell 

Closes #12015 from hvanhovell/SPARK-14213.
---
 .../apache/spark/sql/catalyst/parser/ng/SqlBase.g4 |  34 +-
 .../spark/sql/hive/HiveMetastoreCatalog.scala      |  13 +-
 .../apache/spark/sql/hive/HiveSessionState.scala   |   3 +-
 .../spark/sql/hive/execution/HiveSqlParser.scala   | 442 +++++++++++++++++++++
 4 files changed, 488 insertions(+), 4 deletions(-)
 create mode 100644 sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveSqlParser.scala

diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/ng/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/ng/SqlBase.g4
index e46fd9bed5..4e77b6db25 100644
--- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/ng/SqlBase.g4
+++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/ng/SqlBase.g4
@@ -118,7 +118,9 @@ statement
     | UNCACHE TABLE identifier                                         #uncacheTable
     | CLEAR CACHE                                                      #clearCache
     | ADD identifier .*?                                               #addResource
+    | SET ROLE .*?                                                     #failNativeCommand
     | SET .*?                                                          #setConfiguration
+    | kws=unsupportedHiveNativeCommands .*?                            #failNativeCommand
     | hiveNativeCommands                                               #executeNativeCommand
     ;
 
@@ -145,7 +147,26 @@ hiveNativeCommands
     | ROLLBACK WORK?
     | SHOW PARTITIONS tableIdentifier partitionSpec?
     | DFS .*?
-    | (CREATE | ALTER | DROP | SHOW | DESC | DESCRIBE | REVOKE | GRANT | LOCK | UNLOCK | MSCK | EXPORT | IMPORT | LOAD) .*?
+    | (CREATE | ALTER | DROP | SHOW | DESC | DESCRIBE | LOCK | UNLOCK | MSCK | LOAD) .*?
+    ;
+
+unsupportedHiveNativeCommands
+    : kw1=CREATE kw2=ROLE
+    | kw1=DROP kw2=ROLE
+    | kw1=GRANT kw2=ROLE?
+    | kw1=REVOKE kw2=ROLE?
+    | kw1=SHOW kw2=GRANT
+    | kw1=SHOW kw2=ROLE kw3=GRANT?
+    | kw1=SHOW kw2=PRINCIPALS
+    | kw1=SHOW kw2=ROLES
+    | kw1=SHOW kw2=CURRENT kw3=ROLES
+    | kw1=EXPORT kw2=TABLE
+    | kw1=IMPORT kw2=TABLE
+    | kw1=SHOW kw2=COMPACTIONS
+    | kw1=SHOW kw2=CREATE kw3=TABLE
+    | kw1=SHOW kw2=TRANSACTIONS
+    | kw1=SHOW kw2=INDEXES
+    | kw1=SHOW kw2=LOCKS
     ;
 
 createTableHeader
@@ -619,7 +640,8 @@ nonReserved
     | AFTER | CASCADE | RESTRICT | BUCKETS | CLUSTERED | SORTED | PURGE | INPUTFORMAT | OUTPUTFORMAT
     | INPUTDRIVER | OUTPUTDRIVER | DBPROPERTIES | DFS | TRUNCATE | METADATA | REPLICATION | COMPUTE
     | STATISTICS | ANALYZE | PARTITIONED | EXTERNAL | DEFINED | RECORDWRITER
-    | REVOKE | GRANT | LOCK | UNLOCK | MSCK | EXPORT | IMPORT | LOAD | VALUES | COMMENT
+    | REVOKE | GRANT | LOCK | UNLOCK | MSCK | EXPORT | IMPORT | LOAD | VALUES | COMMENT | ROLE
+    | ROLES | COMPACTIONS | PRINCIPALS | TRANSACTIONS | INDEXES | LOCKS | OPTION
     ;
 
 SELECT: 'SELECT';
@@ -834,6 +856,14 @@ MSCK: 'MSCK';
 EXPORT: 'EXPORT';
 IMPORT: 'IMPORT';
 LOAD: 'LOAD';
+ROLE: 'ROLE';
+ROLES: 'ROLES';
+COMPACTIONS: 'COMPACTIONS';
+PRINCIPALS: 'PRINCIPALS';
+TRANSACTIONS: 'TRANSACTIONS';
+INDEXES: 'INDEXES';
+LOCKS: 'LOCKS';
+OPTION: 'OPTION';
 
 STRING
     : '\'' ( ~('\''|'\\') | ('\\' .) )* '\''
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala
index eedd12d76a..9a5ec9880e 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala
@@ -85,7 +85,18 @@ private[hive] object HiveSerDe {
         HiveSerDe(
           inputFormat = Option("org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat"),
           outputFormat = Option("org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat"),
-          serde = Option("org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe")))
+          serde = Option("org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe")),
+
+      "textfile" ->
+        HiveSerDe(
+          inputFormat = Option("org.apache.hadoop.mapred.TextInputFormat"),
+          outputFormat = Option("org.apache.hadoop.hive.ql.io.IgnoreKeyTextOutputFormat")),
+
+      "avro" ->
+        HiveSerDe(
+          inputFormat = Option("org.apache.hadoop.hive.ql.io.avro.AvroContainerInputFormat"),
+          outputFormat = Option("org.apache.hadoop.hive.ql.io.avro.AvroContainerOutputFormat"),
+          serde = Option("org.apache.hadoop.hive.serde2.avro.AvroSerDe")))
 
     val key = source.toLowerCase match {
       case s if s.startsWith("org.apache.spark.sql.parquet") => "parquet"
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala
index c9b6b1dfb6..11ef0fd1bb 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala
@@ -22,6 +22,7 @@ import org.apache.spark.sql.catalyst.analysis.{Analyzer, FunctionRegistry}
 import org.apache.spark.sql.catalyst.parser.ParserInterface
 import org.apache.spark.sql.execution.{python, SparkPlanner}
 import org.apache.spark.sql.execution.datasources._
+import org.apache.spark.sql.hive.execution.HiveSqlParser
 import org.apache.spark.sql.internal.{SessionState, SQLConf}
 
 
@@ -70,7 +71,7 @@ private[hive] class HiveSessionState(ctx: HiveContext) extends SessionState(ctx)
   /**
    * Parser for HiveQl query texts.
    */
-  override lazy val sqlParser: ParserInterface = new HiveQl(conf)
+  override lazy val sqlParser: ParserInterface = HiveSqlParser
 
   /**
    * Planner that takes into account Hive-specific strategies.
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveSqlParser.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveSqlParser.scala
new file mode 100644
index 0000000000..d6a08fcc57
--- /dev/null
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveSqlParser.scala
@@ -0,0 +1,442 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.hive.execution
+
+import scala.collection.JavaConverters._
+
+import org.antlr.v4.runtime.{ParserRuleContext, Token}
+import org.apache.hadoop.hive.conf.HiveConf
+import org.apache.hadoop.hive.conf.HiveConf.ConfVars
+import org.apache.hadoop.hive.ql.exec.FunctionRegistry
+import org.apache.hadoop.hive.ql.parse.EximUtil
+import org.apache.hadoop.hive.ql.session.SessionState
+import org.apache.hadoop.hive.serde.serdeConstants
+import org.apache.hadoop.hive.serde2.`lazy`.LazySimpleSerDe
+
+import org.apache.spark.sql.catalyst.catalog.{CatalogColumn, CatalogStorageFormat, CatalogTable, CatalogTableType}
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.parser.ng._
+import org.apache.spark.sql.catalyst.parser.ng.SqlBaseParser._
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.execution.SparkSqlAstBuilder
+import org.apache.spark.sql.hive.{CreateTableAsSelect => CTAS, CreateViewAsSelect => CreateView}
+import org.apache.spark.sql.hive.{HiveGenericUDTF, HiveSerDe}
+import org.apache.spark.sql.hive.HiveShim.HiveFunctionWrapper
+
+/**
+ * Concrete parser for HiveQl statements.
+ */
+object HiveSqlParser extends AbstractSqlParser {
+  val astBuilder = new HiveSqlAstBuilder
+
+  override protected def nativeCommand(sqlText: String): LogicalPlan = {
+    HiveNativeCommand(sqlText)
+  }
+}
+
+/**
+ * Builder that converts an ANTLR ParseTree into a LogicalPlan/Expression/TableIdentifier.
+ */
+class HiveSqlAstBuilder extends SparkSqlAstBuilder {
+  import ParserUtils._
+
+  /**
+   * Get the current Hive Configuration.
+   */
+  private[this] def hiveConf: HiveConf = {
+    var ss = SessionState.get()
+    // SessionState is lazy initialization, it can be null here
+    if (ss == null) {
+      val original = Thread.currentThread().getContextClassLoader
+      val conf = new HiveConf(classOf[SessionState])
+      conf.setClassLoader(original)
+      ss = new SessionState(conf)
+      SessionState.start(ss)
+    }
+    ss.getConf
+  }
+
+  /**
+   * Pass a command to Hive using a [[HiveNativeCommand]].
+   */
+  override def visitExecuteNativeCommand(
+      ctx: ExecuteNativeCommandContext): LogicalPlan = withOrigin(ctx) {
+    HiveNativeCommand(command(ctx))
+  }
+
+  /**
+   * Fail an unsupported Hive native command.
+   */
+  override def visitFailNativeCommand(
+      ctx: FailNativeCommandContext): LogicalPlan = withOrigin(ctx) {
+    val keywords = if (ctx.kws != null) {
+      Seq(ctx.kws.kw1, ctx.kws.kw2, ctx.kws.kw3).filter(_ != null).map(_.getText).mkString(" ")
+    } else {
+      // SET ROLE is the exception to the rule, because we handle this before other SET commands.
+      "SET ROLE"
+    }
+    throw new ParseException(s"Unsupported operation: $keywords", ctx)
+  }
+
+  /**
+   * Create an [[AddJar]] or [[AddFile]] command depending on the requested resource.
+   */
+  override def visitAddResource(ctx: AddResourceContext): LogicalPlan = withOrigin(ctx) {
+    ctx.identifier.getText.toLowerCase match {
+      case "file" => AddFile(remainder(ctx.identifier).trim)
+      case "jar" => AddJar(remainder(ctx.identifier).trim)
+      case other => throw new ParseException(s"Unsupported resource type '$other'.", ctx)
+    }
+  }
+
+  /**
+   * Create a [[DropTable]] command.
+   */
+  override def visitDropTable(ctx: DropTableContext): LogicalPlan = withOrigin(ctx) {
+    if (ctx.PURGE != null) {
+      logWarning("PURGE option is ignored.")
+    }
+    if (ctx.REPLICATION != null) {
+      logWarning("REPLICATION clause is ignored.")
+    }
+    DropTable(visitTableIdentifier(ctx.tableIdentifier).toString, ctx.EXISTS != null)
+  }
+
+  /**
+   * Create an [[AnalyzeTable]] command. This currently only implements the NOSCAN option (other
+   * options are passed on to Hive) e.g.:
+   * {{{
+   *   ANALYZE TABLE table COMPUTE STATISTICS NOSCAN;
+   * }}}
+   */
+  override def visitAnalyze(ctx: AnalyzeContext): LogicalPlan = withOrigin(ctx) {
+    if (ctx.partitionSpec == null &&
+      ctx.identifier != null &&
+      ctx.identifier.getText.toLowerCase == "noscan") {
+      AnalyzeTable(visitTableIdentifier(ctx.tableIdentifier).toString)
+    } else {
+      HiveNativeCommand(command(ctx))
+    }
+  }
+
+  /**
+   * Create a [[CreateTableAsSelect]] command.
+   */
+  override def visitCreateTable(ctx: CreateTableContext): LogicalPlan = {
+    if (ctx.query == null) {
+      HiveNativeCommand(command(ctx))
+    } else {
+      // Get the table header.
+      val (table, temp, ifNotExists, external) = visitCreateTableHeader(ctx.createTableHeader)
+      val tableType = if (external) {
+        CatalogTableType.EXTERNAL_TABLE
+      } else {
+        CatalogTableType.MANAGED_TABLE
+      }
+
+      // Unsupported clauses.
+      if (temp) {
+        logWarning("TEMPORARY clause is ignored.")
+      }
+      if (ctx.bucketSpec != null) {
+        // TODO add this - we need cluster columns in the CatalogTable for this to work.
+        logWarning("CLUSTERED BY ... [ORDERED BY ...] INTO ... BUCKETS clause is ignored.")
+      }
+      if (ctx.skewSpec != null) {
+        logWarning("SKEWED BY ... ON ... [STORED AS DIRECTORIES] clause is ignored.")
+      }
+
+      // Create the schema.
+      val schema = Option(ctx.colTypeList).toSeq.flatMap(_.colType.asScala).map { col =>
+        CatalogColumn(
+          col.identifier.getText,
+          col.dataType.getText.toLowerCase, // TODO validate this?
+          nullable = true,
+          Option(col.STRING).map(string))
+      }
+
+      // Get the column by which the table is partitioned.
+      val partitionCols = Option(ctx.identifierList).toSeq.flatMap(visitIdentifierList).map {
+        CatalogColumn(_, null, nullable = true, None)
+      }
+
+      // Create the storage.
+      def format(fmt: ParserRuleContext): CatalogStorageFormat = {
+        Option(fmt).map(typedVisit[CatalogStorageFormat]).getOrElse(EmptyStorageFormat)
+      }
+      // Default storage.
+      val defaultStorageType = hiveConf.getVar(HiveConf.ConfVars.HIVEDEFAULTFILEFORMAT)
+      val hiveSerDe = HiveSerDe.sourceToSerDe(defaultStorageType, hiveConf).getOrElse {
+        HiveSerDe(
+          inputFormat = Option("org.apache.hadoop.mapred.TextInputFormat"),
+          outputFormat = Option("org.apache.hadoop.hive.ql.io.IgnoreKeyTextOutputFormat"))
+      }
+      // Defined storage.
+      val fileStorage = format(ctx.createFileFormat)
+      val rowStorage = format(ctx.rowFormat)
+      val storage = CatalogStorageFormat(
+        Option(ctx.locationSpec).map(visitLocationSpec),
+        fileStorage.inputFormat.orElse(hiveSerDe.inputFormat),
+        fileStorage.outputFormat.orElse(hiveSerDe.outputFormat),
+        rowStorage.serde.orElse(hiveSerDe.serde).orElse(fileStorage.serde),
+        rowStorage.serdeProperties ++ fileStorage.serdeProperties
+      )
+
+      val tableDesc = CatalogTable(
+        identifier = table,
+        tableType = tableType,
+        schema = schema,
+        partitionColumns = partitionCols,
+        storage = storage,
+        properties = Option(ctx.tablePropertyList).map(visitTablePropertyList).getOrElse(Map.empty),
+        // TODO support the sql text - have a proper location for this!
+        viewText = Option(ctx.STRING).map(string))
+      CTAS(tableDesc, plan(ctx.query), ifNotExists)
+    }
+  }
+
+  /**
+   * Create or replace a view. This creates a [[CreateViewAsSelect]] command.
+   */
+  override def visitCreateView(ctx: CreateViewContext): LogicalPlan = withOrigin(ctx) {
+    // Pass a partitioned view on to hive.
+    if (ctx.identifierList != null) {
+      HiveNativeCommand(command(ctx))
+    } else {
+      if (ctx.STRING != null) {
+        logWarning("COMMENT clause is ignored.")
+      }
+      val identifiers = Option(ctx.identifierCommentList).toSeq.flatMap(_.identifierComment.asScala)
+      val schema = identifiers.map { ic =>
+        CatalogColumn(ic.identifier.getText, null, nullable = true, Option(ic.STRING).map(string))
+      }
+      createView(
+        ctx,
+        ctx.tableIdentifier,
+        schema,
+        ctx.query,
+        Option(ctx.tablePropertyList).map(visitTablePropertyList).getOrElse(Map.empty),
+        ctx.EXISTS != null,
+        ctx.REPLACE != null
+      )
+    }
+  }
+
+  /**
+   * Alter the query of a view. This creates a [[CreateViewAsSelect]] command.
+   */
+  override def visitAlterViewQuery(ctx: AlterViewQueryContext): LogicalPlan = withOrigin(ctx) {
+    createView(
+      ctx,
+      ctx.tableIdentifier,
+      Seq.empty,
+      ctx.query,
+      Map.empty,
+      allowExist = false,
+      replace = true)
+  }
+
+  /**
+   * Create a [[CreateViewAsSelect]] command.
+   */
+  private def createView(
+      ctx: ParserRuleContext,
+      name: TableIdentifierContext,
+      schema: Seq[CatalogColumn],
+      query: QueryContext,
+      properties: Map[String, String],
+      allowExist: Boolean,
+      replace: Boolean): LogicalPlan = {
+    val sql = Option(source(query))
+    val tableDesc = CatalogTable(
+      identifier = visitTableIdentifier(name),
+      tableType = CatalogTableType.VIRTUAL_VIEW,
+      schema = schema,
+      storage = EmptyStorageFormat,
+      properties = properties,
+      viewOriginalText = sql,
+      viewText = sql)
+    CreateView(tableDesc, plan(query), allowExist, replace, command(ctx))
+  }
+
+  /**
+   * Create a [[Generator]]. Override this method in order to support custom Generators.
+   */
+  override protected def withGenerator(
+      name: String,
+      expressions: Seq[Expression],
+      ctx: LateralViewContext): Generator = {
+    val info = Option(FunctionRegistry.getFunctionInfo(name.toLowerCase)).getOrElse {
+      throw new ParseException(s"Couldn't find Generator function '$name'", ctx)
+    }
+    HiveGenericUDTF(name, new HiveFunctionWrapper(info.getFunctionClass.getName), expressions)
+  }
+
+  /**
+   * Create a [[HiveScriptIOSchema]].
+   */
+  override protected def withScriptIOSchema(
+      inRowFormat: RowFormatContext,
+      recordWriter: Token,
+      outRowFormat: RowFormatContext,
+      recordReader: Token,
+      schemaLess: Boolean): HiveScriptIOSchema = {
+    if (recordWriter != null || recordReader != null) {
+      logWarning("Used defined record reader/writer classes are currently ignored.")
+    }
+
+    // Decode and input/output format.
+    type Format = (Seq[(String, String)], Option[String], Seq[(String, String)], Option[String])
+    def format(fmt: RowFormatContext, confVar: ConfVars): Format = fmt match {
+      case c: RowFormatDelimitedContext =>
+        // TODO we should use the visitRowFormatDelimited function here. However HiveScriptIOSchema
+        // expects a seq of pairs in which the old parsers' token names are used as keys.
+        // Transforming the result of visitRowFormatDelimited would be quite a bit messier than
+        // retrieving the key value pairs ourselves.
+        def entry(key: String, value: Token): Seq[(String, String)] = {
+          Option(value).map(t => key -> t.getText).toSeq
+        }
+        val entries = entry("TOK_TABLEROWFORMATFIELD", c.fieldsTerminatedBy) ++
+          entry("TOK_TABLEROWFORMATCOLLITEMS", c.collectionItemsTerminatedBy) ++
+          entry("TOK_TABLEROWFORMATMAPKEYS", c.keysTerminatedBy) ++
+          entry("TOK_TABLEROWFORMATLINES", c.linesSeparatedBy) ++
+          entry("TOK_TABLEROWFORMATNULL", c.nullDefinedAs)
+
+        (entries, None, Seq.empty, None)
+
+      case c: RowFormatSerdeContext =>
+        // Use a serde format.
+        val CatalogStorageFormat(None, None, None, Some(name), props) = visitRowFormatSerde(c)
+
+        // SPARK-10310: Special cases LazySimpleSerDe
+        val recordHandler = if (name == classOf[LazySimpleSerDe].getCanonicalName) {
+          Option(hiveConf.getVar(confVar))
+        } else {
+          None
+        }
+        (Seq.empty, Option(name), props.toSeq, recordHandler)
+
+      case null =>
+        // Use default (serde) format.
+        val name = hiveConf.getVar(ConfVars.HIVESCRIPTSERDE)
+        val props = Seq(serdeConstants.FIELD_DELIM -> "\t")
+        val recordHandler = Option(hiveConf.getVar(confVar))
+        (Nil, Option(name), props, recordHandler)
+    }
+
+    val (inFormat, inSerdeClass, inSerdeProps, reader) =
+      format(inRowFormat, ConfVars.HIVESCRIPTRECORDREADER)
+
+    val (outFormat, outSerdeClass, outSerdeProps, writer) =
+      format(inRowFormat, ConfVars.HIVESCRIPTRECORDWRITER)
+
+    HiveScriptIOSchema(
+      inFormat, outFormat,
+      inSerdeClass, outSerdeClass,
+      inSerdeProps, outSerdeProps,
+      reader, writer,
+      schemaLess)
+  }
+
+  /**
+   * Create location string.
+   */
+  override def visitLocationSpec(ctx: LocationSpecContext): String = {
+    EximUtil.relativeToAbsolutePath(hiveConf, super.visitLocationSpec(ctx))
+  }
+
+  /** Empty storage format for default values and copies. */
+  private val EmptyStorageFormat = CatalogStorageFormat(None, None, None, None, Map.empty)
+
+  /**
+   * Create a [[CatalogStorageFormat]]. The INPUTDRIVER and OUTPUTDRIVER clauses are currently
+   * ignored.
+   */
+  override def visitTableFileFormat(
+      ctx: TableFileFormatContext): CatalogStorageFormat = withOrigin(ctx) {
+    import ctx._
+    if (inDriver != null || outDriver != null) {
+      logWarning("INPUTDRIVER ... OUTPUTDRIVER ... clauses are ignored.")
+    }
+    EmptyStorageFormat.copy(
+      inputFormat = Option(string(inFmt)),
+      outputFormat = Option(string(outFmt)),
+      serde = Option(serdeCls).map(string)
+    )
+  }
+
+  /**
+   * Resolve a [[HiveSerDe]] based on the format name given.
+   */
+  override def visitGenericFileFormat(
+      ctx: GenericFileFormatContext): CatalogStorageFormat = withOrigin(ctx) {
+    val source = ctx.identifier.getText
+    HiveSerDe.sourceToSerDe(source, hiveConf) match {
+      case Some(s) =>
+        EmptyStorageFormat.copy(
+          inputFormat = s.inputFormat,
+          outputFormat = s.outputFormat,
+          serde = s.serde)
+      case None =>
+        throw new ParseException(s"Unrecognized file format in STORED AS clause: $source", ctx)
+    }
+  }
+
+  /**
+   * Storage Handlers are currently not supported in the statements we support (CTAS).
+   */
+  override def visitStorageHandler(ctx: StorageHandlerContext): AnyRef = withOrigin(ctx) {
+    throw new ParseException("Storage Handlers are currently unsupported.", ctx)
+  }
+
+  /**
+   * Create SERDE row format name and properties pair.
+   */
+  override def visitRowFormatSerde(
+      ctx: RowFormatSerdeContext): CatalogStorageFormat = withOrigin(ctx) {
+    import ctx._
+    EmptyStorageFormat.copy(
+      serde = Option(string(name)),
+      serdeProperties = Option(tablePropertyList).map(visitTablePropertyList).getOrElse(Map.empty))
+  }
+
+  /**
+   * Create a delimited row format properties object.
+   */
+  override def visitRowFormatDelimited(
+      ctx: RowFormatDelimitedContext): CatalogStorageFormat = withOrigin(ctx) {
+    // Collect the entries if any.
+    def entry(key: String, value: Token): Seq[(String, String)] = {
+      Option(value).toSeq.map(x => key -> string(x))
+    }
+    // TODO we need proper support for the NULL format.
+    val entries = entry(serdeConstants.FIELD_DELIM, ctx.fieldsTerminatedBy) ++
+      entry(serdeConstants.SERIALIZATION_FORMAT, ctx.fieldsTerminatedBy) ++
+      entry(serdeConstants.ESCAPE_CHAR, ctx.escapedBy) ++
+      entry(serdeConstants.COLLECTION_DELIM, ctx.collectionItemsTerminatedBy) ++
+      entry(serdeConstants.MAPKEY_DELIM, ctx.keysTerminatedBy) ++
+      Option(ctx.linesSeparatedBy).toSeq.map { token =>
+        val value = string(token)
+        assert(
+          value == "\n",
+          s"LINES TERMINATED BY only supports newline '\\n' right now: $value",
+          ctx)
+        serdeConstants.LINE_DELIM -> value
+      }
+    EmptyStorageFormat.copy(serdeProperties = entries.toMap)
+  }
+}
-- 
cgit v1.2.3


From 4a55c336397d3f138c6f5735675ec7cb272827f5 Mon Sep 17 00:00:00 2001
From: Nong Li 
Date: Mon, 28 Mar 2016 20:32:58 -0700
Subject: [SPARK-13981][SQL] Defer evaluating variables within Filter operator.

## What changes were proposed in this pull request?

This improves the Filter codegen for NULLs by deferring loading the values for IsNotNull.
Instead of generating code like:

boolean isNull = ...
int value = ...
if (isNull) continue;

we will generate:
boolean isNull = ...
if (isNull) continue;
int value = ...

This is useful since retrieving the values can be non-trivial (they can be dictionary encoded
among other things). This currently only works when the attribute comes from the column batch
but could be extended to other cases in the future.

## How was this patch tested?

On tpcds q55, this fixes the regression from introducing the IsNotNull predicates.

```
TPCDS Snappy:                       Best/Avg Time(ms)    Rate(M/s)   Per Row(ns)
--------------------------------------------------------------------------------
q55                                      4564 / 5036         25.2          39.6
q55                                      4064 / 4340         28.3          35.3
```

Author: Nong Li 

Closes #11792 from nongli/spark-13981.
---
 .../spark/sql/execution/basicOperators.scala       | 77 +++++++++++++++++-----
 1 file changed, 61 insertions(+), 16 deletions(-)

diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
index 70e04d022f..fca662760d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
@@ -22,7 +22,7 @@ import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, ExpressionCanonicalizer}
 import org.apache.spark.sql.catalyst.plans.physical._
-import org.apache.spark.sql.execution.metric.{LongSQLMetricValue, SQLMetrics}
+import org.apache.spark.sql.execution.metric.SQLMetrics
 import org.apache.spark.sql.types.LongType
 import org.apache.spark.util.random.PoissonSampler
 
@@ -79,16 +79,20 @@ case class Filter(condition: Expression, child: SparkPlan)
 
   // Split out all the IsNotNulls from condition.
   private val (notNullPreds, otherPreds) = splitConjunctivePredicates(condition).partition {
-    case IsNotNull(a) if child.output.contains(a) => true
+    case IsNotNull(a) if child.output.exists(_.semanticEquals(a)) => true
     case _ => false
   }
 
   // The columns that will filtered out by `IsNotNull` could be considered as not nullable.
   private val notNullAttributes = notNullPreds.flatMap(_.references)
 
+  // Mark this as empty. We'll evaluate the input during doConsume(). We don't want to evaluate
+  // all the variables at the beginning to take advantage of short circuiting.
+  override def usedInputs: AttributeSet = AttributeSet.empty
+
   override def output: Seq[Attribute] = {
     child.output.map { a =>
-      if (a.nullable && notNullAttributes.contains(a)) {
+      if (a.nullable && notNullAttributes.exists(_.semanticEquals(a))) {
         a.withNullability(false)
       } else {
         a
@@ -110,39 +114,80 @@ case class Filter(condition: Expression, child: SparkPlan)
   override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = {
     val numOutput = metricTerm(ctx, "numOutputRows")
 
-    // filter out the nulls
-    val filterOutNull = notNullAttributes.map { a =>
-      val idx = child.output.indexOf(a)
-      s"if (${input(idx).isNull}) continue;"
-    }.mkString("\n")
+    /**
+     * Generates code for `c`, using `in` for input attributes and `attrs` for nullability.
+     */
+    def genPredicate(c: Expression, in: Seq[ExprCode], attrs: Seq[Attribute]): String = {
+      val bound = BindReferences.bindReference(c, attrs)
+      val evaluated = evaluateRequiredVariables(child.output, in, c.references)
 
-    ctx.currentVars = input
-    val predicates = otherPreds.map { e =>
-      val bound = ExpressionCanonicalizer.execute(
-        BindReferences.bindReference(e, output))
-      val ev = bound.gen(ctx)
+      // Generate the code for the predicate.
+      val ev = ExpressionCanonicalizer.execute(bound).gen(ctx)
       val nullCheck = if (bound.nullable) {
         s"${ev.isNull} || "
       } else {
         s""
       }
+
       s"""
+         |$evaluated
          |${ev.code}
          |if (${nullCheck}!${ev.value}) continue;
        """.stripMargin
+    }
+
+    ctx.currentVars = input
+
+    // To generate the predicates we will follow this algorithm.
+    // For each predicate that is not IsNotNull, we will generate them one by one loading attributes
+    // as necessary. For each of both attributes, if there is a IsNotNull predicate we will generate
+    // that check *before* the predicate. After all of these predicates, we will generate the
+    // remaining IsNotNull checks that were not part of other predicates.
+    // This has the property of not doing redundant IsNotNull checks and taking better advantage of
+    // short-circuiting, not loading attributes until they are needed.
+    // This is very perf sensitive.
+    // TODO: revisit this. We can consider reodering predicates as well.
+    val generatedIsNotNullChecks = new Array[Boolean](notNullPreds.length)
+    val generated = otherPreds.map { c =>
+      val nullChecks = c.references.map { r =>
+        val idx = notNullPreds.indexWhere { n => n.asInstanceOf[IsNotNull].child.semanticEquals(r)}
+        if (idx != -1 && !generatedIsNotNullChecks(idx)) {
+          generatedIsNotNullChecks(idx) = true
+          // Use the child's output. The nullability is what the child produced.
+          genPredicate(notNullPreds(idx), input, child.output)
+        } else {
+          ""
+        }
+      }.mkString("\n").trim
+
+      // Here we use *this* operator's output with this output's nullability since we already
+      // enforced them with the IsNotNull checks above.
+      s"""
+         |$nullChecks
+         |${genPredicate(c, input, output)}
+       """.stripMargin.trim
+    }.mkString("\n")
+
+    val nullChecks = notNullPreds.zipWithIndex.map { case (c, idx) =>
+      if (!generatedIsNotNullChecks(idx)) {
+        genPredicate(c, input, child.output)
+      } else {
+        ""
+      }
     }.mkString("\n")
 
     // Reset the isNull to false for the not-null columns, then the followed operators could
     // generate better code (remove dead branches).
     val resultVars = input.zipWithIndex.map { case (ev, i) =>
-      if (notNullAttributes.contains(child.output(i))) {
+      if (notNullAttributes.exists(_.semanticEquals(child.output(i)))) {
         ev.isNull = "false"
       }
       ev
     }
+
     s"""
-       |$filterOutNull
-       |$predicates
+       |$generated
+       |$nullChecks
        |$numOutput.add(1);
        |${consume(ctx, resultVars)}
      """.stripMargin
-- 
cgit v1.2.3


From a180286b7994f9f9a56b84903cc9ee6057ba6624 Mon Sep 17 00:00:00 2001
From: Nong Li 
Date: Mon, 28 Mar 2016 21:37:46 -0700
Subject: [SPARK-14210] [SQL] Add a metric for time spent in scans.

## What changes were proposed in this pull request?

This adds a metric to parquet scans that measures the time in just the scan phase. This is
only possible when the scan returns ColumnarBatches, otherwise the overhead is too high.

This combined with the pipeline metric lets us easily see what percent of the time was
in the scan.

Author: Nong Li 

Closes #12007 from nongli/spark-14210.
---
 .../apache/spark/sql/execution/ExistingRDD.scala   | 157 ++++++++++++---------
 1 file changed, 94 insertions(+), 63 deletions(-)

diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala
index 815ff01c4c..ab575e90c9 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala
@@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
 import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Statistics}
-import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, UnknownPartitioning}
+import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, UnknownPartitioning}
 import org.apache.spark.sql.catalyst.util.toCommentSafeString
 import org.apache.spark.sql.execution.datasources.parquet.{DefaultSource => ParquetSource}
 import org.apache.spark.sql.execution.metric.SQLMetrics
@@ -139,8 +139,12 @@ private[sql] case class DataSourceScan(
     case _ => false
   }
 
-  private[sql] override lazy val metrics = Map(
-    "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows"))
+  private[sql] override lazy val metrics = if (canProcessBatches()) {
+    Map("numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows"),
+      "scanTime" -> SQLMetrics.createTimingMetric(sparkContext, "scan time"))
+  } else {
+    Map("numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows"))
+  }
 
   val outputUnsafeRows = relation match {
     case r: HadoopFsRelation if r.fileFormat.isInstanceOf[ParquetSource] =>
@@ -170,6 +174,17 @@ private[sql] case class DataSourceScan(
     }
   }
 
+  private def canProcessBatches(): Boolean = {
+    relation match {
+      case r: HadoopFsRelation if r.fileFormat.isInstanceOf[ParquetSource] &&
+        SQLContext.getActive().get.conf.getConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED) &&
+        SQLContext.getActive().get.conf.getConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED) =>
+        true
+      case _ =>
+        false
+    }
+  }
+
   protected override def doExecute(): RDD[InternalRow] = {
     val unsafeRow = if (outputUnsafeRows) {
       rdd
@@ -241,73 +256,89 @@ private[sql] case class DataSourceScan(
     // TODO: The abstractions between this class and SqlNewHadoopRDD makes it difficult to know
     // here which path to use. Fix this.
 
-    ctx.currentVars = null
-    val columns1 = (output zip colVars).map { case (attr, colVar) =>
-      genCodeColumnVector(ctx, colVar, rowidx, attr.dataType, attr.nullable) }
-    val scanBatches = ctx.freshName("processBatches")
-    ctx.addNewFunction(scanBatches,
-      s"""
-      | private void $scanBatches() throws java.io.IOException {
-      |  while (true) {
-      |     int numRows = $batch.numRows();
-      |     if ($idx == 0) {
-      |       ${columnAssigns.mkString("", "\n", "\n")}
-      |       $numOutputRows.add(numRows);
-      |     }
-      |
-      |     // this loop is very perf sensitive and changes to it should be measured carefully
-      |     while ($idx < numRows) {
-      |       int $rowidx = $idx++;
-      |       ${consume(ctx, columns1).trim}
-      |       if (shouldStop()) return;
-      |     }
-      |
-      |     if (!$input.hasNext()) {
-      |       $batch = null;
-      |       break;
-      |     }
-      |     $batch = ($columnarBatchClz)$input.next();
-      |     $idx = 0;
-      |   }
-      | }""".stripMargin)
-
     val exprRows =
-      output.zipWithIndex.map(x => new BoundReference(x._2, x._1.dataType, x._1.nullable))
+        output.zipWithIndex.map(x => new BoundReference(x._2, x._1.dataType, x._1.nullable))
     ctx.INPUT_ROW = row
     ctx.currentVars = null
-    val columns2 = exprRows.map(_.gen(ctx))
+    val columnsRowInput = exprRows.map(_.gen(ctx))
     val inputRow = if (outputUnsafeRows) row else null
     val scanRows = ctx.freshName("processRows")
     ctx.addNewFunction(scanRows,
       s"""
-       | private void $scanRows(InternalRow $row) throws java.io.IOException {
-       |   boolean firstRow = true;
-       |   while (firstRow || $input.hasNext()) {
-       |     if (firstRow) {
-       |       firstRow = false;
-       |     } else {
-       |       $row = (InternalRow) $input.next();
-       |     }
-       |     $numOutputRows.add(1);
-       |     ${consume(ctx, columns2, inputRow).trim}
-       |     if (shouldStop()) return;
-       |   }
-       | }""".stripMargin)
-
-    val value = ctx.freshName("value")
-    s"""
-       | if ($batch != null) {
-       |   $scanBatches();
-       | } else if ($input.hasNext()) {
-       |   Object $value = $input.next();
-       |   if ($value instanceof $columnarBatchClz) {
-       |     $batch = ($columnarBatchClz)$value;
-       |     $scanBatches();
-       |   } else {
-       |     $scanRows((InternalRow) $value);
-       |   }
-       | }
-     """.stripMargin
+         | private void $scanRows(InternalRow $row) throws java.io.IOException {
+         |   boolean firstRow = true;
+         |   while (!shouldStop() && (firstRow || $input.hasNext())) {
+         |     if (firstRow) {
+         |       firstRow = false;
+         |     } else {
+         |       $row = (InternalRow) $input.next();
+         |     }
+         |     $numOutputRows.add(1);
+         |     ${consume(ctx, columnsRowInput, inputRow).trim}
+         |   }
+         | }""".stripMargin)
+
+    // Timers for how long we spent inside the scan. We can only maintain this when using batches,
+    // otherwise the overhead is too high.
+    if (canProcessBatches()) {
+      val scanTimeMetric = metricTerm(ctx, "scanTime")
+      val getBatchStart = ctx.freshName("scanStart")
+      val scanTimeTotalNs = ctx.freshName("scanTime")
+      ctx.currentVars = null
+      val columnsBatchInput = (output zip colVars).map { case (attr, colVar) =>
+        genCodeColumnVector(ctx, colVar, rowidx, attr.dataType, attr.nullable) }
+      val scanBatches = ctx.freshName("processBatches")
+      ctx.addMutableState("long", scanTimeTotalNs, s"$scanTimeTotalNs = 0;")
+
+      ctx.addNewFunction(scanBatches,
+        s"""
+        | private void $scanBatches() throws java.io.IOException {
+        |  while (true) {
+        |     int numRows = $batch.numRows();
+        |     if ($idx == 0) {
+        |       ${columnAssigns.mkString("", "\n", "\n")}
+        |       $numOutputRows.add(numRows);
+        |     }
+        |
+        |     while (!shouldStop() && $idx < numRows) {
+        |       int $rowidx = $idx++;
+        |       ${consume(ctx, columnsBatchInput).trim}
+        |     }
+        |     if (shouldStop()) return;
+        |
+        |     long $getBatchStart = System.nanoTime();
+        |     if (!$input.hasNext()) {
+        |       $batch = null;
+        |       $scanTimeMetric.add($scanTimeTotalNs / (1000 * 1000));
+        |       break;
+        |     }
+        |     $batch = ($columnarBatchClz)$input.next();
+        |     $scanTimeTotalNs += System.nanoTime() - $getBatchStart;
+        |     $idx = 0;
+        |   }
+        | }""".stripMargin)
+
+      val value = ctx.freshName("value")
+      s"""
+         | if ($batch != null) {
+         |   $scanBatches();
+         | } else if ($input.hasNext()) {
+         |   Object $value = $input.next();
+         |   if ($value instanceof $columnarBatchClz) {
+         |     $batch = ($columnarBatchClz)$value;
+         |     $scanBatches();
+         |   } else {
+         |     $scanRows((InternalRow) $value);
+         |   }
+         | }
+       """.stripMargin
+    } else {
+      s"""
+         |if ($input.hasNext()) {
+         |  $scanRows((InternalRow) $input.next());
+         |}
+       """.stripMargin
+    }
   }
 }
 
-- 
cgit v1.2.3


From d3638d7bffd4ee43db594c0669d86fb64d448fc8 Mon Sep 17 00:00:00 2001
From: Sun Rui 
Date: Mon, 28 Mar 2016 21:51:02 -0700
Subject: [SPARK-12792] [SPARKR] Refactor RRDD to support R UDF.

## What changes were proposed in this pull request?

Refactor RRDD by separating the common logic interacting with the R worker to a new class RRunner, which can be used to evaluate R UDFs.

Now RRDD relies on RRuner for RDD computation and RRDD could be reomved if we want to remove RDD API in SparkR later.

## How was this patch tested?
dev/lint-r
SparkR unit tests

Author: Sun Rui 

Closes #12024 from sun-rui/SPARK-12792_new.
---
 R/pkg/inst/tests/testthat/test_rdd.R               |   8 +
 .../main/scala/org/apache/spark/api/r/RRDD.scala   | 328 +-----------------
 .../scala/org/apache/spark/api/r/RRunner.scala     | 368 +++++++++++++++++++++
 3 files changed, 380 insertions(+), 324 deletions(-)
 create mode 100644 core/src/main/scala/org/apache/spark/api/r/RRunner.scala

diff --git a/R/pkg/inst/tests/testthat/test_rdd.R b/R/pkg/inst/tests/testthat/test_rdd.R
index 3b0c16be5a..b6c8e1dc6c 100644
--- a/R/pkg/inst/tests/testthat/test_rdd.R
+++ b/R/pkg/inst/tests/testthat/test_rdd.R
@@ -791,3 +791,11 @@ test_that("sampleByKey() on pairwise RDDs", {
   expect_equal(lookup(sample, 3)[which.min(lookup(sample, 3))] >= 0, TRUE)
   expect_equal(lookup(sample, 3)[which.max(lookup(sample, 3))] <= 2000, TRUE)
 })
+
+test_that("Test correct concurrency of RRDD.compute()", {
+  rdd <- parallelize(sc, 1:1000, 100)
+  jrdd <- getJRDD(lapply(rdd, function(x) { x }), "row")
+  zrdd <- callJMethod(jrdd, "zip", jrdd)
+  count <- callJMethod(zrdd, "count")
+  expect_equal(count, 1000)
+})
diff --git a/core/src/main/scala/org/apache/spark/api/r/RRDD.scala b/core/src/main/scala/org/apache/spark/api/r/RRDD.scala
index 588a57e65f..606ba6ef86 100644
--- a/core/src/main/scala/org/apache/spark/api/r/RRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/r/RRDD.scala
@@ -17,21 +17,16 @@
 
 package org.apache.spark.api.r
 
-import java.io._
-import java.net.{InetAddress, ServerSocket}
-import java.util.{Arrays, Map => JMap}
+import java.util.{Map => JMap}
 
 import scala.collection.JavaConverters._
-import scala.io.Source
 import scala.reflect.ClassTag
-import scala.util.Try
 
 import org.apache.spark._
 import org.apache.spark.api.java.{JavaPairRDD, JavaRDD, JavaSparkContext}
 import org.apache.spark.broadcast.Broadcast
 import org.apache.spark.internal.Logging
 import org.apache.spark.rdd.RDD
-import org.apache.spark.util.Utils
 
 private abstract class BaseRRDD[T: ClassTag, U: ClassTag](
     parent: RDD[T],
@@ -42,188 +37,16 @@ private abstract class BaseRRDD[T: ClassTag, U: ClassTag](
     packageNames: Array[Byte],
     broadcastVars: Array[Broadcast[Object]])
   extends RDD[U](parent) with Logging {
-  protected var dataStream: DataInputStream = _
-  private var bootTime: Double = _
   override def getPartitions: Array[Partition] = parent.partitions
 
   override def compute(partition: Partition, context: TaskContext): Iterator[U] = {
-
-    // Timing start
-    bootTime = System.currentTimeMillis / 1000.0
+    val runner = new RRunner[U](
+      func, deserializer, serializer, packageNames, broadcastVars, numPartitions)
 
     // The parent may be also an RRDD, so we should launch it first.
     val parentIterator = firstParent[T].iterator(partition, context)
 
-    // we expect two connections
-    val serverSocket = new ServerSocket(0, 2, InetAddress.getByName("localhost"))
-    val listenPort = serverSocket.getLocalPort()
-
-    // The stdout/stderr is shared by multiple tasks, because we use one daemon
-    // to launch child process as worker.
-    val errThread = RRDD.createRWorker(listenPort)
-
-    // We use two sockets to separate input and output, then it's easy to manage
-    // the lifecycle of them to avoid deadlock.
-    // TODO: optimize it to use one socket
-
-    // the socket used to send out the input of task
-    serverSocket.setSoTimeout(10000)
-    val inSocket = serverSocket.accept()
-    startStdinThread(inSocket.getOutputStream(), parentIterator, partition.index)
-
-    // the socket used to receive the output of task
-    val outSocket = serverSocket.accept()
-    val inputStream = new BufferedInputStream(outSocket.getInputStream)
-    dataStream = new DataInputStream(inputStream)
-    serverSocket.close()
-
-    try {
-
-      return new Iterator[U] {
-        def next(): U = {
-          val obj = _nextObj
-          if (hasNext) {
-            _nextObj = read()
-          }
-          obj
-        }
-
-        var _nextObj = read()
-
-        def hasNext(): Boolean = {
-          val hasMore = (_nextObj != null)
-          if (!hasMore) {
-            dataStream.close()
-          }
-          hasMore
-        }
-      }
-    } catch {
-      case e: Exception =>
-        throw new SparkException("R computation failed with\n " + errThread.getLines())
-    }
-  }
-
-  /**
-   * Start a thread to write RDD data to the R process.
-   */
-  private def startStdinThread[T](
-    output: OutputStream,
-    iter: Iterator[T],
-    partition: Int): Unit = {
-
-    val env = SparkEnv.get
-    val taskContext = TaskContext.get()
-    val bufferSize = System.getProperty("spark.buffer.size", "65536").toInt
-    val stream = new BufferedOutputStream(output, bufferSize)
-
-    new Thread("writer for R") {
-      override def run(): Unit = {
-        try {
-          SparkEnv.set(env)
-          TaskContext.setTaskContext(taskContext)
-          val dataOut = new DataOutputStream(stream)
-          dataOut.writeInt(partition)
-
-          SerDe.writeString(dataOut, deserializer)
-          SerDe.writeString(dataOut, serializer)
-
-          dataOut.writeInt(packageNames.length)
-          dataOut.write(packageNames)
-
-          dataOut.writeInt(func.length)
-          dataOut.write(func)
-
-          dataOut.writeInt(broadcastVars.length)
-          broadcastVars.foreach { broadcast =>
-            // TODO(shivaram): Read a Long in R to avoid this cast
-            dataOut.writeInt(broadcast.id.toInt)
-            // TODO: Pass a byte array from R to avoid this cast ?
-            val broadcastByteArr = broadcast.value.asInstanceOf[Array[Byte]]
-            dataOut.writeInt(broadcastByteArr.length)
-            dataOut.write(broadcastByteArr)
-          }
-
-          dataOut.writeInt(numPartitions)
-
-          if (!iter.hasNext) {
-            dataOut.writeInt(0)
-          } else {
-            dataOut.writeInt(1)
-          }
-
-          val printOut = new PrintStream(stream)
-
-          def writeElem(elem: Any): Unit = {
-            if (deserializer == SerializationFormats.BYTE) {
-              val elemArr = elem.asInstanceOf[Array[Byte]]
-              dataOut.writeInt(elemArr.length)
-              dataOut.write(elemArr)
-            } else if (deserializer == SerializationFormats.ROW) {
-              dataOut.write(elem.asInstanceOf[Array[Byte]])
-            } else if (deserializer == SerializationFormats.STRING) {
-              // write string(for StringRRDD)
-              // scalastyle:off println
-              printOut.println(elem)
-              // scalastyle:on println
-            }
-          }
-
-          for (elem <- iter) {
-            elem match {
-              case (key, value) =>
-                writeElem(key)
-                writeElem(value)
-              case _ =>
-                writeElem(elem)
-            }
-          }
-          stream.flush()
-        } catch {
-          // TODO: We should propogate this error to the task thread
-          case e: Exception =>
-            logError("R Writer thread got an exception", e)
-        } finally {
-          Try(output.close())
-        }
-      }
-    }.start()
-  }
-
-  protected def readData(length: Int): U
-
-  protected def read(): U = {
-    try {
-      val length = dataStream.readInt()
-
-      length match {
-        case SpecialLengths.TIMING_DATA =>
-          // Timing data from R worker
-          val boot = dataStream.readDouble - bootTime
-          val init = dataStream.readDouble
-          val broadcast = dataStream.readDouble
-          val input = dataStream.readDouble
-          val compute = dataStream.readDouble
-          val output = dataStream.readDouble
-          logInfo(
-            ("Times: boot = %.3f s, init = %.3f s, broadcast = %.3f s, " +
-             "read-input = %.3f s, compute = %.3f s, write-output = %.3f s, " +
-             "total = %.3f s").format(
-               boot,
-               init,
-               broadcast,
-               input,
-               compute,
-               output,
-               boot + init + broadcast + input + compute + output))
-          read()
-        case length if length >= 0 =>
-          readData(length)
-      }
-    } catch {
-      case eof: EOFException =>
-        throw new SparkException("R worker exited unexpectedly (cranshed)", eof)
-    }
+    runner.compute(parentIterator, partition.index, context)
   }
 }
 
@@ -242,19 +65,6 @@ private class PairwiseRRDD[T: ClassTag](
     parent, numPartitions, hashFunc, deserializer,
     SerializationFormats.BYTE, packageNames,
     broadcastVars.map(x => x.asInstanceOf[Broadcast[Object]])) {
-
-  override protected def readData(length: Int): (Int, Array[Byte]) = {
-    length match {
-      case length if length == 2 =>
-        val hashedKey = dataStream.readInt()
-        val contentPairsLength = dataStream.readInt()
-        val contentPairs = new Array[Byte](contentPairsLength)
-        dataStream.readFully(contentPairs)
-        (hashedKey, contentPairs)
-      case _ => null
-   }
-  }
-
   lazy val asJavaPairRDD : JavaPairRDD[Int, Array[Byte]] = JavaPairRDD.fromRDD(this)
 }
 
@@ -271,17 +81,6 @@ private class RRDD[T: ClassTag](
   extends BaseRRDD[T, Array[Byte]](
     parent, -1, func, deserializer, serializer, packageNames,
     broadcastVars.map(x => x.asInstanceOf[Broadcast[Object]])) {
-
-  override protected def readData(length: Int): Array[Byte] = {
-    length match {
-      case length if length > 0 =>
-        val obj = new Array[Byte](length)
-        dataStream.readFully(obj)
-        obj
-      case _ => null
-    }
-  }
-
   lazy val asJavaRDD : JavaRDD[Array[Byte]] = JavaRDD.fromRDD(this)
 }
 
@@ -297,55 +96,10 @@ private class StringRRDD[T: ClassTag](
   extends BaseRRDD[T, String](
     parent, -1, func, deserializer, SerializationFormats.STRING, packageNames,
     broadcastVars.map(x => x.asInstanceOf[Broadcast[Object]])) {
-
-  override protected def readData(length: Int): String = {
-    length match {
-      case length if length > 0 =>
-        SerDe.readStringBytes(dataStream, length)
-      case _ => null
-    }
-  }
-
   lazy val asJavaRDD : JavaRDD[String] = JavaRDD.fromRDD(this)
 }
 
-private object SpecialLengths {
-  val TIMING_DATA = -1
-}
-
-private[r] class BufferedStreamThread(
-    in: InputStream,
-    name: String,
-    errBufferSize: Int) extends Thread(name) with Logging {
-  val lines = new Array[String](errBufferSize)
-  var lineIdx = 0
-  override def run() {
-    for (line <- Source.fromInputStream(in).getLines) {
-      synchronized {
-        lines(lineIdx) = line
-        lineIdx = (lineIdx + 1) % errBufferSize
-      }
-      logInfo(line)
-    }
-  }
-
-  def getLines(): String = synchronized {
-    (0 until errBufferSize).filter { x =>
-      lines((x + lineIdx) % errBufferSize) != null
-    }.map { x =>
-      lines((x + lineIdx) % errBufferSize)
-    }.mkString("\n")
-  }
-}
-
 private[r] object RRDD {
-  // Because forking processes from Java is expensive, we prefer to launch
-  // a single R daemon (daemon.R) and tell it to fork new workers for our tasks.
-  // This daemon currently only works on UNIX-based systems now, so we should
-  // also fall back to launching workers (worker.R) directly.
-  private[this] var errThread: BufferedStreamThread = _
-  private[this] var daemonChannel: DataOutputStream = _
-
   def createSparkContext(
       master: String,
       appName: String,
@@ -353,7 +107,6 @@ private[r] object RRDD {
       jars: Array[String],
       sparkEnvirMap: JMap[Object, Object],
       sparkExecutorEnvMap: JMap[Object, Object]): JavaSparkContext = {
-
     val sparkConf = new SparkConf().setAppName(appName)
                                    .setSparkHome(sparkHome)
 
@@ -380,78 +133,6 @@ private[r] object RRDD {
     jsc
   }
 
-  /**
-   * Start a thread to print the process's stderr to ours
-   */
-  private def startStdoutThread(proc: Process): BufferedStreamThread = {
-    val BUFFER_SIZE = 100
-    val thread = new BufferedStreamThread(proc.getInputStream, "stdout reader for R", BUFFER_SIZE)
-    thread.setDaemon(true)
-    thread.start()
-    thread
-  }
-
-  private def createRProcess(port: Int, script: String): BufferedStreamThread = {
-    // "spark.sparkr.r.command" is deprecated and replaced by "spark.r.command",
-    // but kept here for backward compatibility.
-    val sparkConf = SparkEnv.get.conf
-    var rCommand = sparkConf.get("spark.sparkr.r.command", "Rscript")
-    rCommand = sparkConf.get("spark.r.command", rCommand)
-
-    val rOptions = "--vanilla"
-    val rLibDir = RUtils.sparkRPackagePath(isDriver = false)
-    val rExecScript = rLibDir(0) + "/SparkR/worker/" + script
-    val pb = new ProcessBuilder(Arrays.asList(rCommand, rOptions, rExecScript))
-    // Unset the R_TESTS environment variable for workers.
-    // This is set by R CMD check as startup.Rs
-    // (http://svn.r-project.org/R/trunk/src/library/tools/R/testing.R)
-    // and confuses worker script which tries to load a non-existent file
-    pb.environment().put("R_TESTS", "")
-    pb.environment().put("SPARKR_RLIBDIR", rLibDir.mkString(","))
-    pb.environment().put("SPARKR_WORKER_PORT", port.toString)
-    pb.redirectErrorStream(true)  // redirect stderr into stdout
-    val proc = pb.start()
-    val errThread = startStdoutThread(proc)
-    errThread
-  }
-
-  /**
-   * ProcessBuilder used to launch worker R processes.
-   */
-  def createRWorker(port: Int): BufferedStreamThread = {
-    val useDaemon = SparkEnv.get.conf.getBoolean("spark.sparkr.use.daemon", true)
-    if (!Utils.isWindows && useDaemon) {
-      synchronized {
-        if (daemonChannel == null) {
-          // we expect one connections
-          val serverSocket = new ServerSocket(0, 1, InetAddress.getByName("localhost"))
-          val daemonPort = serverSocket.getLocalPort
-          errThread = createRProcess(daemonPort, "daemon.R")
-          // the socket used to send out the input of task
-          serverSocket.setSoTimeout(10000)
-          val sock = serverSocket.accept()
-          daemonChannel = new DataOutputStream(new BufferedOutputStream(sock.getOutputStream))
-          serverSocket.close()
-        }
-        try {
-          daemonChannel.writeInt(port)
-          daemonChannel.flush()
-        } catch {
-          case e: IOException =>
-            // daemon process died
-            daemonChannel.close()
-            daemonChannel = null
-            errThread = null
-            // fail the current task, retry by scheduler
-            throw e
-        }
-        errThread
-      }
-    } else {
-      createRProcess(port, "worker.R")
-    }
-  }
-
   /**
    * Create an RRDD given a sequence of byte arrays. Used to create RRDD when `parallelize` is
    * called from R.
@@ -459,5 +140,4 @@ private[r] object RRDD {
   def createRDDFromArray(jsc: JavaSparkContext, arr: Array[Array[Byte]]): JavaRDD[Array[Byte]] = {
     JavaRDD.fromRDD(jsc.sc.parallelize(arr, arr.length))
   }
-
 }
diff --git a/core/src/main/scala/org/apache/spark/api/r/RRunner.scala b/core/src/main/scala/org/apache/spark/api/r/RRunner.scala
new file mode 100644
index 0000000000..ff279ec270
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/api/r/RRunner.scala
@@ -0,0 +1,368 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.api.r
+
+import java.io._
+import java.net.{InetAddress, ServerSocket}
+import java.util.Arrays
+
+import scala.io.Source
+import scala.util.Try
+
+import org.apache.spark._
+import org.apache.spark.broadcast.Broadcast
+import org.apache.spark.internal.Logging
+import org.apache.spark.util.Utils
+
+/**
+ * A helper class to run R UDFs in Spark.
+ */
+private[spark] class RRunner[U](
+    func: Array[Byte],
+    deserializer: String,
+    serializer: String,
+    packageNames: Array[Byte],
+    broadcastVars: Array[Broadcast[Object]],
+    numPartitions: Int = -1)
+  extends Logging {
+  private var bootTime: Double = _
+  private var dataStream: DataInputStream = _
+  val readData = numPartitions match {
+    case -1 =>
+      serializer match {
+        case SerializationFormats.STRING => readStringData _
+        case _ => readByteArrayData _
+      }
+    case _ => readShuffledData _
+  }
+
+  def compute(
+      inputIterator: Iterator[_],
+      partitionIndex: Int,
+      context: TaskContext): Iterator[U] = {
+    // Timing start
+    bootTime = System.currentTimeMillis / 1000.0
+
+    // we expect two connections
+    val serverSocket = new ServerSocket(0, 2, InetAddress.getByName("localhost"))
+    val listenPort = serverSocket.getLocalPort()
+
+    // The stdout/stderr is shared by multiple tasks, because we use one daemon
+    // to launch child process as worker.
+    val errThread = RRunner.createRWorker(listenPort)
+
+    // We use two sockets to separate input and output, then it's easy to manage
+    // the lifecycle of them to avoid deadlock.
+    // TODO: optimize it to use one socket
+
+    // the socket used to send out the input of task
+    serverSocket.setSoTimeout(10000)
+    val inSocket = serverSocket.accept()
+    startStdinThread(inSocket.getOutputStream(), inputIterator, partitionIndex)
+
+    // the socket used to receive the output of task
+    val outSocket = serverSocket.accept()
+    val inputStream = new BufferedInputStream(outSocket.getInputStream)
+    dataStream = new DataInputStream(inputStream)
+    serverSocket.close()
+
+    try {
+      return new Iterator[U] {
+        def next(): U = {
+          val obj = _nextObj
+          if (hasNext) {
+            _nextObj = read()
+          }
+          obj
+        }
+
+        var _nextObj = read()
+
+        def hasNext(): Boolean = {
+          val hasMore = (_nextObj != null)
+          if (!hasMore) {
+            dataStream.close()
+          }
+          hasMore
+        }
+      }
+    } catch {
+      case e: Exception =>
+        throw new SparkException("R computation failed with\n " + errThread.getLines())
+    }
+  }
+
+  /**
+   * Start a thread to write RDD data to the R process.
+   */
+  private def startStdinThread(
+      output: OutputStream,
+      iter: Iterator[_],
+      partitionIndex: Int): Unit = {
+    val env = SparkEnv.get
+    val taskContext = TaskContext.get()
+    val bufferSize = System.getProperty("spark.buffer.size", "65536").toInt
+    val stream = new BufferedOutputStream(output, bufferSize)
+
+    new Thread("writer for R") {
+      override def run(): Unit = {
+        try {
+          SparkEnv.set(env)
+          TaskContext.setTaskContext(taskContext)
+          val dataOut = new DataOutputStream(stream)
+          dataOut.writeInt(partitionIndex)
+
+          SerDe.writeString(dataOut, deserializer)
+          SerDe.writeString(dataOut, serializer)
+
+          dataOut.writeInt(packageNames.length)
+          dataOut.write(packageNames)
+
+          dataOut.writeInt(func.length)
+          dataOut.write(func)
+
+          dataOut.writeInt(broadcastVars.length)
+          broadcastVars.foreach { broadcast =>
+            // TODO(shivaram): Read a Long in R to avoid this cast
+            dataOut.writeInt(broadcast.id.toInt)
+            // TODO: Pass a byte array from R to avoid this cast ?
+            val broadcastByteArr = broadcast.value.asInstanceOf[Array[Byte]]
+            dataOut.writeInt(broadcastByteArr.length)
+            dataOut.write(broadcastByteArr)
+          }
+
+          dataOut.writeInt(numPartitions)
+
+          if (!iter.hasNext) {
+            dataOut.writeInt(0)
+          } else {
+            dataOut.writeInt(1)
+          }
+
+          val printOut = new PrintStream(stream)
+
+          def writeElem(elem: Any): Unit = {
+            if (deserializer == SerializationFormats.BYTE) {
+              val elemArr = elem.asInstanceOf[Array[Byte]]
+              dataOut.writeInt(elemArr.length)
+              dataOut.write(elemArr)
+            } else if (deserializer == SerializationFormats.ROW) {
+              dataOut.write(elem.asInstanceOf[Array[Byte]])
+            } else if (deserializer == SerializationFormats.STRING) {
+              // write string(for StringRRDD)
+              // scalastyle:off println
+              printOut.println(elem)
+              // scalastyle:on println
+            }
+          }
+
+          for (elem <- iter) {
+            elem match {
+              case (key, value) =>
+                writeElem(key)
+                writeElem(value)
+              case _ =>
+                writeElem(elem)
+            }
+          }
+          stream.flush()
+        } catch {
+          // TODO: We should propogate this error to the task thread
+          case e: Exception =>
+            logError("R Writer thread got an exception", e)
+        } finally {
+          Try(output.close())
+        }
+      }
+    }.start()
+  }
+
+  private def read(): U = {
+    try {
+      val length = dataStream.readInt()
+
+      length match {
+        case SpecialLengths.TIMING_DATA =>
+          // Timing data from R worker
+          val boot = dataStream.readDouble - bootTime
+          val init = dataStream.readDouble
+          val broadcast = dataStream.readDouble
+          val input = dataStream.readDouble
+          val compute = dataStream.readDouble
+          val output = dataStream.readDouble
+          logInfo(
+            ("Times: boot = %.3f s, init = %.3f s, broadcast = %.3f s, " +
+              "read-input = %.3f s, compute = %.3f s, write-output = %.3f s, " +
+              "total = %.3f s").format(
+                boot,
+                init,
+                broadcast,
+                input,
+                compute,
+                output,
+                boot + init + broadcast + input + compute + output))
+          read()
+        case length if length >= 0 =>
+          readData(length).asInstanceOf[U]
+      }
+    } catch {
+      case eof: EOFException =>
+        throw new SparkException("R worker exited unexpectedly (cranshed)", eof)
+    }
+  }
+
+  private def readShuffledData(length: Int): (Int, Array[Byte]) = {
+    length match {
+      case length if length == 2 =>
+        val hashedKey = dataStream.readInt()
+        val contentPairsLength = dataStream.readInt()
+        val contentPairs = new Array[Byte](contentPairsLength)
+        dataStream.readFully(contentPairs)
+        (hashedKey, contentPairs)
+      case _ => null
+    }
+  }
+
+  private def readByteArrayData(length: Int): Array[Byte] = {
+    length match {
+      case length if length > 0 =>
+        val obj = new Array[Byte](length)
+        dataStream.readFully(obj)
+        obj
+      case _ => null
+    }
+  }
+
+  private def readStringData(length: Int): String = {
+    length match {
+      case length if length > 0 =>
+        SerDe.readStringBytes(dataStream, length)
+      case _ => null
+    }
+  }
+}
+
+private object SpecialLengths {
+  val TIMING_DATA = -1
+}
+
+private[r] class BufferedStreamThread(
+    in: InputStream,
+    name: String,
+    errBufferSize: Int) extends Thread(name) with Logging {
+  val lines = new Array[String](errBufferSize)
+  var lineIdx = 0
+  override def run() {
+    for (line <- Source.fromInputStream(in).getLines) {
+      synchronized {
+        lines(lineIdx) = line
+        lineIdx = (lineIdx + 1) % errBufferSize
+      }
+      logInfo(line)
+    }
+  }
+
+  def getLines(): String = synchronized {
+    (0 until errBufferSize).filter { x =>
+      lines((x + lineIdx) % errBufferSize) != null
+    }.map { x =>
+      lines((x + lineIdx) % errBufferSize)
+    }.mkString("\n")
+  }
+}
+
+private[r] object RRunner {
+  // Because forking processes from Java is expensive, we prefer to launch
+  // a single R daemon (daemon.R) and tell it to fork new workers for our tasks.
+  // This daemon currently only works on UNIX-based systems now, so we should
+  // also fall back to launching workers (worker.R) directly.
+  private[this] var errThread: BufferedStreamThread = _
+  private[this] var daemonChannel: DataOutputStream = _
+
+  /**
+   * Start a thread to print the process's stderr to ours
+   */
+  private def startStdoutThread(proc: Process): BufferedStreamThread = {
+    val BUFFER_SIZE = 100
+    val thread = new BufferedStreamThread(proc.getInputStream, "stdout reader for R", BUFFER_SIZE)
+    thread.setDaemon(true)
+    thread.start()
+    thread
+  }
+
+  private def createRProcess(port: Int, script: String): BufferedStreamThread = {
+    // "spark.sparkr.r.command" is deprecated and replaced by "spark.r.command",
+    // but kept here for backward compatibility.
+    val sparkConf = SparkEnv.get.conf
+    var rCommand = sparkConf.get("spark.sparkr.r.command", "Rscript")
+    rCommand = sparkConf.get("spark.r.command", rCommand)
+
+    val rOptions = "--vanilla"
+    val rLibDir = RUtils.sparkRPackagePath(isDriver = false)
+    val rExecScript = rLibDir(0) + "/SparkR/worker/" + script
+    val pb = new ProcessBuilder(Arrays.asList(rCommand, rOptions, rExecScript))
+    // Unset the R_TESTS environment variable for workers.
+    // This is set by R CMD check as startup.Rs
+    // (http://svn.r-project.org/R/trunk/src/library/tools/R/testing.R)
+    // and confuses worker script which tries to load a non-existent file
+    pb.environment().put("R_TESTS", "")
+    pb.environment().put("SPARKR_RLIBDIR", rLibDir.mkString(","))
+    pb.environment().put("SPARKR_WORKER_PORT", port.toString)
+    pb.redirectErrorStream(true)  // redirect stderr into stdout
+    val proc = pb.start()
+    val errThread = startStdoutThread(proc)
+    errThread
+  }
+
+  /**
+   * ProcessBuilder used to launch worker R processes.
+   */
+  def createRWorker(port: Int): BufferedStreamThread = {
+    val useDaemon = SparkEnv.get.conf.getBoolean("spark.sparkr.use.daemon", true)
+    if (!Utils.isWindows && useDaemon) {
+      synchronized {
+        if (daemonChannel == null) {
+          // we expect one connections
+          val serverSocket = new ServerSocket(0, 1, InetAddress.getByName("localhost"))
+          val daemonPort = serverSocket.getLocalPort
+          errThread = createRProcess(daemonPort, "daemon.R")
+          // the socket used to send out the input of task
+          serverSocket.setSoTimeout(10000)
+          val sock = serverSocket.accept()
+          daemonChannel = new DataOutputStream(new BufferedOutputStream(sock.getOutputStream))
+          serverSocket.close()
+        }
+        try {
+          daemonChannel.writeInt(port)
+          daemonChannel.flush()
+        } catch {
+          case e: IOException =>
+            // daemon process died
+            daemonChannel.close()
+            daemonChannel = null
+            errThread = null
+            // fail the current task, retry by scheduler
+            throw e
+        }
+        errThread
+      }
+    } else {
+      createRProcess(port, "worker.R")
+    }
+  }
+}
-- 
cgit v1.2.3


From f6066b0c3c35ceea1706378145e15776c9b4415a Mon Sep 17 00:00:00 2001
From: sethah 
Date: Mon, 28 Mar 2016 22:27:53 -0700
Subject: [SPARK-11730][ML] Add feature importances for GBTs.

## What changes were proposed in this pull request?

Now that GBTs have been moved to ML, they can use the implementation of feature importance for random forests. This patch simply adds a `featureImportances` attribute to `GBTClassifier` and `GBTRegressor` and adds tests for each.

GBT feature importances here simply average the feature importances for each tree in its ensemble. This follows the implementation from scikit-learn. This method is also suggested by J Friedman in [this paper](https://statweb.stanford.edu/~jhf/ftp/trebst.pdf).

## How was this patch tested?

Unit tests were added to `GBTClassifierSuite` and `GBTRegressorSuite` to validate feature importances.

Author: sethah 

Closes #11961 from sethah/SPARK-11730.
---
 .../ml/classification/DecisionTreeClassifier.scala |   2 +-
 .../spark/ml/classification/GBTClassifier.scala    |  13 +++
 .../ml/classification/RandomForestClassifier.scala |  16 ++-
 .../ml/regression/DecisionTreeRegressor.scala      |   2 +-
 .../apache/spark/ml/regression/GBTRegressor.scala  |  13 +++
 .../ml/regression/RandomForestRegressor.scala      |  16 ++-
 .../spark/ml/tree/impl/GradientBoostedTrees.scala  |   2 +
 .../apache/spark/ml/tree/impl/RandomForest.scala   | 110 -------------------
 .../org/apache/spark/ml/tree/treeModels.scala      | 120 +++++++++++++++++++++
 .../ml/classification/GBTClassifierSuite.scala     |  25 +++++
 .../spark/ml/regression/GBTRegressorSuite.scala    |  23 ++++
 .../spark/ml/tree/impl/RandomForestSuite.scala     |   6 +-
 12 files changed, 213 insertions(+), 135 deletions(-)

diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
index 3e4b21bff6..23c4af17f9 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
@@ -203,7 +203,7 @@ final class DecisionTreeClassificationModel private[ml] (
    *       to determine feature importance instead.
    */
   @Since("2.0.0")
-  lazy val featureImportances: Vector = RandomForest.featureImportances(this, numFeatures)
+  lazy val featureImportances: Vector = TreeEnsembleModel.featureImportances(this, numFeatures)
 
   /** Convert to spark.mllib DecisionTreeModel (losing some infomation) */
   override private[spark] def toOld: OldDecisionTreeModel = {
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
index c31df3aa18..48ce051d0a 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
@@ -238,6 +238,19 @@ final class GBTClassificationModel private[ml](
     s"GBTClassificationModel (uid=$uid) with $numTrees trees"
   }
 
+  /**
+   * Estimate of the importance of each feature.
+   *
+   * Each feature's importance is the average of its importance across all trees in the ensemble
+   * The importance vector is normalized to sum to 1. This method is suggested by Hastie et al.
+   * (Hastie, Tibshirani, Friedman. "The Elements of Statistical Learning, 2nd Edition." 2001.)
+   * and follows the implementation from scikit-learn.
+   *
+   * @see [[DecisionTreeClassificationModel.featureImportances]]
+   */
+  @Since("2.0.0")
+  lazy val featureImportances: Vector = TreeEnsembleModel.featureImportances(trees, numFeatures)
+
   /** (private[ml]) Convert to a model in the old API */
   private[ml] def toOld: OldGBTModel = {
     new OldGBTModel(OldAlgo.Classification, _trees.map(_.toOld), _treeWeights)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
index 5da04d341d..82fa05a604 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
@@ -222,19 +222,15 @@ final class RandomForestClassificationModel private[ml] (
   /**
    * Estimate of the importance of each feature.
    *
-   * This generalizes the idea of "Gini" importance to other losses,
-   * following the explanation of Gini importance from "Random Forests" documentation
-   * by Leo Breiman and Adele Cutler, and following the implementation from scikit-learn.
+   * Each feature's importance is the average of its importance across all trees in the ensemble
+   * The importance vector is normalized to sum to 1. This method is suggested by Hastie et al.
+   * (Hastie, Tibshirani, Friedman. "The Elements of Statistical Learning, 2nd Edition." 2001.)
+   * and follows the implementation from scikit-learn.
    *
-   * This feature importance is calculated as follows:
-   *  - Average over trees:
-   *     - importance(feature j) = sum (over nodes which split on feature j) of the gain,
-   *       where gain is scaled by the number of instances passing through node
-   *     - Normalize importances for tree to sum to 1.
-   *  - Normalize feature importance vector to sum to 1.
+   * @see [[DecisionTreeClassificationModel.featureImportances]]
    */
   @Since("1.5.0")
-  lazy val featureImportances: Vector = RandomForest.featureImportances(trees, numFeatures)
+  lazy val featureImportances: Vector = TreeEnsembleModel.featureImportances(trees, numFeatures)
 
   /** (private[ml]) Convert to a model in the old API */
   private[ml] def toOld: OldRandomForestModel = {
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
index 50ac96eb5e..0a3d00e470 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
@@ -203,7 +203,7 @@ final class DecisionTreeRegressionModel private[ml] (
    *       to determine feature importance instead.
    */
   @Since("2.0.0")
-  lazy val featureImportances: Vector = RandomForest.featureImportances(this, numFeatures)
+  lazy val featureImportances: Vector = TreeEnsembleModel.featureImportances(this, numFeatures)
 
   /** Convert to spark.mllib DecisionTreeModel (losing some infomation) */
   override private[spark] def toOld: OldDecisionTreeModel = {
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
index da5b77e8fa..8fca35da51 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
@@ -224,6 +224,19 @@ final class GBTRegressionModel private[ml](
     s"GBTRegressionModel (uid=$uid) with $numTrees trees"
   }
 
+  /**
+   * Estimate of the importance of each feature.
+   *
+   * Each feature's importance is the average of its importance across all trees in the ensemble
+   * The importance vector is normalized to sum to 1. This method is suggested by Hastie et al.
+   * (Hastie, Tibshirani, Friedman. "The Elements of Statistical Learning, 2nd Edition." 2001.)
+   * and follows the implementation from scikit-learn.
+   *
+   * @see [[DecisionTreeRegressionModel.featureImportances]]
+   */
+  @Since("2.0.0")
+  lazy val featureImportances: Vector = TreeEnsembleModel.featureImportances(trees, numFeatures)
+
   /** (private[ml]) Convert to a model in the old API */
   private[ml] def toOld: OldGBTModel = {
     new OldGBTModel(OldAlgo.Regression, _trees.map(_.toOld), _treeWeights)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
index 798947b94a..5b3f3a1f5d 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
@@ -181,19 +181,15 @@ final class RandomForestRegressionModel private[ml] (
   /**
    * Estimate of the importance of each feature.
    *
-   * This generalizes the idea of "Gini" importance to other losses,
-   * following the explanation of Gini importance from "Random Forests" documentation
-   * by Leo Breiman and Adele Cutler, and following the implementation from scikit-learn.
+   * Each feature's importance is the average of its importance across all trees in the ensemble
+   * The importance vector is normalized to sum to 1. This method is suggested by Hastie et al.
+   * (Hastie, Tibshirani, Friedman. "The Elements of Statistical Learning, 2nd Edition." 2001.)
+   * and follows the implementation from scikit-learn.
    *
-   * This feature importance is calculated as follows:
-   *  - Average over trees:
-   *     - importance(feature j) = sum (over nodes which split on feature j) of the gain,
-   *       where gain is scaled by the number of instances passing through node
-   *     - Normalize importances for tree to sum to 1.
-   *  - Normalize feature importance vector to sum to 1.
+   * @see [[DecisionTreeRegressionModel.featureImportances]]
    */
   @Since("1.5.0")
-  lazy val featureImportances: Vector = RandomForest.featureImportances(trees, numFeatures)
+  lazy val featureImportances: Vector = TreeEnsembleModel.featureImportances(trees, numFeatures)
 
   /** (private[ml]) Convert to a model in the old API */
   private[ml] def toOld: OldRandomForestModel = {
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala
index 1c8a9b4dfe..b37f4e891e 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala
@@ -19,7 +19,9 @@ package org.apache.spark.ml.tree.impl
 
 import org.apache.spark.internal.Logging
 import org.apache.spark.ml.regression.{DecisionTreeRegressionModel, DecisionTreeRegressor}
+import org.apache.spark.ml.tree.DecisionTreeModel
 import org.apache.spark.mllib.impl.PeriodicRDDCheckpointer
+import org.apache.spark.mllib.linalg.Vector
 import org.apache.spark.mllib.regression.LabeledPoint
 import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
 import org.apache.spark.mllib.tree.configuration.{BoostingStrategy => OldBoostingStrategy}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
index 7774ae64e5..cccf052b3e 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
@@ -26,7 +26,6 @@ import org.apache.spark.internal.Logging
 import org.apache.spark.ml.classification.DecisionTreeClassificationModel
 import org.apache.spark.ml.regression.DecisionTreeRegressionModel
 import org.apache.spark.ml.tree._
-import org.apache.spark.mllib.linalg.{Vector, Vectors}
 import org.apache.spark.mllib.regression.LabeledPoint
 import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy}
 import org.apache.spark.mllib.tree.impl.{BaggedPoint, DecisionTreeMetadata, DTStatsAggregator,
@@ -35,7 +34,6 @@ import org.apache.spark.mllib.tree.impurity.ImpurityCalculator
 import org.apache.spark.mllib.tree.model.ImpurityStats
 import org.apache.spark.rdd.RDD
 import org.apache.spark.storage.StorageLevel
-import org.apache.spark.util.collection.OpenHashMap
 import org.apache.spark.util.random.{SamplingUtils, XORShiftRandom}
 
 
@@ -1105,112 +1103,4 @@ private[spark] object RandomForest extends Logging {
     }
   }
 
-  /**
-   * Given a Random Forest model, compute the importance of each feature.
-   * This generalizes the idea of "Gini" importance to other losses,
-   * following the explanation of Gini importance from "Random Forests" documentation
-   * by Leo Breiman and Adele Cutler, and following the implementation from scikit-learn.
-   *
-   * This feature importance is calculated as follows:
-   *  - Average over trees:
-   *     - importance(feature j) = sum (over nodes which split on feature j) of the gain,
-   *       where gain is scaled by the number of instances passing through node
-   *     - Normalize importances for tree to sum to 1.
-   *  - Normalize feature importance vector to sum to 1.
-   *
-   * @param trees  Unweighted forest of trees
-   * @param numFeatures  Number of features in model (even if not all are explicitly used by
-   *                     the model).
-   *                     If -1, then numFeatures is set based on the max feature index in all trees.
-   * @return  Feature importance values, of length numFeatures.
-   */
-  private[ml] def featureImportances(trees: Array[DecisionTreeModel], numFeatures: Int): Vector = {
-    val totalImportances = new OpenHashMap[Int, Double]()
-    trees.foreach { tree =>
-      // Aggregate feature importance vector for this tree
-      val importances = new OpenHashMap[Int, Double]()
-      computeFeatureImportance(tree.rootNode, importances)
-      // Normalize importance vector for this tree, and add it to total.
-      // TODO: In the future, also support normalizing by tree.rootNode.impurityStats.count?
-      val treeNorm = importances.map(_._2).sum
-      if (treeNorm != 0) {
-        importances.foreach { case (idx, impt) =>
-          val normImpt = impt / treeNorm
-          totalImportances.changeValue(idx, normImpt, _ + normImpt)
-        }
-      }
-    }
-    // Normalize importances
-    normalizeMapValues(totalImportances)
-    // Construct vector
-    val d = if (numFeatures != -1) {
-      numFeatures
-    } else {
-      // Find max feature index used in trees
-      val maxFeatureIndex = trees.map(_.maxSplitFeatureIndex()).max
-      maxFeatureIndex + 1
-    }
-    if (d == 0) {
-      assert(totalImportances.size == 0, s"Unknown error in computing feature" +
-        s" importance: No splits found, but some non-zero importances.")
-    }
-    val (indices, values) = totalImportances.iterator.toSeq.sortBy(_._1).unzip
-    Vectors.sparse(d, indices.toArray, values.toArray)
-  }
-
-  /**
-   * Given a Decision Tree model, compute the importance of each feature.
-   * This generalizes the idea of "Gini" importance to other losses,
-   * following the explanation of Gini importance from "Random Forests" documentation
-   * by Leo Breiman and Adele Cutler, and following the implementation from scikit-learn.
-   *
-   * This feature importance is calculated as follows:
-   *  - importance(feature j) = sum (over nodes which split on feature j) of the gain,
-   *    where gain is scaled by the number of instances passing through node
-   *  - Normalize importances for tree to sum to 1.
-   *
-   * @param tree  Decision tree to compute importances for.
-   * @param numFeatures  Number of features in model (even if not all are explicitly used by
-   *                     the model).
-   *                     If -1, then numFeatures is set based on the max feature index in all trees.
-   * @return  Feature importance values, of length numFeatures.
-   */
-  private[ml] def featureImportances(tree: DecisionTreeModel, numFeatures: Int): Vector = {
-    featureImportances(Array(tree), numFeatures)
-  }
-
-  /**
-   * Recursive method for computing feature importances for one tree.
-   * This walks down the tree, adding to the importance of 1 feature at each node.
-   * @param node  Current node in recursion
-   * @param importances  Aggregate feature importances, modified by this method
-   */
-  private[impl] def computeFeatureImportance(
-      node: Node,
-      importances: OpenHashMap[Int, Double]): Unit = {
-    node match {
-      case n: InternalNode =>
-        val feature = n.split.featureIndex
-        val scaledGain = n.gain * n.impurityStats.count
-        importances.changeValue(feature, scaledGain, _ + scaledGain)
-        computeFeatureImportance(n.leftChild, importances)
-        computeFeatureImportance(n.rightChild, importances)
-      case n: LeafNode =>
-        // do nothing
-    }
-  }
-
-  /**
-   * Normalize the values of this map to sum to 1, in place.
-   * If all values are 0, this method does nothing.
-   * @param map  Map with non-negative values.
-   */
-  private[impl] def normalizeMapValues(map: OpenHashMap[Int, Double]): Unit = {
-    val total = map.map(_._2).sum
-    if (total != 0) {
-      val keys = map.iterator.map(_._1).toArray
-      keys.foreach { key => map.changeValue(key, 0.0, _ / total) }
-    }
-  }
-
 }
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala
index ef40c9068f..1fad9d6d8c 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala
@@ -27,6 +27,7 @@ import org.apache.spark.mllib.linalg.{Vector, Vectors}
 import org.apache.spark.mllib.tree.impurity.ImpurityCalculator
 import org.apache.spark.mllib.tree.model.{DecisionTreeModel => OldDecisionTreeModel}
 import org.apache.spark.sql.SQLContext
+import org.apache.spark.util.collection.OpenHashMap
 
 /**
  * Abstraction for Decision Tree models.
@@ -115,6 +116,125 @@ private[ml] trait TreeEnsembleModel {
   lazy val totalNumNodes: Int = trees.map(_.numNodes).sum
 }
 
+private[ml] object TreeEnsembleModel {
+
+  /**
+   * Given a tree ensemble model, compute the importance of each feature.
+   * This generalizes the idea of "Gini" importance to other losses,
+   * following the explanation of Gini importance from "Random Forests" documentation
+   * by Leo Breiman and Adele Cutler, and following the implementation from scikit-learn.
+   *
+   *  For collections of trees, including boosting and bagging, Hastie et al.
+   *  propose to use the average of single tree importances across all trees in the ensemble.
+   *
+   * This feature importance is calculated as follows:
+   *  - Average over trees:
+   *     - importance(feature j) = sum (over nodes which split on feature j) of the gain,
+   *       where gain is scaled by the number of instances passing through node
+   *     - Normalize importances for tree to sum to 1.
+   *  - Normalize feature importance vector to sum to 1.
+   *
+   *  References:
+   *  - Hastie, Tibshirani, Friedman. "The Elements of Statistical Learning, 2nd Edition." 2001.
+   *
+   * @param trees  Unweighted collection of trees
+   * @param numFeatures  Number of features in model (even if not all are explicitly used by
+   *                     the model).
+   *                     If -1, then numFeatures is set based on the max feature index in all trees.
+   * @return  Feature importance values, of length numFeatures.
+   */
+  def featureImportances(trees: Array[DecisionTreeModel], numFeatures: Int): Vector = {
+    val totalImportances = new OpenHashMap[Int, Double]()
+    trees.foreach { tree =>
+      // Aggregate feature importance vector for this tree
+      val importances = new OpenHashMap[Int, Double]()
+      computeFeatureImportance(tree.rootNode, importances)
+      // Normalize importance vector for this tree, and add it to total.
+      // TODO: In the future, also support normalizing by tree.rootNode.impurityStats.count?
+      val treeNorm = importances.map(_._2).sum
+      if (treeNorm != 0) {
+        importances.foreach { case (idx, impt) =>
+          val normImpt = impt / treeNorm
+          totalImportances.changeValue(idx, normImpt, _ + normImpt)
+        }
+      }
+    }
+    // Normalize importances
+    normalizeMapValues(totalImportances)
+    // Construct vector
+    val d = if (numFeatures != -1) {
+      numFeatures
+    } else {
+      // Find max feature index used in trees
+      val maxFeatureIndex = trees.map(_.maxSplitFeatureIndex()).max
+      maxFeatureIndex + 1
+    }
+    if (d == 0) {
+      assert(totalImportances.size == 0, s"Unknown error in computing feature" +
+        s" importance: No splits found, but some non-zero importances.")
+    }
+    val (indices, values) = totalImportances.iterator.toSeq.sortBy(_._1).unzip
+    Vectors.sparse(d, indices.toArray, values.toArray)
+  }
+
+  /**
+   * Given a Decision Tree model, compute the importance of each feature.
+   * This generalizes the idea of "Gini" importance to other losses,
+   * following the explanation of Gini importance from "Random Forests" documentation
+   * by Leo Breiman and Adele Cutler, and following the implementation from scikit-learn.
+   *
+   * This feature importance is calculated as follows:
+   *  - importance(feature j) = sum (over nodes which split on feature j) of the gain,
+   *    where gain is scaled by the number of instances passing through node
+   *  - Normalize importances for tree to sum to 1.
+   *
+   * @param tree  Decision tree to compute importances for.
+   * @param numFeatures  Number of features in model (even if not all are explicitly used by
+   *                     the model).
+   *                     If -1, then numFeatures is set based on the max feature index in all trees.
+   * @return  Feature importance values, of length numFeatures.
+   */
+  def featureImportances(tree: DecisionTreeModel, numFeatures: Int): Vector = {
+    featureImportances(Array(tree), numFeatures)
+  }
+
+  /**
+   * Recursive method for computing feature importances for one tree.
+   * This walks down the tree, adding to the importance of 1 feature at each node.
+   *
+   * @param node  Current node in recursion
+   * @param importances  Aggregate feature importances, modified by this method
+   */
+  def computeFeatureImportance(
+      node: Node,
+      importances: OpenHashMap[Int, Double]): Unit = {
+    node match {
+      case n: InternalNode =>
+        val feature = n.split.featureIndex
+        val scaledGain = n.gain * n.impurityStats.count
+        importances.changeValue(feature, scaledGain, _ + scaledGain)
+        computeFeatureImportance(n.leftChild, importances)
+        computeFeatureImportance(n.rightChild, importances)
+      case n: LeafNode =>
+      // do nothing
+    }
+  }
+
+  /**
+   * Normalize the values of this map to sum to 1, in place.
+   * If all values are 0, this method does nothing.
+   *
+   * @param map  Map with non-negative values.
+   */
+  def normalizeMapValues(map: OpenHashMap[Int, Double]): Unit = {
+    val total = map.map(_._2).sum
+    if (total != 0) {
+      val keys = map.iterator.map(_._1).toArray
+      keys.foreach { key => map.changeValue(key, 0.0, _ / total) }
+    }
+  }
+}
+
 /** Helper classes for tree model persistence */
 private[ml] object DecisionTreeModelReadWrite {
 
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
index f3680ed044..bf7481e8a3 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
@@ -120,6 +120,31 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext {
   }
   */
 
+  /////////////////////////////////////////////////////////////////////////////
+  // Tests of feature importance
+  /////////////////////////////////////////////////////////////////////////////
+  test("Feature importance with toy data") {
+    val numClasses = 2
+    val gbt = new GBTClassifier()
+      .setImpurity("Gini")
+      .setMaxDepth(3)
+      .setMaxIter(5)
+      .setSubsamplingRate(1.0)
+      .setStepSize(0.5)
+      .setSeed(123)
+
+    // In this data, feature 1 is very important.
+    val data: RDD[LabeledPoint] = TreeTests.featureImportanceData(sc)
+    val categoricalFeatures = Map.empty[Int, Int]
+    val df: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses)
+
+    val importances = gbt.fit(df).featureImportances
+    val mostImportantFeature = importances.argmax
+    assert(mostImportantFeature === 1)
+    assert(importances.toArray.sum === 1.0)
+    assert(importances.toArray.forall(_ >= 0.0))
+  }
+
   /////////////////////////////////////////////////////////////////////////////
   // Tests of model save/load
   /////////////////////////////////////////////////////////////////////////////
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
index 84148a8a4a..dfb8418086 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
@@ -131,6 +131,29 @@ class GBTRegressorSuite extends SparkFunSuite with MLlibTestSparkContext {
   }
   */
 
+  /////////////////////////////////////////////////////////////////////////////
+  // Tests of feature importance
+  /////////////////////////////////////////////////////////////////////////////
+  test("Feature importance with toy data") {
+    val gbt = new GBTRegressor()
+      .setMaxDepth(3)
+      .setMaxIter(5)
+      .setSubsamplingRate(1.0)
+      .setStepSize(0.5)
+      .setSeed(123)
+
+    // In this data, feature 1 is very important.
+    val data: RDD[LabeledPoint] = TreeTests.featureImportanceData(sc)
+    val categoricalFeatures = Map.empty[Int, Int]
+    val df: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, 0)
+
+    val importances = gbt.fit(df).featureImportances
+    val mostImportantFeature = importances.argmax
+    assert(mostImportantFeature === 1)
+    assert(importances.toArray.sum === 1.0)
+    assert(importances.toArray.forall(_ >= 0.0))
+  }
+
   /////////////////////////////////////////////////////////////////////////////
   // Tests of model save/load
   /////////////////////////////////////////////////////////////////////////////
diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala
index 361366fde7..441338e74e 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala
@@ -471,7 +471,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
     // Test feature importance computed at different subtrees.
     def testNode(node: Node, expected: Map[Int, Double]): Unit = {
       val map = new OpenHashMap[Int, Double]()
-      RandomForest.computeFeatureImportance(node, map)
+      TreeEnsembleModel.computeFeatureImportance(node, map)
       assert(mapToVec(map.toMap) ~== mapToVec(expected) relTol 0.01)
     }
 
@@ -493,7 +493,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
       new DecisionTreeClassificationModel(root, numFeatures = 2, numClasses = 3)
         .asInstanceOf[DecisionTreeModel]
     }
-    val importances: Vector = RandomForest.featureImportances(trees, 2)
+    val importances: Vector = TreeEnsembleModel.featureImportances(trees, 2)
     val tree2norm = feature0importance + feature1importance
     val expected = Vectors.dense((1.0 + feature0importance / tree2norm) / 2.0,
       (feature1importance / tree2norm) / 2.0)
@@ -504,7 +504,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
     val map = new OpenHashMap[Int, Double]()
     map(0) = 1.0
     map(2) = 2.0
-    RandomForest.normalizeMapValues(map)
+    TreeEnsembleModel.normalizeMapValues(map)
     val expected = Map(0 -> 1.0 / 3.0, 2 -> 2.0 / 3.0)
     assert(mapToVec(map.toMap) ~== mapToVec(expected) relTol 0.01)
   }
-- 
cgit v1.2.3


From 63b200e8d4a05d5b744d437fd10781c6b5429da9 Mon Sep 17 00:00:00 2001
From: "wm624@hotmail.com" 
Date: Mon, 28 Mar 2016 22:33:25 -0700
Subject: [SPARK-14071][PYSPARK][ML] Change MLWritable.write to be a property

Add property to MLWritable.write method, so we can use .write instead of .write()

Add a new test to ml/test.py to check whether the write is a property.
./python/run-tests --python-executables=python2.7 --modules=pyspark-ml

Will test against the following Python executables: ['python2.7']
Will test the following Python modules: ['pyspark-ml']
Finished test(python2.7): pyspark.ml.evaluation (11s)
Finished test(python2.7): pyspark.ml.clustering (16s)
Finished test(python2.7): pyspark.ml.classification (24s)
Finished test(python2.7): pyspark.ml.recommendation (24s)
Finished test(python2.7): pyspark.ml.feature (39s)
Finished test(python2.7): pyspark.ml.regression (26s)
Finished test(python2.7): pyspark.ml.tuning (15s)
Finished test(python2.7): pyspark.ml.tests (30s)
Tests passed in 55 seconds

Author: wm624@hotmail.com 

Closes #11945 from wangmiao1981/fix_property.
---
 python/pyspark/ml/tests.py | 5 +++++
 python/pyspark/ml/util.py  | 4 +++-
 2 files changed, 8 insertions(+), 1 deletion(-)

diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py
index 224232ed7f..f6159b2c95 100644
--- a/python/pyspark/ml/tests.py
+++ b/python/pyspark/ml/tests.py
@@ -51,6 +51,7 @@ from pyspark.ml.param.shared import HasMaxIter, HasInputCol, HasSeed
 from pyspark.ml.regression import LinearRegression, DecisionTreeRegressor
 from pyspark.ml.tuning import *
 from pyspark.ml.util import keyword_only
+from pyspark.ml.util import MLWritable, MLWriter
 from pyspark.ml.wrapper import JavaWrapper
 from pyspark.mllib.linalg import DenseVector, SparseVector
 from pyspark.sql import DataFrame, SQLContext, Row
@@ -655,6 +656,10 @@ class PersistenceTest(PySparkTestCase):
             except OSError:
                 pass
 
+    def test_write_property(self):
+        lr = LinearRegression(maxIter=1)
+        self.assertTrue(isinstance(lr.write, MLWriter))
+
     def test_decisiontree_classifier(self):
         dt = DecisionTreeClassifier(maxDepth=1)
         path = tempfile.mkdtemp()
diff --git a/python/pyspark/ml/util.py b/python/pyspark/ml/util.py
index 6703851262..d4411fdfb9 100644
--- a/python/pyspark/ml/util.py
+++ b/python/pyspark/ml/util.py
@@ -134,13 +134,14 @@ class MLWritable(object):
     .. versionadded:: 2.0.0
     """
 
+    @property
     def write(self):
         """Returns an JavaMLWriter instance for this ML instance."""
         raise NotImplementedError("MLWritable is not yet implemented for type: %r" % type(self))
 
     def save(self, path):
         """Save this ML instance to the given path, a shortcut of `write().save(path)`."""
-        self.write().save(path)
+        self.write.save(path)
 
 
 @inherit_doc
@@ -149,6 +150,7 @@ class JavaMLWritable(MLWritable):
     (Private) Mixin for ML instances that provide :py:class:`JavaMLWriter`.
     """
 
+    @property
     def write(self):
         """Returns an JavaMLWriter instance for this ML instance."""
         return JavaMLWriter(self)
-- 
cgit v1.2.3


From 83775bc78e183791f75a99cdfbcd68a67ca0d472 Mon Sep 17 00:00:00 2001
From: Wenchen Fan 
Date: Tue, 29 Mar 2016 14:34:12 +0800
Subject: [SPARK-14158][SQL] implement buildReader for json data source

## What changes were proposed in this pull request?

This PR implements buildReader for json data source and enable it in the new data source code path.

## How was this patch tested?

existing tests

Author: Wenchen Fan 

Closes #11960 from cloud-fan/json.
---
 .../execution/datasources/FileSourceStrategy.scala |  4 +-
 .../datasources/HadoopFileLinesReader.scala        | 51 ++++++++++++++++++++++
 .../execution/datasources/json/JSONRelation.scala  | 37 +++++++++++++++-
 .../execution/datasources/json/JacksonParser.scala |  2 +-
 4 files changed, 90 insertions(+), 4 deletions(-)
 create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFileLinesReader.scala

diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala
index 4b04fec57d..76a724e51e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala
@@ -58,7 +58,8 @@ private[sql] object FileSourceStrategy extends Strategy with Logging {
     case PhysicalOperation(projects, filters, l @ LogicalRelation(files: HadoopFsRelation, _, _))
       if (files.fileFormat.toString == "TestFileFormat" ||
          files.fileFormat.isInstanceOf[parquet.DefaultSource] ||
-         files.fileFormat.toString == "ORC") &&
+         files.fileFormat.toString == "ORC" ||
+         files.fileFormat.isInstanceOf[json.DefaultSource]) &&
          files.sqlContext.conf.parquetFileScan =>
       // Filters on this relation fall into four categories based on where we can use them to avoid
       // reading unneeded data:
@@ -138,7 +139,6 @@ private[sql] object FileSourceStrategy extends Strategy with Logging {
 
           val splitFiles = selectedPartitions.flatMap { partition =>
             partition.files.flatMap { file =>
-              assert(file.getLen != 0, file.toString)
               (0L to file.getLen by maxSplitBytes).map { offset =>
                 val remaining = file.getLen - offset
                 val size = if (remaining > maxSplitBytes) maxSplitBytes else remaining
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFileLinesReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFileLinesReader.scala
new file mode 100644
index 0000000000..18f9b55895
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/HadoopFileLinesReader.scala
@@ -0,0 +1,51 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.datasources
+
+import java.net.URI
+
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.fs.Path
+import org.apache.hadoop.io.Text
+import org.apache.hadoop.mapreduce._
+import org.apache.hadoop.mapreduce.lib.input.{FileSplit, LineRecordReader}
+import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl
+
+/**
+ * An adaptor from a [[PartitionedFile]] to an [[Iterator]] of [[Text]], which are all of the lines
+ * in that file.
+ */
+class HadoopFileLinesReader(file: PartitionedFile, conf: Configuration) extends Iterator[Text] {
+  private val iterator = {
+    val fileSplit = new FileSplit(
+      new Path(new URI(file.filePath)),
+      file.start,
+      file.length,
+      // TODO: Implement Locality
+      Array.empty)
+    val attemptId = new TaskAttemptID(new TaskID(new JobID(), TaskType.MAP, 0), 0)
+    val hadoopAttemptContext = new TaskAttemptContextImpl(conf, attemptId)
+    val reader = new LineRecordReader()
+    reader.initialize(fileSplit, hadoopAttemptContext)
+    new RecordReaderIterator(reader)
+  }
+
+  override def hasNext: Boolean = iterator.hasNext
+
+  override def next(): Text = iterator.next()
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala
index 3bf0af0efa..21fc1224ef 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.datasources.json
 import java.io.CharArrayWriter
 
 import com.fasterxml.jackson.core.JsonFactory
+import org.apache.hadoop.conf.Configuration
 import org.apache.hadoop.fs.{FileStatus, Path}
 import org.apache.hadoop.io.{LongWritable, NullWritable, Text}
 import org.apache.hadoop.mapred.{JobConf, TextInputFormat}
@@ -32,7 +33,8 @@ import org.apache.spark.internal.Logging
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.{AnalysisException, Row, SQLContext}
 import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.UnsafeProjection
+import org.apache.spark.sql.catalyst.expressions.{JoinedRow, UnsafeProjection}
+import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection
 import org.apache.spark.sql.execution.datasources._
 import org.apache.spark.sql.sources._
 import org.apache.spark.sql.types.StructType
@@ -120,6 +122,39 @@ class DefaultSource extends FileFormat with DataSourceRegister {
     }
   }
 
+  override def buildReader(
+      sqlContext: SQLContext,
+      partitionSchema: StructType,
+      dataSchema: StructType,
+      filters: Seq[Filter],
+      options: Map[String, String]): PartitionedFile => Iterator[InternalRow] = {
+    val conf = new Configuration(sqlContext.sparkContext.hadoopConfiguration)
+    val broadcastedConf =
+      sqlContext.sparkContext.broadcast(new SerializableConfiguration(conf))
+
+    val parsedOptions: JSONOptions = new JSONOptions(options)
+    val columnNameOfCorruptRecord = parsedOptions.columnNameOfCorruptRecord
+      .getOrElse(sqlContext.conf.columnNameOfCorruptRecord)
+
+    val fullSchema = dataSchema.toAttributes ++ partitionSchema.toAttributes
+    val joinedRow = new JoinedRow()
+
+    file => {
+      val lines = new HadoopFileLinesReader(file, broadcastedConf.value.value).map(_.toString)
+
+      val rows = JacksonParser.parseJson(
+        lines,
+        dataSchema,
+        columnNameOfCorruptRecord,
+        parsedOptions)
+
+      val appendPartitionColumns = GenerateUnsafeProjection.generate(fullSchema, fullSchema)
+      rows.map { row =>
+        appendPartitionColumns(joinedRow(row, file.partitionValues))
+      }
+    }
+  }
+
   private def createBaseRdd(sqlContext: SQLContext, inputPaths: Seq[FileStatus]): RDD[String] = {
     val job = Job.getInstance(sqlContext.sparkContext.hadoopConfiguration)
     val conf = job.getConfiguration
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonParser.scala
index 00c14adf07..8bc53bae6c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonParser.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonParser.scala
@@ -250,7 +250,7 @@ object JacksonParser extends Logging {
     new GenericArrayData(values.toArray)
   }
 
-  private def parseJson(
+  def parseJson(
       input: Iterator[String],
       schema: StructType,
       columnNameOfCorruptRecords: String,
-- 
cgit v1.2.3


From 425bcf6d6844732fe402af05472ad87b4e032cb6 Mon Sep 17 00:00:00 2001
From: Bryan Cutler 
Date: Tue, 29 Mar 2016 12:30:30 +0200
Subject: [SPARK-13963][ML] Adding binary toggle param to HashingTF

## What changes were proposed in this pull request?
Adding binary toggle parameter to ml.feature.HashingTF, as well as mllib.feature.HashingTF since the former wraps this functionality.  This parameter, if true, will set non-zero valued term counts to 1 to transform term count features to binary values that are well suited for discrete probability models.

## How was this patch tested?
Added unit tests for ML and MLlib

Author: Bryan Cutler 

Closes #11832 from BryanCutler/binary-param-HashingTF-SPARK-13963.
---
 .../org/apache/spark/ml/feature/HashingTF.scala    | 23 ++++++++++++++++++---
 .../org/apache/spark/mllib/feature/HashingTF.scala | 15 +++++++++++++-
 .../apache/spark/ml/feature/HashingTFSuite.scala   | 24 +++++++++++++++++++++-
 .../spark/mllib/feature/HashingTFSuite.scala       | 12 +++++++++++
 4 files changed, 69 insertions(+), 5 deletions(-)

diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala
index 61a78d73c4..0f7ae5a100 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala
@@ -20,7 +20,7 @@ package org.apache.spark.ml.feature
 import org.apache.spark.annotation.{Experimental, Since}
 import org.apache.spark.ml.Transformer
 import org.apache.spark.ml.attribute.AttributeGroup
-import org.apache.spark.ml.param.{IntParam, ParamMap, ParamValidators}
+import org.apache.spark.ml.param.{BooleanParam, IntParam, ParamMap, ParamValidators}
 import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol}
 import org.apache.spark.ml.util._
 import org.apache.spark.mllib.feature
@@ -52,7 +52,18 @@ class HashingTF(override val uid: String)
   val numFeatures = new IntParam(this, "numFeatures", "number of features (> 0)",
     ParamValidators.gt(0))
 
-  setDefault(numFeatures -> (1 << 18))
+  /**
+   * Binary toggle to control term frequency counts.
+   * If true, all non-zero counts are set to 1.  This is useful for discrete probabilistic
+   * models that model binary events rather than integer counts.
+   * (default = false)
+   * @group param
+   */
+  val binary = new BooleanParam(this, "binary", "If true, all non zero counts are set to 1. " +
+    "This is useful for discrete probabilistic models that model binary events rather " +
+    "than integer counts")
+
+  setDefault(numFeatures -> (1 << 18), binary -> false)
 
   /** @group getParam */
   def getNumFeatures: Int = $(numFeatures)
@@ -60,9 +71,15 @@ class HashingTF(override val uid: String)
   /** @group setParam */
   def setNumFeatures(value: Int): this.type = set(numFeatures, value)
 
+  /** @group getParam */
+  def getBinary: Boolean = $(binary)
+
+  /** @group setParam */
+  def setBinary(value: Boolean): this.type = set(binary, value)
+
   override def transform(dataset: DataFrame): DataFrame = {
     val outputSchema = transformSchema(dataset.schema)
-    val hashingTF = new feature.HashingTF($(numFeatures))
+    val hashingTF = new feature.HashingTF($(numFeatures)).setBinary($(binary))
     val t = udf { terms: Seq[_] => hashingTF.transform(terms) }
     val metadata = outputSchema($(outputCol)).metadata
     dataset.select(col("*"), t(col($(inputCol))).as($(outputCol), metadata))
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/HashingTF.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/HashingTF.scala
index c93ed64183..47c9e850a0 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/feature/HashingTF.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/HashingTF.scala
@@ -36,11 +36,23 @@ import org.apache.spark.util.Utils
 @Since("1.1.0")
 class HashingTF(val numFeatures: Int) extends Serializable {
 
+  private var binary = false
+
   /**
    */
   @Since("1.1.0")
   def this() = this(1 << 20)
 
+  /**
+   * If true, term frequency vector will be binary such that non-zero term counts will be set to 1
+   * (default: false)
+   */
+  @Since("2.0.0")
+  def setBinary(value: Boolean): this.type = {
+    binary = value
+    this
+  }
+
   /**
    * Returns the index of the input term.
    */
@@ -53,9 +65,10 @@ class HashingTF(val numFeatures: Int) extends Serializable {
   @Since("1.1.0")
   def transform(document: Iterable[_]): Vector = {
     val termFrequencies = mutable.HashMap.empty[Int, Double]
+    val setTF = if (binary) (i: Int) => 1.0 else (i: Int) => termFrequencies.getOrElse(i, 0.0) + 1.0
     document.foreach { term =>
       val i = indexOf(term)
-      termFrequencies.put(i, termFrequencies.getOrElse(i, 0.0) + 1.0)
+      termFrequencies.put(i, setTF(i))
     }
     Vectors.sparse(numFeatures, termFrequencies.toSeq)
   }
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala
index 0dcd0f4946..addd733c20 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala
@@ -46,12 +46,30 @@ class HashingTFSuite extends SparkFunSuite with MLlibTestSparkContext with Defau
     require(attrGroup.numAttributes === Some(n))
     val features = output.select("features").first().getAs[Vector](0)
     // Assume perfect hash on "a", "b", "c", and "d".
-    def idx(any: Any): Int = Utils.nonNegativeMod(any.##, n)
+    def idx: Any => Int = featureIdx(n)
     val expected = Vectors.sparse(n,
       Seq((idx("a"), 2.0), (idx("b"), 2.0), (idx("c"), 1.0), (idx("d"), 1.0)))
     assert(features ~== expected absTol 1e-14)
   }
 
+  test("applying binary term freqs") {
+    val df = sqlContext.createDataFrame(Seq(
+      (0, "a a b c c c".split(" ").toSeq)
+    )).toDF("id", "words")
+    val n = 100
+    val hashingTF = new HashingTF()
+        .setInputCol("words")
+        .setOutputCol("features")
+        .setNumFeatures(n)
+        .setBinary(true)
+    val output = hashingTF.transform(df)
+    val features = output.select("features").first().getAs[Vector](0)
+    def idx: Any => Int = featureIdx(n)  // Assume perfect hash on input features
+    val expected = Vectors.sparse(n,
+      Seq((idx("a"), 1.0), (idx("b"), 1.0), (idx("c"), 1.0)))
+    assert(features ~== expected absTol 1e-14)
+  }
+
   test("read/write") {
     val t = new HashingTF()
       .setInputCol("myInputCol")
@@ -59,4 +77,8 @@ class HashingTFSuite extends SparkFunSuite with MLlibTestSparkContext with Defau
       .setNumFeatures(10)
     testDefaultReadWrite(t)
   }
+
+  private def featureIdx(numFeatures: Int)(term: Any): Int = {
+    Utils.nonNegativeMod(term.##, numFeatures)
+  }
 }
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/HashingTFSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/HashingTFSuite.scala
index cf279c0233..6c07e3a5ce 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/feature/HashingTFSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/HashingTFSuite.scala
@@ -20,6 +20,7 @@ package org.apache.spark.mllib.feature
 import org.apache.spark.SparkFunSuite
 import org.apache.spark.mllib.linalg.Vectors
 import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.mllib.util.TestingUtils._
 
 class HashingTFSuite extends SparkFunSuite with MLlibTestSparkContext {
 
@@ -48,4 +49,15 @@ class HashingTFSuite extends SparkFunSuite with MLlibTestSparkContext {
     val docs = sc.parallelize(localDocs, 2)
     assert(hashingTF.transform(docs).collect().toSet === localDocs.map(hashingTF.transform).toSet)
   }
+
+  test("applying binary term freqs") {
+    val hashingTF = new HashingTF(100).setBinary(true)
+    val doc = "a a b c c c".split(" ")
+    val n = hashingTF.numFeatures
+    val expected = Vectors.sparse(n, Seq(
+      (hashingTF.indexOf("a"), 1.0),
+      (hashingTF.indexOf("b"), 1.0),
+      (hashingTF.indexOf("c"), 1.0)))
+    assert(hashingTF.transform(doc) ~== expected absTol 1e-14)
+  }
 }
-- 
cgit v1.2.3


From a632bb56f8867df39a78d7f01fb870f548b09815 Mon Sep 17 00:00:00 2001
From: Cheng Lian 
Date: Tue, 29 Mar 2016 20:56:01 +0800
Subject: [SPARK-14208][SQL] Renames spark.sql.parquet.fileScan

## What changes were proposed in this pull request?

Renames SQL option `spark.sql.parquet.fileScan` since now all `HadoopFsRelation` based data sources are being migrated to `FileScanRDD` code path.

## How was this patch tested?

None.

Author: Cheng Lian 

Closes #12003 from liancheng/spark-14208-option-renaming.
---
 .../spark/sql/execution/datasources/FileSourceStrategy.scala      | 2 +-
 .../src/main/scala/org/apache/spark/sql/internal/SQLConf.scala    | 8 ++++----
 2 files changed, 5 insertions(+), 5 deletions(-)

diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala
index 76a724e51e..20fda95154 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala
@@ -60,7 +60,7 @@ private[sql] object FileSourceStrategy extends Strategy with Logging {
          files.fileFormat.isInstanceOf[parquet.DefaultSource] ||
          files.fileFormat.toString == "ORC" ||
          files.fileFormat.isInstanceOf[json.DefaultSource]) &&
-         files.sqlContext.conf.parquetFileScan =>
+         files.sqlContext.conf.useFileScan =>
       // Filters on this relation fall into four categories based on where we can use them to avoid
       // reading unneeded data:
       //  - partition keys only - used to prune directories to read
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index 77af0e000b..ca6ba4c643 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -288,9 +288,9 @@ object SQLConf {
     defaultValue = Some(true),
     doc = "Whether the query analyzer should be case sensitive or not.")
 
-  val PARQUET_FILE_SCAN = booleanConf("spark.sql.parquet.fileScan",
+  val USE_FILE_SCAN = booleanConf("spark.sql.sources.fileScan",
     defaultValue = Some(true),
-    doc = "Use the new FileScanRDD path for reading parquet data.",
+    doc = "Use the new FileScanRDD path for reading HDSF based data sources.",
     isPublic = false)
 
   val PARQUET_SCHEMA_MERGING_ENABLED = booleanConf("spark.sql.parquet.mergeSchema",
@@ -583,9 +583,9 @@ class SQLConf extends Serializable with CatalystConf with ParserConf with Loggin
 
   def useCompression: Boolean = getConf(COMPRESS_CACHED)
 
-  def parquetCompressionCodec: String = getConf(PARQUET_COMPRESSION)
+  def useFileScan: Boolean = getConf(USE_FILE_SCAN)
 
-  def parquetFileScan: Boolean = getConf(PARQUET_FILE_SCAN)
+  def parquetCompressionCodec: String = getConf(PARQUET_COMPRESSION)
 
   def parquetCacheMetadata: Boolean = getConf(PARQUET_CACHE_METADATA)
 
-- 
cgit v1.2.3


From d2a819a6363190b946986ebf6f8001d520098c3b Mon Sep 17 00:00:00 2001
From: Yuhao Yang 
Date: Tue, 29 Mar 2016 09:16:50 -0700
Subject: [SPARK-14154][MLLIB] Simplify the implementation for
 Kolmogorov–Smirnov test
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

## What changes were proposed in this pull request?
jira: https://issues.apache.org/jira/browse/SPARK-14154

I just read the code for KolmogorovSmirnovTest and find it could be much simplified following the original definition.

Send a PR for discussion

## How was this patch tested?
unit test

Author: Yuhao Yang 

Closes #11954 from hhbyyh/ksoptimize.
---
 .../mllib/stat/test/KolmogorovSmirnovTest.scala    | 77 ++--------------------
 1 file changed, 4 insertions(+), 73 deletions(-)

diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/test/KolmogorovSmirnovTest.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/KolmogorovSmirnovTest.scala
index baf9e5e7d1..0ec8975fed 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/stat/test/KolmogorovSmirnovTest.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/KolmogorovSmirnovTest.scala
@@ -64,11 +64,10 @@ private[stat] object KolmogorovSmirnovTest extends Logging {
    */
   def testOneSample(data: RDD[Double], cdf: Double => Double): KolmogorovSmirnovTestResult = {
     val n = data.count().toDouble
-    val localData = data.sortBy(x => x).mapPartitions { part =>
-      val partDiffs = oneSampleDifferences(part, n, cdf) // local distances
-      searchOneSampleCandidates(partDiffs) // candidates: local extrema
-    }.collect()
-    val ksStat = searchOneSampleStatistic(localData, n) // result: global extreme
+    val ksStat = data.sortBy(x => x).zipWithIndex().map { case (v, i) =>
+      val f = cdf(v)
+      math.max(f - i / n, (i + 1) / n - f)
+    }.max()
     evalOneSampleP(ksStat, n.toLong)
   }
 
@@ -84,74 +83,6 @@ private[stat] object KolmogorovSmirnovTest extends Logging {
     testOneSample(data, cdf)
   }
 
-  /**
-   * Calculate unadjusted distances between the empirical CDF and the theoretical CDF in a
-   * partition
-   * @param partData `Iterator[Double]` 1 partition of a sorted RDD
-   * @param n `Double` the total size of the RDD
-   * @param cdf `Double => Double` a function the calculates the theoretical CDF of a value
-   * @return `Iterator[(Double, Double)] `Unadjusted (ie. off by a constant) potential extrema
-   *        in a partition. The first element corresponds to the (empirical CDF - 1/N) - CDF,
-   *        the second element corresponds to empirical CDF - CDF.  We can then search the resulting
-   *        iterator for the minimum of the first and the maximum of the second element, and provide
-   *        this as a partition's candidate extrema
-   */
-  private def oneSampleDifferences(partData: Iterator[Double], n: Double, cdf: Double => Double)
-    : Iterator[(Double, Double)] = {
-    // zip data with index (within that partition)
-    // calculate local (unadjusted) empirical CDF and subtract CDF
-    partData.zipWithIndex.map { case (v, ix) =>
-      // dp and dl are later adjusted by constant, when global info is available
-      val dp = (ix + 1) / n
-      val dl = ix / n
-      val cdfVal = cdf(v)
-      (dl - cdfVal, dp - cdfVal)
-    }
-  }
-
-  /**
-   * Search the unadjusted differences in a partition and return the
-   * two extrema (furthest below and furthest above CDF), along with a count of elements in that
-   * partition
-   * @param partDiffs `Iterator[(Double, Double)]` the unadjusted differences between empirical CDF
-   *                 and CDFin a partition, which come as a tuple of
-   *                 (empirical CDF - 1/N - CDF, empirical CDF - CDF)
-   * @return `Iterator[(Double, Double, Double)]` the local extrema and a count of elements
-   */
-  private def searchOneSampleCandidates(partDiffs: Iterator[(Double, Double)])
-    : Iterator[(Double, Double, Double)] = {
-    val initAcc = (Double.MaxValue, Double.MinValue, 0.0)
-    val pResults = partDiffs.foldLeft(initAcc) { case ((pMin, pMax, pCt), (dl, dp)) =>
-      (math.min(pMin, dl), math.max(pMax, dp), pCt + 1)
-    }
-    val results = if (pResults == initAcc) Array[(Double, Double, Double)]() else Array(pResults)
-    results.iterator
-  }
-
-  /**
-   * Find the global maximum distance between empirical CDF and CDF (i.e. the KS statistic) after
-   * adjusting local extrema estimates from individual partitions with the amount of elements in
-   * preceding partitions
-   * @param localData `Array[(Double, Double, Double)]` A local array containing the collected
-   *                 results of `searchOneSampleCandidates` across all partitions
-   * @param n `Double`The size of the RDD
-   * @return The one-sample Kolmogorov Smirnov Statistic
-   */
-  private def searchOneSampleStatistic(localData: Array[(Double, Double, Double)], n: Double)
-    : Double = {
-    val initAcc = (Double.MinValue, 0.0)
-    // adjust differences based on the number of elements preceding it, which should provide
-    // the correct distance between empirical CDF and CDF
-    val results = localData.foldLeft(initAcc) { case ((prevMax, prevCt), (minCand, maxCand, ct)) =>
-      val adjConst = prevCt / n
-      val dist1 = math.abs(minCand + adjConst)
-      val dist2 = math.abs(maxCand + adjConst)
-      val maxVal = Array(prevMax, dist1, dist2).max
-      (maxVal, prevCt + ct)
-    }
-    results._1
-  }
-
   /**
    * A convenience function that allows running the KS test for 1 set of sample data against
    * a named distribution
-- 
cgit v1.2.3


From 15c0b0006b3d04434b505210df541aeb28a51de8 Mon Sep 17 00:00:00 2001
From: Carson Wang 
Date: Tue, 29 Mar 2016 11:07:58 -0700
Subject: [SPARK-14232][WEBUI] Fix event timeline display issue when an
 executor is removed with a multiple line reason.

## What changes were proposed in this pull request?
The event timeline doesn't show on job page if an executor is removed with a multiple line reason. This PR replaces all new line characters in the reason string with spaces.

![timelineerror](https://cloud.githubusercontent.com/assets/9278199/14100211/5fd4cd30-f5be-11e5-9cea-f32651a4cd62.jpg)

## How was this patch tested?
Verified on the Web UI.

Author: Carson Wang 

Closes #12029 from carsonwang/eventTimeline.
---
 core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala | 2 +-
 core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala     | 2 +-
 2 files changed, 2 insertions(+), 2 deletions(-)

diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala
index d1c8b3089a..d5f15f160b 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala
@@ -148,7 +148,7 @@ private[ui] class AllJobsPage(parent: JobsTab) extends WebUIPage("") {
                |    'Removed at ${UIUtils.formatDate(new Date(event.finishTime.get))}' +
                |    '${
                         if (event.finishReason.isDefined) {
-                          s"""
Reason: ${event.finishReason.get}""" + s"""
Reason: ${event.finishReason.get.replace("\n", " ")}""" } else { "" } diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala index 654d988807..645e2d2e36 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala @@ -122,7 +122,7 @@ private[ui] class JobPage(parent: JobsTab) extends WebUIPage("job") { | 'Removed at ${UIUtils.formatDate(new Date(event.finishTime.get))}' + | '${ if (event.finishReason.isDefined) { - s"""
Reason: ${event.finishReason.get}""" + s"""
Reason: ${event.finishReason.get.replace("\n", " ")}""" } else { "" } -- cgit v1.2.3 From d26c42982c18da8fb1b21c9eb75aef9364d1b992 Mon Sep 17 00:00:00 2001 From: Jakob Odersky Date: Tue, 29 Mar 2016 11:10:15 -0700 Subject: [SPARK-10570][CORE] Add version info to json api Add a new api endpoint `/api/v1/version` to retrieve various version info. This PR only adds support for finding the current spark version, however other version info such as jvm or scala versions can easily be added. Author: Jakob Odersky Closes #10760 from jodersky/SPARK-10570. --- .../spark/status/api/v1/ApiRootResource.scala | 6 +++++ .../spark/status/api/v1/VersionResource.scala | 30 ++++++++++++++++++++++ .../scala/org/apache/spark/status/api/v1/api.scala | 3 +++ 3 files changed, 39 insertions(+) create mode 100644 core/src/main/scala/org/apache/spark/status/api/v1/VersionResource.scala diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/ApiRootResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/ApiRootResource.scala index 50b6ba67e9..ba9cd711f1 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/ApiRootResource.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/ApiRootResource.scala @@ -177,6 +177,12 @@ private[v1] class ApiRootResource extends UIRootFromServletContext { @PathParam("attemptId") attemptId: String): EventLogDownloadResource = { new EventLogDownloadResource(uiRoot, appId, Some(attemptId)) } + + @Path("version") + def getVersion(): VersionResource = { + new VersionResource(uiRoot) + } + } private[spark] object ApiRootResource { diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/VersionResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/VersionResource.scala new file mode 100644 index 0000000000..673da1ce36 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/status/api/v1/VersionResource.scala @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.status.api.v1 + +import javax.ws.rs._ +import javax.ws.rs.core.MediaType + +@Produces(Array(MediaType.APPLICATION_JSON)) +private[v1] class VersionResource(ui: UIRoot) { + + @GET + def getVersionInfo(): VersionInfo = new VersionInfo( + org.apache.spark.SPARK_VERSION + ) + +} diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/api.scala b/core/src/main/scala/org/apache/spark/status/api/v1/api.scala index 909dd0c07e..d43868bbcb 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/api.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/api.scala @@ -237,3 +237,6 @@ class AccumulableInfo private[spark]( val name: String, val update: Option[String], val value: String) + +class VersionInfo private[spark]( + val spark: String) -- cgit v1.2.3 From d612228eff9ed6589b2a94658986ec06ed833bf5 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Tue, 29 Mar 2016 12:45:43 -0700 Subject: [MINOR][SQL] Fix typos by replacing 'much' with 'match'. ## What changes were proposed in this pull request? This PR fixes two trivial typos: 'does not **much**' --> 'does not **match**'. ## How was this patch tested? Manual. Author: Dongjoon Hyun Closes #12042 from dongjoon-hyun/fix_typo_by_replacing_much_with_match. --- .../main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala | 2 +- .../src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala index 8bb8e09a28..9f270236ae 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala @@ -291,5 +291,5 @@ case class CatalogRelation( override def output: Seq[Attribute] = Seq.empty require(metadata.identifier.database == Some(db), - "provided database does not much the one specified in the table definition") + "provided database does not match the one specified in the table definition") } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala index f75509fe80..11205ae67c 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala @@ -76,7 +76,7 @@ private[spark] class HiveExternalCatalog(client: HiveClient) extends ExternalCat private def requireDbMatches(db: String, table: CatalogTable): Unit = { if (table.identifier.database != Some(db)) { throw new AnalysisException( - s"Provided database $db does not much the one specified in the " + + s"Provided database $db does not match the one specified in the " + s"table definition (${table.identifier.database.getOrElse("n/a")})") } } -- cgit v1.2.3 From 838cb4583dd68a9d7ef3bd74b9b4d33f32f177fc Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Tue, 29 Mar 2016 12:47:30 -0700 Subject: [MINOR][SQL] Fix exception message to print string-array correctly. ## What changes were proposed in this pull request? This PR is a simple fix for an exception message to print `string[]` content correctly. ```java String[] colPath = requestedSchema.getPaths().get(i); ... - throw new IOException("Required column is missing in data file. Col: " + colPath); + throw new IOException("Required column is missing in data file. Col: " + Arrays.toString(colPath)); ``` ## How was this patch tested? Manual. Author: Dongjoon Hyun Closes #12041 from dongjoon-hyun/fix_exception_message_with_string_array. --- .../execution/datasources/parquet/VectorizedParquetRecordReader.java | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java index 5bfde55c3b..0bdf4aab29 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution.datasources.parquet; import java.io.IOException; +import java.util.Arrays; import java.util.List; import org.apache.hadoop.mapreduce.InputSplit; @@ -269,7 +270,8 @@ public class VectorizedParquetRecordReader extends SpecificParquetRecordReaderBa } else { if (requestedSchema.getColumns().get(i).getMaxDefinitionLevel() == 0) { // Column is missing in data but the required data is non-nullable. This file is invalid. - throw new IOException("Required column is missing in data file. Col: " + colPath); + throw new IOException("Required column is missing in data file. Col: " + + Arrays.toString(colPath)); } missingColumns[i] = true; } -- cgit v1.2.3 From e58c4cb3c5a95f44e357b99a2f0d0e1201d91e7a Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Tue, 29 Mar 2016 13:31:51 -0700 Subject: [SPARK-14227][SQL] Add method for printing out generated code for debugging ## What changes were proposed in this pull request? This adds `debugCodegen` to the debug package for query execution. ## How was this patch tested? Unit and manual testing. Output example: ``` scala> import org.apache.spark.sql.execution.debug._ import org.apache.spark.sql.execution.debug._ scala> sqlContext.range(100).groupBy("id").count().orderBy("id").debugCodegen() Found 3 WholeStageCodegen subtrees. == Subtree 1 / 3 == WholeStageCodegen : +- TungstenAggregate(key=[id#0L], functions=[(count(1),mode=Partial,isDistinct=false)], output=[id#0L,count#9L]) : +- Range 0, 1, 1, 100, [id#0L] Generated code: /* 001 */ public Object generate(Object[] references) { /* 002 */ return new GeneratedIterator(references); /* 003 */ } /* 004 */ /* 005 */ /** Codegened pipeline for: /* 006 */ * TungstenAggregate(key=[id#0L], functions=[(count(1),mode=Partial,isDistinct=false)], output=[id#0L,count#9L]) /* 007 */ +- Range 0, 1, 1, 100, [id#0L] /* 008 */ */ /* 009 */ final class GeneratedIterator extends org.apache.spark.sql.execution.BufferedRowIterator { /* 010 */ private Object[] references; /* 011 */ private boolean agg_initAgg; /* 012 */ private org.apache.spark.sql.execution.aggregate.TungstenAggregate agg_plan; /* 013 */ private org.apache.spark.sql.execution.UnsafeFixedWidthAggregationMap agg_hashMap; /* 014 */ private org.apache.spark.sql.execution.UnsafeKVExternalSorter agg_sorter; /* 015 */ private org.apache.spark.unsafe.KVIterator agg_mapIter; /* 016 */ private org.apache.spark.sql.execution.metric.LongSQLMetric range_numOutputRows; /* 017 */ private org.apache.spark.sql.execution.metric.LongSQLMetricValue range_metricValue; /* 018 */ private boolean range_initRange; /* 019 */ private long range_partitionEnd; /* 020 */ private long range_number; /* 021 */ private boolean range_overflow; /* 022 */ private scala.collection.Iterator range_input; /* 023 */ private UnsafeRow range_result; /* 024 */ private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder range_holder; /* 025 */ private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter range_rowWriter; /* 026 */ private UnsafeRow agg_result; /* 027 */ private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder agg_holder; /* 028 */ private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter agg_rowWriter; /* 029 */ private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowJoiner agg_unsafeRowJoiner; /* 030 */ private org.apache.spark.sql.execution.metric.LongSQLMetric wholestagecodegen_numOutputRows; /* 031 */ private org.apache.spark.sql.execution.metric.LongSQLMetricValue wholestagecodegen_metricValue; /* 032 */ /* 033 */ public GeneratedIterator(Object[] references) { /* 034 */ this.references = references; /* 035 */ } /* 036 */ /* 037 */ public void init(scala.collection.Iterator inputs[]) { /* 038 */ agg_initAgg = false; /* 039 */ this.agg_plan = (org.apache.spark.sql.execution.aggregate.TungstenAggregate) references[0]; /* 040 */ agg_hashMap = agg_plan.createHashMap(); /* 041 */ /* 042 */ this.range_numOutputRows = (org.apache.spark.sql.execution.metric.LongSQLMetric) references[1]; /* 043 */ range_metricValue = (org.apache.spark.sql.execution.metric.LongSQLMetricValue) range_numOutputRows.localValue(); /* 044 */ range_initRange = false; /* 045 */ range_partitionEnd = 0L; /* 046 */ range_number = 0L; /* 047 */ range_overflow = false; /* 048 */ range_input = inputs[0]; /* 049 */ range_result = new UnsafeRow(1); /* 050 */ this.range_holder = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(range_result, 0); /* 051 */ this.range_rowWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(range_holder, 1); /* 052 */ agg_result = new UnsafeRow(1); /* 053 */ this.agg_holder = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(agg_result, 0); /* 054 */ this.agg_rowWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(agg_holder, 1); /* 055 */ agg_unsafeRowJoiner = agg_plan.createUnsafeJoiner(); /* 056 */ this.wholestagecodegen_numOutputRows = (org.apache.spark.sql.execution.metric.LongSQLMetric) references[2]; /* 057 */ wholestagecodegen_metricValue = (org.apache.spark.sql.execution.metric.LongSQLMetricValue) wholestagecodegen_numOutputRows.localValue(); /* 058 */ } /* 059 */ /* 060 */ private void agg_doAggregateWithKeys() throws java.io.IOException { /* 061 */ /*** PRODUCE: Range 0, 1, 1, 100, [id#0L] */ /* 062 */ /* 063 */ // initialize Range /* 064 */ if (!range_initRange) { /* 065 */ range_initRange = true; /* 066 */ if (range_input.hasNext()) { /* 067 */ initRange(((InternalRow) range_input.next()).getInt(0)); /* 068 */ } else { /* 069 */ return; /* 070 */ } /* 071 */ } /* 072 */ /* 073 */ while (!range_overflow && range_number < range_partitionEnd) { /* 074 */ long range_value = range_number; /* 075 */ range_number += 1L; /* 076 */ if (range_number < range_value ^ 1L < 0) { /* 077 */ range_overflow = true; /* 078 */ } /* 079 */ /* 080 */ /*** CONSUME: TungstenAggregate(key=[id#0L], functions=[(count(1),mode=Partial,isDistinct=false)], output=[id#0L,count#9L]) */ /* 081 */ /* 082 */ // generate grouping key /* 083 */ agg_rowWriter.write(0, range_value); /* 084 */ /* hash(input[0, bigint], 42) */ /* 085 */ int agg_value1 = 42; /* 086 */ /* 087 */ agg_value1 = org.apache.spark.unsafe.hash.Murmur3_x86_32.hashLong(range_value, agg_value1); /* 088 */ UnsafeRow agg_aggBuffer = null; /* 089 */ if (true) { /* 090 */ // try to get the buffer from hash map /* 091 */ agg_aggBuffer = agg_hashMap.getAggregationBufferFromUnsafeRow(agg_result, agg_value1); /* 092 */ } /* 093 */ if (agg_aggBuffer == null) { /* 094 */ if (agg_sorter == null) { /* 095 */ agg_sorter = agg_hashMap.destructAndCreateExternalSorter(); /* 096 */ } else { /* 097 */ agg_sorter.merge(agg_hashMap.destructAndCreateExternalSorter()); /* 098 */ } /* 099 */ /* 100 */ // the hash map had be spilled, it should have enough memory now, /* 101 */ // try to allocate buffer again. /* 102 */ agg_aggBuffer = agg_hashMap.getAggregationBufferFromUnsafeRow(agg_result, agg_value1); /* 103 */ if (agg_aggBuffer == null) { /* 104 */ // failed to allocate the first page /* 105 */ throw new OutOfMemoryError("No enough memory for aggregation"); /* 106 */ } /* 107 */ } /* 108 */ /* 109 */ // evaluate aggregate function /* 110 */ /* (input[0, bigint] + 1) */ /* 111 */ /* input[0, bigint] */ /* 112 */ long agg_value4 = agg_aggBuffer.getLong(0); /* 113 */ /* 114 */ long agg_value3 = -1L; /* 115 */ agg_value3 = agg_value4 + 1L; /* 116 */ // update aggregate buffer /* 117 */ agg_aggBuffer.setLong(0, agg_value3); /* 118 */ /* 119 */ if (shouldStop()) return; /* 120 */ } /* 121 */ /* 122 */ agg_mapIter = agg_plan.finishAggregate(agg_hashMap, agg_sorter); /* 123 */ } /* 124 */ /* 125 */ private void initRange(int idx) { /* 126 */ java.math.BigInteger index = java.math.BigInteger.valueOf(idx); /* 127 */ java.math.BigInteger numSlice = java.math.BigInteger.valueOf(1L); /* 128 */ java.math.BigInteger numElement = java.math.BigInteger.valueOf(100L); /* 129 */ java.math.BigInteger step = java.math.BigInteger.valueOf(1L); /* 130 */ java.math.BigInteger start = java.math.BigInteger.valueOf(0L); /* 131 */ /* 132 */ java.math.BigInteger st = index.multiply(numElement).divide(numSlice).multiply(step).add(start); /* 133 */ if (st.compareTo(java.math.BigInteger.valueOf(Long.MAX_VALUE)) > 0) { /* 134 */ range_number = Long.MAX_VALUE; /* 135 */ } else if (st.compareTo(java.math.BigInteger.valueOf(Long.MIN_VALUE)) < 0) { /* 136 */ range_number = Long.MIN_VALUE; /* 137 */ } else { /* 138 */ range_number = st.longValue(); /* 139 */ } /* 140 */ /* 141 */ java.math.BigInteger end = index.add(java.math.BigInteger.ONE).multiply(numElement).divide(numSlice) /* 142 */ .multiply(step).add(start); /* 143 */ if (end.compareTo(java.math.BigInteger.valueOf(Long.MAX_VALUE)) > 0) { /* 144 */ range_partitionEnd = Long.MAX_VALUE; /* 145 */ } else if (end.compareTo(java.math.BigInteger.valueOf(Long.MIN_VALUE)) < 0) { /* 146 */ range_partitionEnd = Long.MIN_VALUE; /* 147 */ } else { /* 148 */ range_partitionEnd = end.longValue(); /* 149 */ } /* 150 */ /* 151 */ range_metricValue.add((range_partitionEnd - range_number) / 1L); /* 152 */ } /* 153 */ /* 154 */ protected void processNext() throws java.io.IOException { /* 155 */ /*** PRODUCE: TungstenAggregate(key=[id#0L], functions=[(count(1),mode=Partial,isDistinct=false)], output=[id#0L,count#9L]) */ /* 156 */ /* 157 */ if (!agg_initAgg) { /* 158 */ agg_initAgg = true; /* 159 */ agg_doAggregateWithKeys(); /* 160 */ } /* 161 */ /* 162 */ // output the result /* 163 */ while (agg_mapIter.next()) { /* 164 */ wholestagecodegen_metricValue.add(1); /* 165 */ UnsafeRow agg_aggKey = (UnsafeRow) agg_mapIter.getKey(); /* 166 */ UnsafeRow agg_aggBuffer1 = (UnsafeRow) agg_mapIter.getValue(); /* 167 */ /* 168 */ UnsafeRow agg_resultRow = agg_unsafeRowJoiner.join(agg_aggKey, agg_aggBuffer1); /* 169 */ /* 170 */ /*** CONSUME: WholeStageCodegen */ /* 171 */ /* 172 */ append(agg_resultRow); /* 173 */ /* 174 */ if (shouldStop()) return; /* 175 */ } /* 176 */ /* 177 */ agg_mapIter.close(); /* 178 */ if (agg_sorter == null) { /* 179 */ agg_hashMap.free(); /* 180 */ } /* 181 */ } /* 182 */ } == Subtree 2 / 3 == WholeStageCodegen : +- Sort [id#0L ASC], true, 0 : +- INPUT +- Exchange rangepartitioning(id#0L ASC, 200), None +- WholeStageCodegen : +- TungstenAggregate(key=[id#0L], functions=[(count(1),mode=Final,isDistinct=false)], output=[id#0L,count#4L]) : +- INPUT +- Exchange hashpartitioning(id#0L, 200), None +- WholeStageCodegen : +- TungstenAggregate(key=[id#0L], functions=[(count(1),mode=Partial,isDistinct=false)], output=[id#0L,count#9L]) : +- Range 0, 1, 1, 100, [id#0L] Generated code: /* 001 */ public Object generate(Object[] references) { /* 002 */ return new GeneratedIterator(references); /* 003 */ } /* 004 */ /* 005 */ /** Codegened pipeline for: /* 006 */ * Sort [id#0L ASC], true, 0 /* 007 */ +- INPUT /* 008 */ */ /* 009 */ final class GeneratedIterator extends org.apache.spark.sql.execution.BufferedRowIterator { /* 010 */ private Object[] references; /* 011 */ private boolean sort_needToSort; /* 012 */ private org.apache.spark.sql.execution.Sort sort_plan; /* 013 */ private org.apache.spark.sql.execution.UnsafeExternalRowSorter sort_sorter; /* 014 */ private org.apache.spark.executor.TaskMetrics sort_metrics; /* 015 */ private scala.collection.Iterator sort_sortedIter; /* 016 */ private scala.collection.Iterator inputadapter_input; /* 017 */ private org.apache.spark.sql.execution.metric.LongSQLMetric sort_dataSize; /* 018 */ private org.apache.spark.sql.execution.metric.LongSQLMetricValue sort_metricValue; /* 019 */ private org.apache.spark.sql.execution.metric.LongSQLMetric sort_spillSize; /* 020 */ private org.apache.spark.sql.execution.metric.LongSQLMetricValue sort_metricValue1; /* 021 */ /* 022 */ public GeneratedIterator(Object[] references) { /* 023 */ this.references = references; /* 024 */ } /* 025 */ /* 026 */ public void init(scala.collection.Iterator inputs[]) { /* 027 */ sort_needToSort = true; /* 028 */ this.sort_plan = (org.apache.spark.sql.execution.Sort) references[0]; /* 029 */ sort_sorter = sort_plan.createSorter(); /* 030 */ sort_metrics = org.apache.spark.TaskContext.get().taskMetrics(); /* 031 */ /* 032 */ inputadapter_input = inputs[0]; /* 033 */ this.sort_dataSize = (org.apache.spark.sql.execution.metric.LongSQLMetric) references[1]; /* 034 */ sort_metricValue = (org.apache.spark.sql.execution.metric.LongSQLMetricValue) sort_dataSize.localValue(); /* 035 */ this.sort_spillSize = (org.apache.spark.sql.execution.metric.LongSQLMetric) references[2]; /* 036 */ sort_metricValue1 = (org.apache.spark.sql.execution.metric.LongSQLMetricValue) sort_spillSize.localValue(); /* 037 */ } /* 038 */ /* 039 */ private void sort_addToSorter() throws java.io.IOException { /* 040 */ /*** PRODUCE: INPUT */ /* 041 */ /* 042 */ while (inputadapter_input.hasNext()) { /* 043 */ InternalRow inputadapter_row = (InternalRow) inputadapter_input.next(); /* 044 */ /*** CONSUME: Sort [id#0L ASC], true, 0 */ /* 045 */ /* 046 */ sort_sorter.insertRow((UnsafeRow)inputadapter_row); /* 047 */ if (shouldStop()) return; /* 048 */ } /* 049 */ /* 050 */ } /* 051 */ /* 052 */ protected void processNext() throws java.io.IOException { /* 053 */ /*** PRODUCE: Sort [id#0L ASC], true, 0 */ /* 054 */ if (sort_needToSort) { /* 055 */ sort_addToSorter(); /* 056 */ Long sort_spillSizeBefore = sort_metrics.memoryBytesSpilled(); /* 057 */ sort_sortedIter = sort_sorter.sort(); /* 058 */ sort_metricValue.add(sort_sorter.getPeakMemoryUsage()); /* 059 */ sort_metricValue1.add(sort_metrics.memoryBytesSpilled() - sort_spillSizeBefore); /* 060 */ sort_metrics.incPeakExecutionMemory(sort_sorter.getPeakMemoryUsage()); /* 061 */ sort_needToSort = false; /* 062 */ } /* 063 */ /* 064 */ while (sort_sortedIter.hasNext()) { /* 065 */ UnsafeRow sort_outputRow = (UnsafeRow)sort_sortedIter.next(); /* 066 */ /* 067 */ /*** CONSUME: WholeStageCodegen */ /* 068 */ /* 069 */ append(sort_outputRow); /* 070 */ /* 071 */ if (shouldStop()) return; /* 072 */ } /* 073 */ } /* 074 */ } == Subtree 3 / 3 == WholeStageCodegen : +- TungstenAggregate(key=[id#0L], functions=[(count(1),mode=Final,isDistinct=false)], output=[id#0L,count#4L]) : +- INPUT +- Exchange hashpartitioning(id#0L, 200), None +- WholeStageCodegen : +- TungstenAggregate(key=[id#0L], functions=[(count(1),mode=Partial,isDistinct=false)], output=[id#0L,count#9L]) : +- Range 0, 1, 1, 100, [id#0L] Generated code: /* 001 */ public Object generate(Object[] references) { /* 002 */ return new GeneratedIterator(references); /* 003 */ } /* 004 */ /* 005 */ /** Codegened pipeline for: /* 006 */ * TungstenAggregate(key=[id#0L], functions=[(count(1),mode=Final,isDistinct=false)], output=[id#0L,count#4L]) /* 007 */ +- INPUT /* 008 */ */ /* 009 */ final class GeneratedIterator extends org.apache.spark.sql.execution.BufferedRowIterator { /* 010 */ private Object[] references; /* 011 */ private boolean agg_initAgg; /* 012 */ private org.apache.spark.sql.execution.aggregate.TungstenAggregate agg_plan; /* 013 */ private org.apache.spark.sql.execution.UnsafeFixedWidthAggregationMap agg_hashMap; /* 014 */ private org.apache.spark.sql.execution.UnsafeKVExternalSorter agg_sorter; /* 015 */ private org.apache.spark.unsafe.KVIterator agg_mapIter; /* 016 */ private scala.collection.Iterator inputadapter_input; /* 017 */ private UnsafeRow agg_result; /* 018 */ private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder agg_holder; /* 019 */ private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter agg_rowWriter; /* 020 */ private UnsafeRow agg_result1; /* 021 */ private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder agg_holder1; /* 022 */ private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter agg_rowWriter1; /* 023 */ private org.apache.spark.sql.execution.metric.LongSQLMetric wholestagecodegen_numOutputRows; /* 024 */ private org.apache.spark.sql.execution.metric.LongSQLMetricValue wholestagecodegen_metricValue; /* 025 */ /* 026 */ public GeneratedIterator(Object[] references) { /* 027 */ this.references = references; /* 028 */ } /* 029 */ /* 030 */ public void init(scala.collection.Iterator inputs[]) { /* 031 */ agg_initAgg = false; /* 032 */ this.agg_plan = (org.apache.spark.sql.execution.aggregate.TungstenAggregate) references[0]; /* 033 */ agg_hashMap = agg_plan.createHashMap(); /* 034 */ /* 035 */ inputadapter_input = inputs[0]; /* 036 */ agg_result = new UnsafeRow(1); /* 037 */ this.agg_holder = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(agg_result, 0); /* 038 */ this.agg_rowWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(agg_holder, 1); /* 039 */ agg_result1 = new UnsafeRow(2); /* 040 */ this.agg_holder1 = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(agg_result1, 0); /* 041 */ this.agg_rowWriter1 = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(agg_holder1, 2); /* 042 */ this.wholestagecodegen_numOutputRows = (org.apache.spark.sql.execution.metric.LongSQLMetric) references[1]; /* 043 */ wholestagecodegen_metricValue = (org.apache.spark.sql.execution.metric.LongSQLMetricValue) wholestagecodegen_numOutputRows.localValue(); /* 044 */ } /* 045 */ /* 046 */ private void agg_doAggregateWithKeys() throws java.io.IOException { /* 047 */ /*** PRODUCE: INPUT */ /* 048 */ /* 049 */ while (inputadapter_input.hasNext()) { /* 050 */ InternalRow inputadapter_row = (InternalRow) inputadapter_input.next(); /* 051 */ /*** CONSUME: TungstenAggregate(key=[id#0L], functions=[(count(1),mode=Final,isDistinct=false)], output=[id#0L,count#4L]) */ /* 052 */ /* input[0, bigint] */ /* 053 */ long inputadapter_value = inputadapter_row.getLong(0); /* 054 */ /* input[1, bigint] */ /* 055 */ long inputadapter_value1 = inputadapter_row.getLong(1); /* 056 */ /* 057 */ // generate grouping key /* 058 */ agg_rowWriter.write(0, inputadapter_value); /* 059 */ /* hash(input[0, bigint], 42) */ /* 060 */ int agg_value1 = 42; /* 061 */ /* 062 */ agg_value1 = org.apache.spark.unsafe.hash.Murmur3_x86_32.hashLong(inputadapter_value, agg_value1); /* 063 */ UnsafeRow agg_aggBuffer = null; /* 064 */ if (true) { /* 065 */ // try to get the buffer from hash map /* 066 */ agg_aggBuffer = agg_hashMap.getAggregationBufferFromUnsafeRow(agg_result, agg_value1); /* 067 */ } /* 068 */ if (agg_aggBuffer == null) { /* 069 */ if (agg_sorter == null) { /* 070 */ agg_sorter = agg_hashMap.destructAndCreateExternalSorter(); /* 071 */ } else { /* 072 */ agg_sorter.merge(agg_hashMap.destructAndCreateExternalSorter()); /* 073 */ } /* 074 */ /* 075 */ // the hash map had be spilled, it should have enough memory now, /* 076 */ // try to allocate buffer again. /* 077 */ agg_aggBuffer = agg_hashMap.getAggregationBufferFromUnsafeRow(agg_result, agg_value1); /* 078 */ if (agg_aggBuffer == null) { /* 079 */ // failed to allocate the first page /* 080 */ throw new OutOfMemoryError("No enough memory for aggregation"); /* 081 */ } /* 082 */ } /* 083 */ /* 084 */ // evaluate aggregate function /* 085 */ /* (input[0, bigint] + input[2, bigint]) */ /* 086 */ /* input[0, bigint] */ /* 087 */ long agg_value4 = agg_aggBuffer.getLong(0); /* 088 */ /* 089 */ long agg_value3 = -1L; /* 090 */ agg_value3 = agg_value4 + inputadapter_value1; /* 091 */ // update aggregate buffer /* 092 */ agg_aggBuffer.setLong(0, agg_value3); /* 093 */ if (shouldStop()) return; /* 094 */ } /* 095 */ /* 096 */ agg_mapIter = agg_plan.finishAggregate(agg_hashMap, agg_sorter); /* 097 */ } /* 098 */ /* 099 */ protected void processNext() throws java.io.IOException { /* 100 */ /*** PRODUCE: TungstenAggregate(key=[id#0L], functions=[(count(1),mode=Final,isDistinct=false)], output=[id#0L,count#4L]) */ /* 101 */ /* 102 */ if (!agg_initAgg) { /* 103 */ agg_initAgg = true; /* 104 */ agg_doAggregateWithKeys(); /* 105 */ } /* 106 */ /* 107 */ // output the result /* 108 */ while (agg_mapIter.next()) { /* 109 */ wholestagecodegen_metricValue.add(1); /* 110 */ UnsafeRow agg_aggKey = (UnsafeRow) agg_mapIter.getKey(); /* 111 */ UnsafeRow agg_aggBuffer1 = (UnsafeRow) agg_mapIter.getValue(); /* 112 */ /* 113 */ /* input[0, bigint] */ /* 114 */ long agg_value6 = agg_aggKey.getLong(0); /* 115 */ /* input[0, bigint] */ /* 116 */ long agg_value7 = agg_aggBuffer1.getLong(0); /* 117 */ /* 118 */ /*** CONSUME: WholeStageCodegen */ /* 119 */ /* 120 */ agg_rowWriter1.write(0, agg_value6); /* 121 */ /* 122 */ agg_rowWriter1.write(1, agg_value7); /* 123 */ append(agg_result1); /* 124 */ /* 125 */ if (shouldStop()) return; /* 126 */ } /* 127 */ /* 128 */ agg_mapIter.close(); /* 129 */ if (agg_sorter == null) { /* 130 */ agg_hashMap.free(); /* 131 */ } /* 132 */ } /* 133 */ } ``` rxin Author: Eric Liang Closes #12025 from ericl/spark-14227. --- .../spark/sql/execution/WholeStageCodegen.scala | 13 +++++- .../apache/spark/sql/execution/debug/package.scala | 46 +++++++++++++++++++--- .../spark/sql/execution/debug/DebuggingSuite.scala | 7 ++++ 3 files changed, 60 insertions(+), 6 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala index 1b13c8fd22..da3ee46b7d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala @@ -297,7 +297,12 @@ case class WholeStageCodegen(child: SparkPlan) extends UnaryNode with CodegenSup "pipelineTime" -> SQLMetrics.createTimingMetric(sparkContext, WholeStageCodegen.PIPELINE_DURATION_METRIC)) - override def doExecute(): RDD[InternalRow] = { + /** + * Generates code for this subtree. + * + * @return the tuple of the codegen context and the actual generated source. + */ + def doCodeGen(): (CodegenContext, String) = { val ctx = new CodegenContext val code = child.asInstanceOf[CodegenSupport].produce(ctx, this) val references = ctx.references.toArray @@ -334,6 +339,12 @@ case class WholeStageCodegen(child: SparkPlan) extends UnaryNode with CodegenSup val cleanedSource = CodeFormatter.stripExtraNewLines(source) logDebug(s"${CodeFormatter.format(cleanedSource)}") CodeGenerator.compile(cleanedSource) + (ctx, cleanedSource) + } + + override def doExecute(): RDD[InternalRow] = { + val (ctx, cleanedSource) = doCodeGen() + val references = ctx.references.toArray val durationMs = longMetric("pipelineTime") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala index 5e573b3159..9916482a68 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala @@ -25,7 +25,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Attribute -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodeFormatter, CodegenContext, ExprCode} import org.apache.spark.sql.catalyst.trees.TreeNodeRef import org.apache.spark.sql.internal.SQLConf @@ -41,6 +41,13 @@ import org.apache.spark.sql.internal.SQLConf */ package object debug { + /** Helper function to evade the println() linter. */ + private def debugPrint(msg: String): Unit = { + // scalastyle:off println + println(msg) + // scalastyle:on println + } + /** * Augments [[SQLContext]] with debug methods. */ @@ -62,12 +69,41 @@ package object debug { visited += new TreeNodeRef(s) DebugNode(s) } - logDebug(s"Results returned: ${debugPlan.execute().count()}") + debugPrint(s"Results returned: ${debugPlan.execute().count()}") debugPlan.foreach { case d: DebugNode => d.dumpStats() case _ => } } + + /** + * Prints to stdout all the generated code found in this plan (i.e. the output of each + * WholeStageCodegen subtree). + */ + def debugCodegen(): Unit = { + debugPrint(debugCodegenString()) + } + + /** Visible for testing. */ + def debugCodegenString(): String = { + val plan = query.queryExecution.executedPlan + val codegenSubtrees = new collection.mutable.HashSet[WholeStageCodegen]() + plan transform { + case s: WholeStageCodegen => + codegenSubtrees += s + s + case s => s + } + var output = s"Found ${codegenSubtrees.size} WholeStageCodegen subtrees.\n" + for ((s, i) <- codegenSubtrees.toSeq.zipWithIndex) { + output += s"== Subtree ${i + 1} / ${codegenSubtrees.size} ==\n" + output += s + output += "\nGenerated code:\n" + val (_, source) = s.doCodeGen() + output += s"${CodeFormatter.format(source)}\n" + } + output + } } private[sql] case class DebugNode(child: SparkPlan) extends UnaryNode with CodegenSupport { @@ -99,11 +135,11 @@ package object debug { val columnStats: Array[ColumnMetrics] = Array.fill(child.output.size)(new ColumnMetrics()) def dumpStats(): Unit = { - logDebug(s"== ${child.simpleString} ==") - logDebug(s"Tuples output: ${tupleCount.value}") + debugPrint(s"== ${child.simpleString} ==") + debugPrint(s"Tuples output: ${tupleCount.value}") child.output.zip(columnStats).foreach { case (attr, metric) => val actualDataTypes = metric.elementTypes.value.mkString("{", ",", "}") - logDebug(s" ${attr.name} ${attr.dataType}: $actualDataTypes") + debugPrint(s" ${attr.name} ${attr.dataType}: $actualDataTypes") } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala index 22189477d2..979265e274 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala @@ -25,4 +25,11 @@ class DebuggingSuite extends SparkFunSuite with SharedSQLContext { test("DataFrame.debug()") { testData.debug() } + + test("debugCodegen") { + val res = sqlContext.range(10).groupBy("id").count().debugCodegenString() + assert(res.contains("Subtree 1 / 2")) + assert(res.contains("Subtree 2 / 2")) + assert(res.contains("Object[]")) + } } -- cgit v1.2.3 From a7a93a116dd9813853ba6f112beb7763931d2006 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Tue, 29 Mar 2016 15:06:29 -0700 Subject: [SPARK-14215] [SQL] [PYSPARK] Support chained Python UDFs ## What changes were proposed in this pull request? This PR brings the support for chained Python UDFs, for example ```sql select udf1(udf2(a)) select udf1(udf2(a) + 3) select udf1(udf2(a) + udf3(b)) ``` Also directly chained unary Python UDFs are put in single batch of Python UDFs, others may require multiple batches. For example, ```python >>> sqlContext.sql("select double(double(1))").explain() == Physical Plan == WholeStageCodegen : +- Project [pythonUDF#10 AS double(double(1))#9] : +- INPUT +- !BatchPythonEvaluation double(double(1)), [pythonUDF#10] +- Scan OneRowRelation[] >>> sqlContext.sql("select double(double(1) + double(2))").explain() == Physical Plan == WholeStageCodegen : +- Project [pythonUDF#19 AS double((double(1) + double(2)))#16] : +- INPUT +- !BatchPythonEvaluation double((pythonUDF#17 + pythonUDF#18)), [pythonUDF#17,pythonUDF#18,pythonUDF#19] +- !BatchPythonEvaluation double(2), [pythonUDF#17,pythonUDF#18] +- !BatchPythonEvaluation double(1), [pythonUDF#17] +- Scan OneRowRelation[] ``` TODO: will support multiple unrelated Python UDFs in one batch (another PR). ## How was this patch tested? Added new unit tests for chained UDFs. Author: Davies Liu Closes #12014 from davies/py_udfs. --- .../org/apache/spark/api/python/PythonRDD.scala | 38 +++++++++++++--------- python/pyspark/sql/functions.py | 16 ++++++--- python/pyspark/sql/tests.py | 9 +++++ python/pyspark/worker.py | 33 ++++++++++++++++--- .../execution/python/BatchPythonEvaluation.scala | 29 ++++++++++++----- .../sql/execution/python/ExtractPythonUDFs.scala | 26 ++++++++++++++- 6 files changed, 116 insertions(+), 35 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index f423b2ee56..0f579b4ef5 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -59,7 +59,7 @@ private[spark] class PythonRDD( val asJavaRDD: JavaRDD[Array[Byte]] = JavaRDD.fromRDD(this) override def compute(split: Partition, context: TaskContext): Iterator[Array[Byte]] = { - val runner = new PythonRunner(func, bufferSize, reuse_worker) + val runner = new PythonRunner(Seq(func), bufferSize, reuse_worker, false) runner.compute(firstParent.iterator(split, context), split.index, context) } } @@ -81,14 +81,18 @@ private[spark] case class PythonFunction( * A helper class to run Python UDFs in Spark. */ private[spark] class PythonRunner( - func: PythonFunction, + funcs: Seq[PythonFunction], bufferSize: Int, - reuse_worker: Boolean) + reuse_worker: Boolean, + rowBased: Boolean) extends Logging { - private val envVars = func.envVars - private val pythonExec = func.pythonExec - private val accumulator = func.accumulator + // All the Python functions should have the same exec, version and envvars. + private val envVars = funcs.head.envVars + private val pythonExec = funcs.head.pythonExec + private val pythonVer = funcs.head.pythonVer + + private val accumulator = funcs.head.accumulator // TODO: support accumulator in multiple UDF def compute( inputIterator: Iterator[_], @@ -228,10 +232,8 @@ private[spark] class PythonRunner( @volatile private var _exception: Exception = null - private val pythonVer = func.pythonVer - private val pythonIncludes = func.pythonIncludes - private val broadcastVars = func.broadcastVars - private val command = func.command + private val pythonIncludes = funcs.flatMap(_.pythonIncludes.asScala).toSet + private val broadcastVars = funcs.flatMap(_.broadcastVars.asScala) setDaemon(true) @@ -256,13 +258,13 @@ private[spark] class PythonRunner( // sparkFilesDir PythonRDD.writeUTF(SparkFiles.getRootDirectory(), dataOut) // Python includes (*.zip and *.egg files) - dataOut.writeInt(pythonIncludes.size()) - for (include <- pythonIncludes.asScala) { + dataOut.writeInt(pythonIncludes.size) + for (include <- pythonIncludes) { PythonRDD.writeUTF(include, dataOut) } // Broadcast variables val oldBids = PythonRDD.getWorkerBroadcasts(worker) - val newBids = broadcastVars.asScala.map(_.id).toSet + val newBids = broadcastVars.map(_.id).toSet // number of different broadcasts val toRemove = oldBids.diff(newBids) val cnt = toRemove.size + newBids.diff(oldBids).size @@ -272,7 +274,7 @@ private[spark] class PythonRunner( dataOut.writeLong(- bid - 1) // bid >= 0 oldBids.remove(bid) } - for (broadcast <- broadcastVars.asScala) { + for (broadcast <- broadcastVars) { if (!oldBids.contains(broadcast.id)) { // send new broadcast dataOut.writeLong(broadcast.id) @@ -282,8 +284,12 @@ private[spark] class PythonRunner( } dataOut.flush() // Serialized command: - dataOut.writeInt(command.length) - dataOut.write(command) + dataOut.writeInt(if (rowBased) 1 else 0) + dataOut.writeInt(funcs.length) + funcs.foreach { f => + dataOut.writeInt(f.command.length) + dataOut.write(f.command) + } // Data values PythonRDD.writeIteratorToStream(inputIterator, dataOut) dataOut.writeInt(SpecialLengths.END_OF_DATA_SECTION) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index f5d959ef98..3211834226 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -25,7 +25,7 @@ if sys.version < "3": from itertools import imap as map from pyspark import since, SparkContext -from pyspark.rdd import _wrap_function, ignore_unicode_prefix +from pyspark.rdd import _prepare_for_python_RDD, ignore_unicode_prefix from pyspark.serializers import PickleSerializer, AutoBatchedSerializer from pyspark.sql.types import StringType from pyspark.sql.column import Column, _to_java_column, _to_seq @@ -1648,6 +1648,14 @@ def sort_array(col, asc=True): # ---------------------------- User Defined Function ---------------------------------- +def _wrap_function(sc, func, returnType): + ser = AutoBatchedSerializer(PickleSerializer()) + command = (func, returnType, ser) + pickled_command, broadcast_vars, env, includes = _prepare_for_python_RDD(sc, command) + return sc._jvm.PythonFunction(bytearray(pickled_command), env, includes, sc.pythonExec, + sc.pythonVer, broadcast_vars, sc._javaAccumulator) + + class UserDefinedFunction(object): """ User defined function in Python @@ -1662,14 +1670,12 @@ class UserDefinedFunction(object): def _create_judf(self, name): from pyspark.sql import SQLContext - f, returnType = self.func, self.returnType # put them in closure `func` - func = lambda _, it: map(lambda x: returnType.toInternal(f(*x)), it) - ser = AutoBatchedSerializer(PickleSerializer()) sc = SparkContext.getOrCreate() - wrapped_func = _wrap_function(sc, func, ser, ser) + wrapped_func = _wrap_function(sc, self.func, self.returnType) ctx = SQLContext.getOrCreate(sc) jdt = ctx._ssql_ctx.parseDataType(self.returnType.json()) if name is None: + f = self.func name = f.__name__ if hasattr(f, '__name__') else f.__class__.__name__ judf = sc._jvm.org.apache.spark.sql.execution.python.UserDefinedPythonFunction( name, wrapped_func, jdt) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 1a5d422af9..84947560e7 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -305,6 +305,15 @@ class SQLTests(ReusedPySparkTestCase): [res] = self.sqlCtx.sql("SELECT strlen(a) FROM test WHERE strlen(a) > 1").collect() self.assertEqual(4, res[0]) + def test_chained_python_udf(self): + self.sqlCtx.registerFunction("double", lambda x: x + x, IntegerType()) + [row] = self.sqlCtx.sql("SELECT double(1)").collect() + self.assertEqual(row[0], 2) + [row] = self.sqlCtx.sql("SELECT double(double(1))").collect() + self.assertEqual(row[0], 4) + [row] = self.sqlCtx.sql("SELECT double(double(1) + 1)").collect() + self.assertEqual(row[0], 6) + def test_udf_with_array_type(self): d = [Row(l=list(range(3)), d={"key": list(range(5))})] rdd = self.sc.parallelize(d) diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 42c2f8b759..0f05fe31aa 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -50,6 +50,18 @@ def add_path(path): sys.path.insert(1, path) +def read_command(serializer, file): + command = serializer._read_with_length(file) + if isinstance(command, Broadcast): + command = serializer.loads(command.value) + return command + + +def chain(f, g): + """chain two function together """ + return lambda x: g(f(x)) + + def main(infile, outfile): try: boot_time = time.time() @@ -95,10 +107,23 @@ def main(infile, outfile): _broadcastRegistry.pop(bid) _accumulatorRegistry.clear() - command = pickleSer._read_with_length(infile) - if isinstance(command, Broadcast): - command = pickleSer.loads(command.value) - func, profiler, deserializer, serializer = command + row_based = read_int(infile) + num_commands = read_int(infile) + if row_based: + profiler = None # profiling is not supported for UDF + row_func = None + for i in range(num_commands): + f, returnType, deserializer = read_command(pickleSer, infile) + if row_func is None: + row_func = f + else: + row_func = chain(row_func, f) + serializer = deserializer + func = lambda _, it: map(lambda x: returnType.toInternal(row_func(*x)), it) + else: + assert num_commands == 1 + func, profiler, deserializer, serializer = read_command(pickleSer, infile) + init_time = time.time() def process(): diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchPythonEvaluation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchPythonEvaluation.scala index 79e4491026..a76009e7df 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchPythonEvaluation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchPythonEvaluation.scala @@ -22,10 +22,10 @@ import scala.collection.JavaConverters._ import net.razorvine.pickle.{Pickler, Unpickler} import org.apache.spark.TaskContext -import org.apache.spark.api.python.PythonRunner +import org.apache.spark.api.python.{PythonFunction, PythonRunner} import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{Attribute, GenericMutableRow, JoinedRow, UnsafeProjection} +import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.types.{StructField, StructType} @@ -45,6 +45,18 @@ case class BatchPythonEvaluation(udf: PythonUDF, output: Seq[Attribute], child: def children: Seq[SparkPlan] = child :: Nil + private def collectFunctions(udf: PythonUDF): (Seq[PythonFunction], Seq[Expression]) = { + udf.children match { + case Seq(u: PythonUDF) => + val (fs, children) = collectFunctions(u) + (fs ++ Seq(udf.func), children) + case children => + // There should not be any other UDFs, or the children can't be evaluated directly. + assert(children.forall(_.find(_.isInstanceOf[PythonUDF]).isEmpty)) + (Seq(udf.func), udf.children) + } + } + protected override def doExecute(): RDD[InternalRow] = { val inputRDD = child.execute().map(_.copy()) val bufferSize = inputRDD.conf.getInt("spark.buffer.size", 65536) @@ -57,9 +69,11 @@ case class BatchPythonEvaluation(udf: PythonUDF, output: Seq[Attribute], child: // combine input with output from Python. val queue = new java.util.concurrent.ConcurrentLinkedQueue[InternalRow]() + val (pyFuncs, children) = collectFunctions(udf) + val pickle = new Pickler - val currentRow = newMutableProjection(udf.children, child.output)() - val fields = udf.children.map(_.dataType) + val currentRow = newMutableProjection(children, child.output)() + val fields = children.map(_.dataType) val schema = new StructType(fields.map(t => new StructField("", t, true)).toArray) // Input iterator to Python: input rows are grouped so we send them in batches to Python. @@ -75,11 +89,8 @@ case class BatchPythonEvaluation(udf: PythonUDF, output: Seq[Attribute], child: val context = TaskContext.get() // Output iterator for results from Python. - val outputIterator = new PythonRunner( - udf.func, - bufferSize, - reuseWorker - ).compute(inputIterator, context.partitionId(), context) + val outputIterator = new PythonRunner(pyFuncs, bufferSize, reuseWorker, true) + .compute(inputIterator, context.partitionId(), context) val unpickle = new Unpickler val row = new GenericMutableRow(1) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala index 6e76e9569f..c486ce18e8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.execution.python +import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.Rule @@ -25,17 +26,40 @@ import org.apache.spark.sql.catalyst.rules.Rule * Extracts PythonUDFs from operators, rewriting the query plan so that the UDF can be evaluated * alone in a batch. * + * Only extracts the PythonUDFs that could be evaluated in Python (the single child is PythonUDFs + * or all the children could be evaluated in JVM). + * * This has the limitation that the input to the Python UDF is not allowed include attributes from * multiple child operators. */ private[spark] object ExtractPythonUDFs extends Rule[LogicalPlan] { + + private def hasPythonUDF(e: Expression): Boolean = { + e.find(_.isInstanceOf[PythonUDF]).isDefined + } + + private def canEvaluateInPython(e: PythonUDF): Boolean = { + e.children match { + // single PythonUDF child could be chained and evaluated in Python + case Seq(u: PythonUDF) => canEvaluateInPython(u) + // Python UDF can't be evaluated directly in JVM + case children => !children.exists(hasPythonUDF) + } + } + + private def collectEvaluatableUDF(expr: Expression): Seq[PythonUDF] = { + expr.collect { + case udf: PythonUDF if canEvaluateInPython(udf) => udf + } + } + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { // Skip EvaluatePython nodes. case plan: EvaluatePython => plan case plan: LogicalPlan if plan.resolved => // Extract any PythonUDFs from the current operator. - val udfs = plan.expressions.flatMap(_.collect { case udf: PythonUDF => udf }) + val udfs = plan.expressions.flatMap(collectEvaluatableUDF) if (udfs.isEmpty) { // If there aren't any, we are done. plan -- cgit v1.2.3 From 366cac6fb0bb5591a0463c4696f5b9de2a294022 Mon Sep 17 00:00:00 2001 From: Sameer Agarwal Date: Tue, 29 Mar 2016 16:46:45 -0700 Subject: [SPARK-14225][SQL] Cap the length of toCommentSafeString at 128 chars ## What changes were proposed in this pull request? Builds on https://github.com/apache/spark/pull/12022 and (a) appends "..." to truncated comment strings and (b) fixes indentation in lines after the commented strings if they happen to have a `(`, `{`, `)` or `}` ## How was this patch tested? Manually examined the generated code. Author: Sameer Agarwal Closes #12044 from sameeragarwal/comment. --- .../expressions/codegen/CodeFormatter.scala | 37 ++++++++++++++++--- .../apache/spark/sql/catalyst/util/package.scala | 9 +++-- .../expressions/codegen/CodeFormatterSuite.scala | 41 +++++++++++++++++++++- 3 files changed, 79 insertions(+), 8 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeFormatter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeFormatter.scala index 9d99bbffbe..8e40754dc3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeFormatter.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeFormatter.scala @@ -43,15 +43,44 @@ object CodeFormatter { private class CodeFormatter { private val code = new StringBuilder - private var indentLevel = 0 private val indentSize = 2 + + // Tracks the level of indentation in the current line. + private var indentLevel = 0 private var indentString = "" private var currentLine = 1 + // Tracks the level of indentation in multi-line comment blocks. + private var inCommentBlock = false + private var indentLevelOutsideCommentBlock = indentLevel + private def addLine(line: String): Unit = { - val indentChange = - line.count(c => "({".indexOf(c) >= 0) - line.count(c => ")}".indexOf(c) >= 0) - val newIndentLevel = math.max(0, indentLevel + indentChange) + + // We currently infer the level of indentation of a given line based on a simple heuristic that + // examines the number of parenthesis and braces in that line. This isn't the most robust + // implementation but works for all code that we generate. + val indentChange = line.count(c => "({".indexOf(c) >= 0) - line.count(c => ")}".indexOf(c) >= 0) + var newIndentLevel = math.max(0, indentLevel + indentChange) + + // Please note that while we try to format the comment blocks in exactly the same way as the + // rest of the code, once the block ends, we reset the next line's indentation level to what it + // was immediately before entering the comment block. + if (!inCommentBlock) { + if (line.startsWith("/*")) { + // Handle multi-line comments + inCommentBlock = true + indentLevelOutsideCommentBlock = indentLevel + } else if (line.startsWith("//")) { + // Handle single line comments + newIndentLevel = indentLevel + } + } else { + if (line.endsWith("*/")) { + inCommentBlock = false + newIndentLevel = indentLevelOutsideCommentBlock + } + } + // Lines starting with '}' should be de-indented even if they contain '{' after; // in addition, lines ending with ':' are typically labels val thisLineIndent = if (line.startsWith("}") || line.startsWith(")") || line.endsWith(":")) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala index b11365b297..f879b34358 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala @@ -155,10 +155,13 @@ package object util { /** * Returns the string representation of this expression that is safe to be put in - * code comments of generated code. + * code comments of generated code. The length is capped at 128 characters. */ - def toCommentSafeString(str: String): String = - str.replace("*/", "\\*\\/").replace("\\u", "\\\\u") + def toCommentSafeString(str: String): String = { + val len = math.min(str.length, 128) + val suffix = if (str.length > len) "..." else "" + str.substring(0, len).replace("*/", "\\*\\/").replace("\\u", "\\\\u") + suffix + } /* FIX ME implicit class debugLogging(a: Any) { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeFormatterSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeFormatterSuite.scala index 9da1068e9c..d7836aa3b2 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeFormatterSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeFormatterSuite.scala @@ -18,13 +18,20 @@ package org.apache.spark.sql.catalyst.expressions.codegen import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.util._ class CodeFormatterSuite extends SparkFunSuite { def testCase(name: String)(input: String)(expected: String): Unit = { test(name) { - assert(CodeFormatter.format(input).trim === expected.trim) + if (CodeFormatter.format(input).trim !== expected.trim) { + fail( + s""" + |== FAIL: Formatted code doesn't match === + |${sideBySide(CodeFormatter.format(input).trim, expected.trim).mkString("\n")} + """.stripMargin) + } } } @@ -93,4 +100,36 @@ class CodeFormatterSuite extends SparkFunSuite { |/* 004 */ c) """.stripMargin } + + testCase("single line comments") { + """// This is a comment about class A { { { ( ( + |class A { + |class body; + |}""".stripMargin + }{ + """ + |/* 001 */ // This is a comment about class A { { { ( ( + |/* 002 */ class A { + |/* 003 */ class body; + |/* 004 */ } + """.stripMargin + } + + testCase("multi-line comments") { + """ /* This is a comment about + |class A { + |class body; ...*/ + |class A { + |class body; + |}""".stripMargin + }{ + """ + |/* 001 */ /* This is a comment about + |/* 002 */ class A { + |/* 003 */ class body; ...*/ + |/* 004 */ class A { + |/* 005 */ class body; + |/* 006 */ } + """.stripMargin + } } -- cgit v1.2.3 From e1f6845391078726f60e760f0ea68ccf81f9eca9 Mon Sep 17 00:00:00 2001 From: tedyu Date: Tue, 29 Mar 2016 17:16:53 -0700 Subject: [SPARK-12181] Check Cached unaligned-access capability before using Unsafe ## What changes were proposed in this pull request? For MemoryMode.OFF_HEAP, Unsafe.getInt etc. are used with no restriction. However, the Oracle implementation uses these methods only if the class variable unaligned (commented as "Cached unaligned-access capability") is true, which seems to be calculated whether the architecture is i386, x86, amd64, or x86_64. I think we should perform similar check for the use of Unsafe. Reference: https://github.com/netty/netty/blob/4.1/common/src/main/java/io/netty/util/internal/PlatformDependent0.java#L112 ## How was this patch tested? Unit test suite Author: tedyu Closes #11943 from tedyu/master. --- .../java/org/apache/spark/unsafe/Platform.java | 28 ++++++++++++++++++++++ .../org/apache/spark/memory/MemoryManager.scala | 3 +++ 2 files changed, 31 insertions(+) diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java index 18761bfd22..672552cc65 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java @@ -18,6 +18,7 @@ package org.apache.spark.unsafe; import java.lang.reflect.Field; +import java.lang.reflect.Method; import sun.misc.Unsafe; @@ -37,6 +38,33 @@ public final class Platform { public static final int DOUBLE_ARRAY_OFFSET; + private static final boolean unaligned; + static { + boolean _unaligned; + // use reflection to access unaligned field + try { + Class bitsClass = + Class.forName("java.nio.Bits", false, ClassLoader.getSystemClassLoader()); + Method unalignedMethod = bitsClass.getDeclaredMethod("unaligned"); + unalignedMethod.setAccessible(true); + _unaligned = Boolean.TRUE.equals(unalignedMethod.invoke(null)); + } catch (Throwable t) { + // We at least know x86 and x64 support unaligned access. + String arch = System.getProperty("os.arch", ""); + //noinspection DynamicRegexReplaceableByCompiledPattern + _unaligned = arch.matches("^(i[3-6]86|x86(_64)?|x64|amd64)$"); + } + unaligned = _unaligned; + } + + /** + * @return true when running JVM is having sun's Unsafe package available in it and underlying + * system having unaligned-access capability. + */ + public static boolean unaligned() { + return unaligned; + } + public static int getInt(Object object, long offset) { return _UNSAFE.getInt(object, offset); } diff --git a/core/src/main/scala/org/apache/spark/memory/MemoryManager.scala b/core/src/main/scala/org/apache/spark/memory/MemoryManager.scala index 10656bc8c8..0210217e41 100644 --- a/core/src/main/scala/org/apache/spark/memory/MemoryManager.scala +++ b/core/src/main/scala/org/apache/spark/memory/MemoryManager.scala @@ -23,6 +23,7 @@ import org.apache.spark.SparkConf import org.apache.spark.internal.Logging import org.apache.spark.storage.BlockId import org.apache.spark.storage.memory.MemoryStore +import org.apache.spark.unsafe.Platform import org.apache.spark.unsafe.array.ByteArrayMethods import org.apache.spark.unsafe.memory.MemoryAllocator @@ -190,6 +191,8 @@ private[spark] abstract class MemoryManager( if (conf.getBoolean("spark.memory.offHeap.enabled", false)) { require(conf.getSizeAsBytes("spark.memory.offHeap.size", 0) > 0, "spark.memory.offHeap.size must be > 0 when spark.memory.offHeap.enabled == true") + require(Platform.unaligned(), + "No support for unaligned Unsafe. Set spark.memory.offHeap.enabled to false.") MemoryMode.OFF_HEAP } else { MemoryMode.ON_HEAP -- cgit v1.2.3 From b66b97cd04067e1ec344fa2e28dd91e7ef937af5 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Tue, 29 Mar 2016 17:39:52 -0700 Subject: [SPARK-14124][SQL] Implement Database-related DDL Commands #### What changes were proposed in this pull request? This PR is to implement the following four Database-related DDL commands: - `CREATE DATABASE|SCHEMA [IF NOT EXISTS] database_name` - `DROP DATABASE [IF EXISTS] database_name [RESTRICT|CASCADE]` - `DESCRIBE DATABASE [EXTENDED] db_name` - `ALTER (DATABASE|SCHEMA) database_name SET DBPROPERTIES (property_name=property_value, ...)` Another PR will be submitted to handle the unsupported commands. In the Database-related DDL commands, we will issue an error exception for `ALTER (DATABASE|SCHEMA) database_name SET OWNER [USER|ROLE] user_or_role`. cc yhuai andrewor14 rxin Could you review the changes? Is it in the right direction? Thanks! #### How was this patch tested? Added a few test cases in `command/DDLSuite.scala` for testing DDL command execution in `SQLContext`. Since `HiveContext` also shares the same implementation, the existing test cases in `\hive` also verifies the correctness of these commands. Author: gatorsmile Author: xiaoli Author: Xiao Li Closes #12009 from gatorsmile/dbDDL. --- .../sql/catalyst/catalog/SessionCatalog.scala | 6 + .../spark/sql/catalyst/catalog/interface.scala | 2 +- .../org/apache/spark/sql/execution/SparkQl.scala | 17 +-- .../spark/sql/execution/SparkSqlParser.scala | 10 +- .../apache/spark/sql/execution/command/ddl.scala | 125 ++++++++++++++--- .../sql/execution/command/DDLCommandSuite.scala | 42 ++---- .../spark/sql/execution/command/DDLSuite.scala | 151 +++++++++++++++++++++ .../spark/sql/hive/thriftserver/CliSuite.scala | 2 +- .../apache/spark/sql/hive/HiveSessionCatalog.scala | 8 ++ 9 files changed, 302 insertions(+), 61 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala index 7165db1d5d..569b99e414 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.catalog +import java.io.File + import scala.collection.mutable import org.apache.spark.sql.AnalysisException @@ -114,6 +116,10 @@ class SessionCatalog( currentDb = db } + def getDefaultDBPath(db: String): String = { + System.getProperty("java.io.tmpdir") + File.separator + db + ".db" + } + // ---------------------------------------------------------------------------- // Tables // ---------------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala index 9f270236ae..303846d313 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala @@ -39,7 +39,7 @@ abstract class ExternalCatalog { protected def requireDbExists(db: String): Unit = { if (!databaseExists(db)) { - throw new AnalysisException(s"Database $db does not exist") + throw new AnalysisException(s"Database '$db' does not exist") } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkQl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkQl.scala index d4d1992d27..6fe04757ba 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkQl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkQl.scala @@ -97,7 +97,8 @@ private[sql] class SparkQl(conf: ParserConf = SimpleParserConf()) extends Cataly // CREATE DATABASE [IF NOT EXISTS] database_name [COMMENT database_comment] // [LOCATION path] [WITH DBPROPERTIES (key1=val1, key2=val2, ...)]; - case Token("TOK_CREATEDATABASE", Token(databaseName, Nil) :: args) => + case Token("TOK_CREATEDATABASE", Token(dbName, Nil) :: args) => + val databaseName = cleanIdentifier(dbName) val Seq(ifNotExists, dbLocation, databaseComment, dbprops) = getClauses(Seq( "TOK_IFNOTEXISTS", "TOK_DATABASELOCATION", @@ -126,7 +127,7 @@ private[sql] class SparkQl(conf: ParserConf = SimpleParserConf()) extends Cataly extractProps(propList, "TOK_TABLEPROPERTY") case _ => parseFailed("Invalid CREATE DATABASE command", node) }.toMap - CreateDatabase(databaseName, ifNotExists.isDefined, location, comment, props)(node.source) + CreateDatabase(databaseName, ifNotExists.isDefined, location, comment, props) // DROP DATABASE [IF EXISTS] database_name [RESTRICT|CASCADE]; case Token("TOK_DROPDATABASE", Token(dbName, Nil) :: otherArgs) => @@ -136,15 +137,15 @@ private[sql] class SparkQl(conf: ParserConf = SimpleParserConf()) extends Cataly // :- database_name // :- TOK_IFEXISTS // +- TOK_RESTRICT/TOK_CASCADE - val databaseName = unquoteString(dbName) + val databaseName = cleanIdentifier(dbName) // The default is RESTRICT val Seq(ifExists, _, cascade) = getClauses(Seq( "TOK_IFEXISTS", "TOK_RESTRICT", "TOK_CASCADE"), otherArgs) - DropDatabase(databaseName, ifExists.isDefined, restrict = cascade.isEmpty)(node.source) + DropDatabase(databaseName, ifExists.isDefined, cascade.isDefined) // ALTER (DATABASE|SCHEMA) database_name SET DBPROPERTIES (property_name=property_value, ...) case Token("TOK_ALTERDATABASE_PROPERTIES", Token(dbName, Nil) :: args) => - val databaseName = unquoteString(dbName) + val databaseName = cleanIdentifier(dbName) val dbprops = getClause("TOK_DATABASEPROPERTIES", args) val props = dbprops match { case Token("TOK_DATABASEPROPERTIES", Token("TOK_DBPROPLIST", propList) :: Nil) => @@ -161,13 +162,13 @@ private[sql] class SparkQl(conf: ParserConf = SimpleParserConf()) extends Cataly extractProps(propList, "TOK_TABLEPROPERTY") case _ => parseFailed("Invalid ALTER DATABASE command", node) } - AlterDatabaseProperties(databaseName, props.toMap)(node.source) + AlterDatabaseProperties(databaseName, props.toMap) // DESCRIBE DATABASE [EXTENDED] db_name case Token("TOK_DESCDATABASE", Token(dbName, Nil) :: describeArgs) => - val databaseName = unquoteString(dbName) + val databaseName = cleanIdentifier(dbName) val extended = getClauseOption("EXTENDED", describeArgs) - DescribeDatabase(databaseName, extended.isDefined)(node.source) + DescribeDatabase(databaseName, extended.isDefined) // CREATE [TEMPORARY] FUNCTION [db_name.]function_name AS class_name // [USING JAR|FILE|ARCHIVE 'file_uri' [, JAR|FILE|ARCHIVE 'file_uri'] ]; diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala index a8313deeef..8333074eca 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala @@ -232,8 +232,7 @@ class SparkSqlAstBuilder extends AstBuilder { ctx.EXISTS != null, Option(ctx.locationSpec).map(visitLocationSpec), Option(ctx.comment).map(string), - Option(ctx.tablePropertyList).map(visitTablePropertyList).getOrElse(Map.empty))( - command(ctx)) + Option(ctx.tablePropertyList).map(visitTablePropertyList).getOrElse(Map.empty)) } /** @@ -248,8 +247,7 @@ class SparkSqlAstBuilder extends AstBuilder { ctx: SetDatabasePropertiesContext): LogicalPlan = withOrigin(ctx) { AlterDatabaseProperties( ctx.identifier.getText, - visitTablePropertyList(ctx.tablePropertyList))( - command(ctx)) + visitTablePropertyList(ctx.tablePropertyList)) } /** @@ -261,7 +259,7 @@ class SparkSqlAstBuilder extends AstBuilder { * }}} */ override def visitDropDatabase(ctx: DropDatabaseContext): LogicalPlan = withOrigin(ctx) { - DropDatabase(ctx.identifier.getText, ctx.EXISTS != null, ctx.CASCADE == null)(command(ctx)) + DropDatabase(ctx.identifier.getText, ctx.EXISTS != null, ctx.CASCADE != null) } /** @@ -273,7 +271,7 @@ class SparkSqlAstBuilder extends AstBuilder { * }}} */ override def visitDescribeDatabase(ctx: DescribeDatabaseContext): LogicalPlan = withOrigin(ctx) { - DescribeDatabase(ctx.identifier.getText, ctx.EXTENDED != null)(command(ctx)) + DescribeDatabase(ctx.identifier.getText, ctx.EXTENDED != null) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala index 0e51abb44b..6c2a67f81c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala @@ -19,8 +19,8 @@ package org.apache.spark.sql.execution.command import org.apache.spark.internal.Logging import org.apache.spark.sql.{Row, SQLContext} -import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} -import org.apache.spark.sql.catalyst.catalog.CatalogFunction +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.catalog.CatalogDatabase import org.apache.spark.sql.catalyst.catalog.ExternalCatalog.TablePartitionSpec import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} import org.apache.spark.sql.execution.datasources.BucketSpec @@ -45,46 +45,135 @@ abstract class NativeDDLCommand(val sql: String) extends RunnableCommand { } +/** + * A command for users to create a new database. + * + * It will issue an error message when the database with the same name already exists, + * unless 'ifNotExists' is true. + * The syntax of using this command in SQL is: + * {{{ + * CREATE DATABASE|SCHEMA [IF NOT EXISTS] database_name + * }}} + */ case class CreateDatabase( databaseName: String, ifNotExists: Boolean, path: Option[String], comment: Option[String], - props: Map[String, String])(sql: String) - extends NativeDDLCommand(sql) with Logging + props: Map[String, String]) + extends RunnableCommand { + + override def run(sqlContext: SQLContext): Seq[Row] = { + val catalog = sqlContext.sessionState.catalog + catalog.createDatabase( + CatalogDatabase( + databaseName, + comment.getOrElse(""), + path.getOrElse(catalog.getDefaultDBPath(databaseName)), + props), + ifNotExists) + Seq.empty[Row] + } + + override val output: Seq[Attribute] = Seq.empty +} + /** - * Drop Database: Removes a database from the system. + * A command for users to remove a database from the system. * * 'ifExists': * - true, if database_name does't exist, no action * - false (default), if database_name does't exist, a warning message will be issued - * 'restric': - * - true (default), the database cannot be dropped if it is not empty. The inclusive - * tables must be dropped at first. - * - false, it is in the Cascade mode. The dependent objects are automatically dropped - * before dropping database. + * 'cascade': + * - true, the dependent objects are automatically dropped before dropping database. + * - false (default), it is in the Restrict mode. The database cannot be dropped if + * it is not empty. The inclusive tables must be dropped at first. + * + * The syntax of using this command in SQL is: + * {{{ + * DROP DATABASE [IF EXISTS] database_name [RESTRICT|CASCADE]; + * }}} */ case class DropDatabase( databaseName: String, ifExists: Boolean, - restrict: Boolean)(sql: String) - extends NativeDDLCommand(sql) with Logging + cascade: Boolean) + extends RunnableCommand { + + override def run(sqlContext: SQLContext): Seq[Row] = { + sqlContext.sessionState.catalog.dropDatabase(databaseName, ifExists, cascade) + Seq.empty[Row] + } + + override val output: Seq[Attribute] = Seq.empty +} -/** ALTER DATABASE: add new (key, value) pairs into DBPROPERTIES */ +/** + * A command for users to add new (key, value) pairs into DBPROPERTIES + * If the database does not exist, an error message will be issued to indicate the database + * does not exist. + * The syntax of using this command in SQL is: + * {{{ + * ALTER (DATABASE|SCHEMA) database_name SET DBPROPERTIES (property_name=property_value, ...) + * }}} + */ case class AlterDatabaseProperties( databaseName: String, - props: Map[String, String])(sql: String) - extends NativeDDLCommand(sql) with Logging + props: Map[String, String]) + extends RunnableCommand { + + override def run(sqlContext: SQLContext): Seq[Row] = { + val catalog = sqlContext.sessionState.catalog + val db: CatalogDatabase = catalog.getDatabase(databaseName) + catalog.alterDatabase(db.copy(properties = db.properties ++ props)) + + Seq.empty[Row] + } + + override val output: Seq[Attribute] = Seq.empty +} /** - * DESCRIBE DATABASE: shows the name of the database, its comment (if one has been set), and its + * A command for users to show the name of the database, its comment (if one has been set), and its * root location on the filesystem. When extended is true, it also shows the database's properties + * If the database does not exist, an error message will be issued to indicate the database + * does not exist. + * The syntax of using this command in SQL is + * {{{ + * DESCRIBE DATABASE [EXTENDED] db_name + * }}} */ case class DescribeDatabase( databaseName: String, - extended: Boolean)(sql: String) - extends NativeDDLCommand(sql) with Logging + extended: Boolean) + extends RunnableCommand { + + override def run(sqlContext: SQLContext): Seq[Row] = { + val dbMetadata: CatalogDatabase = sqlContext.sessionState.catalog.getDatabase(databaseName) + val result = + Row("Database Name", dbMetadata.name) :: + Row("Description", dbMetadata.description) :: + Row("Location", dbMetadata.locationUri) :: Nil + + if (extended) { + val properties = + if (dbMetadata.properties.isEmpty) { + "" + } else { + dbMetadata.properties.toSeq.mkString("(", ", ", ")") + } + result :+ Row("Properties", properties) + } else { + result + } + } + + override val output: Seq[Attribute] = { + AttributeReference("database_description_item", StringType, nullable = false)() :: + AttributeReference("database_description_value", StringType, nullable = false)() :: Nil + } +} case class CreateFunction( databaseName: Option[String], diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala index 03079c6890..ccbfd41cca 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala @@ -39,7 +39,7 @@ class DDLCommandSuite extends PlanTest { ifNotExists = true, Some("/home/user/db"), Some("database_comment"), - Map("a" -> "a", "b" -> "b", "c" -> "c"))(sql) + Map("a" -> "a", "b" -> "b", "c" -> "c")) comparePlans(parsed, expected) } @@ -65,39 +65,27 @@ class DDLCommandSuite extends PlanTest { val expected1 = DropDatabase( "database_name", ifExists = true, - restrict = true)(sql1) + cascade = false) val expected2 = DropDatabase( "database_name", ifExists = true, - restrict = false)(sql2) + cascade = true) val expected3 = DropDatabase( - "database_name", - ifExists = true, - restrict = true)(sql3) - val expected4 = DropDatabase( - "database_name", - ifExists = true, - restrict = false)(sql4) - val expected5 = DropDatabase( - "database_name", - ifExists = true, - restrict = true)(sql5) - val expected6 = DropDatabase( "database_name", ifExists = false, - restrict = true)(sql6) - val expected7 = DropDatabase( + cascade = false) + val expected4 = DropDatabase( "database_name", ifExists = false, - restrict = false)(sql7) + cascade = true) comparePlans(parsed1, expected1) comparePlans(parsed2, expected2) - comparePlans(parsed3, expected3) - comparePlans(parsed4, expected4) - comparePlans(parsed5, expected5) - comparePlans(parsed6, expected6) - comparePlans(parsed7, expected7) + comparePlans(parsed3, expected1) + comparePlans(parsed4, expected2) + comparePlans(parsed5, expected1) + comparePlans(parsed6, expected3) + comparePlans(parsed7, expected4) } test("alter database set dbproperties") { @@ -110,10 +98,10 @@ class DDLCommandSuite extends PlanTest { val expected1 = AlterDatabaseProperties( "database_name", - Map("a" -> "a", "b" -> "b", "c" -> "c"))(sql1) + Map("a" -> "a", "b" -> "b", "c" -> "c")) val expected2 = AlterDatabaseProperties( "database_name", - Map("a" -> "a"))(sql2) + Map("a" -> "a")) comparePlans(parsed1, expected1) comparePlans(parsed2, expected2) @@ -129,10 +117,10 @@ class DDLCommandSuite extends PlanTest { val expected1 = DescribeDatabase( "db_name", - extended = true)(sql1) + extended = true) val expected2 = DescribeDatabase( "db_name", - extended = false)(sql2) + extended = false) comparePlans(parsed1, expected1) comparePlans(parsed2, expected2) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala new file mode 100644 index 0000000000..47c9a22acd --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala @@ -0,0 +1,151 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.command + +import java.io.File + +import org.apache.spark.sql.{AnalysisException, QueryTest, Row} +import org.apache.spark.sql.catalyst.catalog.CatalogDatabase +import org.apache.spark.sql.catalyst.parser.ParserUtils._ +import org.apache.spark.sql.test.SharedSQLContext + +class DDLSuite extends QueryTest with SharedSQLContext { + + /** + * Drops database `databaseName` after calling `f`. + */ + private def withDatabase(dbNames: String*)(f: => Unit): Unit = { + try f finally { + dbNames.foreach { name => + sqlContext.sql(s"DROP DATABASE IF EXISTS $name CASCADE") + } + } + } + + test("Create/Drop Database") { + val catalog = sqlContext.sessionState.catalog + + val databaseNames = Seq("db1", "`database`") + + databaseNames.foreach { dbName => + withDatabase(dbName) { + val dbNameWithoutBackTicks = cleanIdentifier(dbName) + + sql(s"CREATE DATABASE $dbName") + val db1 = catalog.getDatabase(dbNameWithoutBackTicks) + assert(db1 == CatalogDatabase( + dbNameWithoutBackTicks, + "", + System.getProperty("java.io.tmpdir") + File.separator + s"$dbNameWithoutBackTicks.db", + Map.empty)) + sql(s"DROP DATABASE $dbName CASCADE") + assert(!catalog.databaseExists(dbNameWithoutBackTicks)) + } + } + } + + test("Create Database - database already exists") { + val catalog = sqlContext.sessionState.catalog + val databaseNames = Seq("db1", "`database`") + + databaseNames.foreach { dbName => + val dbNameWithoutBackTicks = cleanIdentifier(dbName) + withDatabase(dbName) { + sql(s"CREATE DATABASE $dbName") + val db1 = catalog.getDatabase(dbNameWithoutBackTicks) + assert(db1 == CatalogDatabase( + dbNameWithoutBackTicks, + "", + System.getProperty("java.io.tmpdir") + File.separator + s"$dbNameWithoutBackTicks.db", + Map.empty)) + + val message = intercept[AnalysisException] { + sql(s"CREATE DATABASE $dbName") + }.getMessage + assert(message.contains(s"Database '$dbNameWithoutBackTicks' already exists.")) + } + } + } + + test("Alter/Describe Database") { + val catalog = sqlContext.sessionState.catalog + val databaseNames = Seq("db1", "`database`") + + databaseNames.foreach { dbName => + withDatabase(dbName) { + val dbNameWithoutBackTicks = cleanIdentifier(dbName) + val location = + System.getProperty("java.io.tmpdir") + File.separator + s"$dbNameWithoutBackTicks.db" + sql(s"CREATE DATABASE $dbName") + + checkAnswer( + sql(s"DESCRIBE DATABASE EXTENDED $dbName"), + Row("Database Name", dbNameWithoutBackTicks) :: + Row("Description", "") :: + Row("Location", location) :: + Row("Properties", "") :: Nil) + + sql(s"ALTER DATABASE $dbName SET DBPROPERTIES ('a'='a', 'b'='b', 'c'='c')") + + checkAnswer( + sql(s"DESCRIBE DATABASE EXTENDED $dbName"), + Row("Database Name", dbNameWithoutBackTicks) :: + Row("Description", "") :: + Row("Location", location) :: + Row("Properties", "((a,a), (b,b), (c,c))") :: Nil) + + sql(s"ALTER DATABASE $dbName SET DBPROPERTIES ('d'='d')") + + checkAnswer( + sql(s"DESCRIBE DATABASE EXTENDED $dbName"), + Row("Database Name", dbNameWithoutBackTicks) :: + Row("Description", "") :: + Row("Location", location) :: + Row("Properties", "((a,a), (b,b), (c,c), (d,d))") :: Nil) + } + } + } + + test("Drop/Alter/Describe Database - database does not exists") { + val databaseNames = Seq("db1", "`database`") + + databaseNames.foreach { dbName => + val dbNameWithoutBackTicks = cleanIdentifier(dbName) + assert(!sqlContext.sessionState.catalog.databaseExists(dbNameWithoutBackTicks)) + + var message = intercept[AnalysisException] { + sql(s"DROP DATABASE $dbName") + }.getMessage + assert(message.contains(s"Database '$dbNameWithoutBackTicks' does not exist")) + + message = intercept[AnalysisException] { + sql(s"ALTER DATABASE $dbName SET DBPROPERTIES ('d'='d')") + }.getMessage + assert(message.contains(s"Database '$dbNameWithoutBackTicks' does not exist")) + + message = intercept[AnalysisException] { + sql(s"DESCRIBE DATABASE EXTENDED $dbName") + }.getMessage + assert(message.contains(s"Database '$dbNameWithoutBackTicks' does not exist")) + + sql(s"DROP DATABASE IF EXISTS $dbName") + } + } + + // TODO: ADD a testcase for Drop Database in Restric when we can create tables in SQLContext +} diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala index 8e1ebe2937..7ad7f92bd2 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala @@ -183,7 +183,7 @@ class CliSuite extends SparkFunSuite with BeforeAndAfterAll with Logging { test("Single command with --database") { runCliWithin(2.minute)( "CREATE DATABASE hive_test_db;" - -> "OK", + -> "", "USE hive_test_db;" -> "", "CREATE TABLE hive_test(key INT, val STRING);" diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala index ff12245e8d..1cd783e63a 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala @@ -17,6 +17,9 @@ package org.apache.spark.sql.hive +import org.apache.hadoop.fs.Path +import org.apache.hadoop.hive.conf.HiveConf + import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.FunctionRegistry import org.apache.spark.sql.catalyst.catalog.SessionCatalog @@ -59,6 +62,11 @@ class HiveSessionCatalog( // | Methods and fields for interacting with HiveMetastoreCatalog | // ---------------------------------------------------------------- + override def getDefaultDBPath(db: String): String = { + val defaultPath = context.hiveconf.getVar(HiveConf.ConfVars.METASTOREWAREHOUSE) + new Path(new Path(defaultPath), db + ".db").toString + } + // Catalog for handling data source tables. TODO: This really doesn't belong here since it is // essentially a cache for metastore tables. However, it relies on a lot of session-specific // things so it would be a lot of work to split its functionality between HiveSessionCatalog -- cgit v1.2.3 From 7320f9bd190afb7639cd21e956e7300fdd84c0ee Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Tue, 29 Mar 2016 21:14:48 -0700 Subject: [SPARK-14254][CORE] Add logs to help investigate the network performance ## What changes were proposed in this pull request? It would be very helpful for network performance investigation if we log the time spent on connecting and resolving host. ## How was this patch tested? Jenkins unit tests. Author: Shixiong Zhu Closes #12046 from zsxwing/connection-time. --- .../java/org/apache/spark/network/client/TransportClientFactory.java | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/common/network-common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java b/common/network-common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java index f179bad1f4..5a36e18b09 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java +++ b/common/network-common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java @@ -123,7 +123,10 @@ public class TransportClientFactory implements Closeable { public TransportClient createClient(String remoteHost, int remotePort) throws IOException { // Get connection from the connection pool first. // If it is not found or not active, create a new one. + long preResolveHost = System.nanoTime(); final InetSocketAddress address = new InetSocketAddress(remoteHost, remotePort); + long hostResolveTimeMs = (System.nanoTime() - preResolveHost) / 1000000; + logger.info("Spent {} ms to resolve {}", hostResolveTimeMs, address); // Create the ClientPool if we don't have it yet. ClientPool clientPool = connectionPool.get(address); @@ -235,7 +238,7 @@ public class TransportClientFactory implements Closeable { } long postBootstrap = System.nanoTime(); - logger.debug("Successfully created connection to {} after {} ms ({} ms spent in bootstraps)", + logger.info("Successfully created connection to {} after {} ms ({} ms spent in bootstraps)", address, (postBootstrap - preConnect) / 1000000, (postBootstrap - preBootstrap) / 1000000); return client; -- cgit v1.2.3 From 816f359cf043ef719a0bc7df0506a3a830fff70d Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 30 Mar 2016 17:32:53 +0800 Subject: [SPARK-14114][SQL] implement buildReader for text data source ## What changes were proposed in this pull request? This PR implements buildReader for text data source and enable it in the new data source code path. ## How was this patch tested? Existing tests. Author: Wenchen Fan Closes #11934 from cloud-fan/text. --- .../execution/datasources/FileSourceStrategy.scala | 3 ++- .../execution/datasources/text/DefaultSource.scala | 28 +++++++++++++++++++++- 2 files changed, 29 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala index 20fda95154..4448796b16 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala @@ -59,7 +59,8 @@ private[sql] object FileSourceStrategy extends Strategy with Logging { if (files.fileFormat.toString == "TestFileFormat" || files.fileFormat.isInstanceOf[parquet.DefaultSource] || files.fileFormat.toString == "ORC" || - files.fileFormat.isInstanceOf[json.DefaultSource]) && + files.fileFormat.isInstanceOf[json.DefaultSource] || + files.fileFormat.isInstanceOf[text.DefaultSource]) && files.sqlContext.conf.useFileScan => // Filters on this relation fall into four categories based on where we can use them to avoid // reading unneeded data: diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/DefaultSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/DefaultSource.scala index 5cfc9e9afa..d6ab5fc56e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/DefaultSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/DefaultSource.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.execution.datasources.text +import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileStatus, Path} import org.apache.hadoop.io.{LongWritable, NullWritable, Text} import org.apache.hadoop.mapred.{JobConf, TextInputFormat} @@ -30,7 +31,7 @@ import org.apache.spark.sql.{AnalysisException, Row, SQLContext} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.catalyst.expressions.codegen.{BufferHolder, UnsafeRowWriter} -import org.apache.spark.sql.execution.datasources.CompressionCodecs +import org.apache.spark.sql.execution.datasources.{CompressionCodecs, HadoopFileLinesReader, PartitionedFile} import org.apache.spark.sql.sources._ import org.apache.spark.sql.types.{StringType, StructType} import org.apache.spark.util.SerializableConfiguration @@ -125,6 +126,31 @@ class DefaultSource extends FileFormat with DataSourceRegister { } } } + + override def buildReader( + sqlContext: SQLContext, + partitionSchema: StructType, + dataSchema: StructType, + filters: Seq[Filter], + options: Map[String, String]): PartitionedFile => Iterator[InternalRow] = { + val conf = new Configuration(sqlContext.sparkContext.hadoopConfiguration) + val broadcastedConf = + sqlContext.sparkContext.broadcast(new SerializableConfiguration(conf)) + + file => { + val unsafeRow = new UnsafeRow(1) + val bufferHolder = new BufferHolder(unsafeRow) + val unsafeRowWriter = new UnsafeRowWriter(bufferHolder, 1) + + new HadoopFileLinesReader(file, broadcastedConf.value.value).map { line => + // Writes to an UnsafeRow directly + bufferHolder.reset() + unsafeRowWriter.write(0, line.getBytes, 0, line.getLength) + unsafeRow.setTotalSize(bufferHolder.totalSize()) + unsafeRow + } + } + } } class TextOutputWriter(path: String, dataSchema: StructType, context: TaskAttemptContext) -- cgit v1.2.3 From d46c71b39da92f5cabf6d9057c953c52f7f3f965 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 30 Mar 2016 11:03:15 -0700 Subject: [SPARK-14268][SQL] rename toRowExpressions and fromRowExpression to serializer and deserializer in ExpressionEncoder ## What changes were proposed in this pull request? In `ExpressionEncoder`, we use `constructorFor` to build `fromRowExpression` as the `deserializer` in `ObjectOperator`. It's kind of confusing, we should make the name consistent. ## How was this patch tested? existing tests. Author: Wenchen Fan Closes #12058 from cloud-fan/rename. --- .../main/scala/org/apache/spark/sql/Encoder.scala | 4 +- .../spark/sql/catalyst/JavaTypeInference.scala | 32 ++++---- .../spark/sql/catalyst/ScalaReflection.scala | 40 +++++----- .../sql/catalyst/encoders/ExpressionEncoder.scala | 87 +++++++++++----------- .../spark/sql/catalyst/encoders/RowEncoder.scala | 34 ++++----- .../spark/sql/catalyst/plans/logical/object.scala | 14 ++-- .../spark/sql/catalyst/ScalaReflectionSuite.scala | 4 +- .../catalyst/encoders/ExpressionEncoderSuite.scala | 4 +- .../scala/org/apache/spark/sql/QueryTest.scala | 4 +- 9 files changed, 110 insertions(+), 113 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala index b19538a23f..1f20e26354 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala @@ -245,10 +245,10 @@ object Encoders { ExpressionEncoder[T]( schema = new StructType().add("value", BinaryType), flat = true, - toRowExpressions = Seq( + serializer = Seq( EncodeUsingSerializer( BoundReference(0, ObjectType(classOf[AnyRef]), nullable = true), kryo = useKryo)), - fromRowExpression = + deserializer = DecodeUsingSerializer[T]( BoundReference(0, BinaryType, nullable = true), classTag[T], kryo = useKryo), clsTag = classTag[T] diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala index 59ee41d02f..6f9fbbbead 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala @@ -155,16 +155,16 @@ object JavaTypeInference { } /** - * Returns an expression that can be used to construct an object of java bean `T` given an input - * row with a compatible schema. Fields of the row will be extracted using UnresolvedAttributes + * Returns an expression that can be used to deserialize an internal row to an object of java bean + * `T` with a compatible schema. Fields of the row will be extracted using UnresolvedAttributes * of the same name as the constructor arguments. Nested classes will have their fields accessed * using UnresolvedExtractValue. */ - def constructorFor(beanClass: Class[_]): Expression = { - constructorFor(TypeToken.of(beanClass), None) + def deserializerFor(beanClass: Class[_]): Expression = { + deserializerFor(TypeToken.of(beanClass), None) } - private def constructorFor(typeToken: TypeToken[_], path: Option[Expression]): Expression = { + private def deserializerFor(typeToken: TypeToken[_], path: Option[Expression]): Expression = { /** Returns the current path with a sub-field extracted. */ def addToPath(part: String): Expression = path .map(p => UnresolvedExtractValue(p, expressions.Literal(part))) @@ -231,7 +231,7 @@ object JavaTypeInference { }.getOrElse { Invoke( MapObjects( - p => constructorFor(typeToken.getComponentType, Some(p)), + p => deserializerFor(typeToken.getComponentType, Some(p)), getPath, inferDataType(elementType)._1), "array", @@ -243,7 +243,7 @@ object JavaTypeInference { val array = Invoke( MapObjects( - p => constructorFor(et, Some(p)), + p => deserializerFor(et, Some(p)), getPath, inferDataType(et)._1), "array", @@ -259,7 +259,7 @@ object JavaTypeInference { val keyData = Invoke( MapObjects( - p => constructorFor(keyType, Some(p)), + p => deserializerFor(keyType, Some(p)), Invoke(getPath, "keyArray", ArrayType(keyDataType)), keyDataType), "array", @@ -268,7 +268,7 @@ object JavaTypeInference { val valueData = Invoke( MapObjects( - p => constructorFor(valueType, Some(p)), + p => deserializerFor(valueType, Some(p)), Invoke(getPath, "valueArray", ArrayType(valueDataType)), valueDataType), "array", @@ -288,7 +288,7 @@ object JavaTypeInference { val fieldName = p.getName val fieldType = typeToken.method(p.getReadMethod).getReturnType val (_, nullable) = inferDataType(fieldType) - val constructor = constructorFor(fieldType, Some(addToPath(fieldName))) + val constructor = deserializerFor(fieldType, Some(addToPath(fieldName))) val setter = if (nullable) { constructor } else { @@ -313,14 +313,14 @@ object JavaTypeInference { } /** - * Returns expressions for extracting all the fields from the given type. + * Returns an expression for serializing an object of the given type to an internal row. */ - def extractorsFor(beanClass: Class[_]): CreateNamedStruct = { + def serializerFor(beanClass: Class[_]): CreateNamedStruct = { val inputObject = BoundReference(0, ObjectType(beanClass), nullable = true) - extractorFor(inputObject, TypeToken.of(beanClass)).asInstanceOf[CreateNamedStruct] + serializerFor(inputObject, TypeToken.of(beanClass)).asInstanceOf[CreateNamedStruct] } - private def extractorFor(inputObject: Expression, typeToken: TypeToken[_]): Expression = { + private def serializerFor(inputObject: Expression, typeToken: TypeToken[_]): Expression = { def toCatalystArray(input: Expression, elementType: TypeToken[_]): Expression = { val (dataType, nullable) = inferDataType(elementType) @@ -330,7 +330,7 @@ object JavaTypeInference { input :: Nil, dataType = ArrayType(dataType, nullable)) } else { - MapObjects(extractorFor(_, elementType), input, ObjectType(elementType.getRawType)) + MapObjects(serializerFor(_, elementType), input, ObjectType(elementType.getRawType)) } } @@ -403,7 +403,7 @@ object JavaTypeInference { inputObject, p.getReadMethod.getName, inferExternalType(fieldType.getRawType)) - expressions.Literal(fieldName) :: extractorFor(fieldValue, fieldType) :: Nil + expressions.Literal(fieldName) :: serializerFor(fieldValue, fieldType) :: Nil }) } else { throw new UnsupportedOperationException( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index f208401160..d241b8a79b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -110,8 +110,8 @@ object ScalaReflection extends ScalaReflection { } /** - * Returns an expression that can be used to construct an object of type `T` given an input - * row with a compatible schema. Fields of the row will be extracted using UnresolvedAttributes + * Returns an expression that can be used to deserialize an input row to an object of type `T` + * with a compatible schema. Fields of the row will be extracted using UnresolvedAttributes * of the same name as the constructor arguments. Nested classes will have their fields accessed * using UnresolvedExtractValue. * @@ -119,14 +119,14 @@ object ScalaReflection extends ScalaReflection { * from ordinal 0 (since there are no names to map to). The actual location can be moved by * calling resolve/bind with a new schema. */ - def constructorFor[T : TypeTag]: Expression = { + def deserializerFor[T : TypeTag]: Expression = { val tpe = localTypeOf[T] val clsName = getClassNameFromType(tpe) val walkedTypePath = s"""- root class: "${clsName}"""" :: Nil - constructorFor(tpe, None, walkedTypePath) + deserializerFor(tpe, None, walkedTypePath) } - private def constructorFor( + private def deserializerFor( tpe: `Type`, path: Option[Expression], walkedTypePath: Seq[String]): Expression = ScalaReflectionLock.synchronized { @@ -161,7 +161,7 @@ object ScalaReflection extends ScalaReflection { } /** - * When we build the `fromRowExpression` for an encoder, we set up a lot of "unresolved" stuff + * When we build the `deserializer` for an encoder, we set up a lot of "unresolved" stuff * and lost the required data type, which may lead to runtime error if the real type doesn't * match the encoder's schema. * For example, we build an encoder for `case class Data(a: Int, b: String)` and the real type @@ -188,7 +188,7 @@ object ScalaReflection extends ScalaReflection { val TypeRef(_, _, Seq(optType)) = t val className = getClassNameFromType(optType) val newTypePath = s"""- option value class: "$className"""" +: walkedTypePath - WrapOption(constructorFor(optType, path, newTypePath), dataTypeFor(optType)) + WrapOption(deserializerFor(optType, path, newTypePath), dataTypeFor(optType)) case t if t <:< localTypeOf[java.lang.Integer] => val boxedType = classOf[java.lang.Integer] @@ -272,7 +272,7 @@ object ScalaReflection extends ScalaReflection { val newTypePath = s"""- array element class: "$className"""" +: walkedTypePath Invoke( MapObjects( - p => constructorFor(elementType, Some(p), newTypePath), + p => deserializerFor(elementType, Some(p), newTypePath), getPath, schemaFor(elementType).dataType), "array", @@ -286,7 +286,7 @@ object ScalaReflection extends ScalaReflection { val newTypePath = s"""- array element class: "$className"""" +: walkedTypePath val mapFunction: Expression => Expression = p => { - val converter = constructorFor(elementType, Some(p), newTypePath) + val converter = deserializerFor(elementType, Some(p), newTypePath) if (nullable) { converter } else { @@ -312,7 +312,7 @@ object ScalaReflection extends ScalaReflection { val keyData = Invoke( MapObjects( - p => constructorFor(keyType, Some(p), walkedTypePath), + p => deserializerFor(keyType, Some(p), walkedTypePath), Invoke(getPath, "keyArray", ArrayType(schemaFor(keyType).dataType)), schemaFor(keyType).dataType), "array", @@ -321,7 +321,7 @@ object ScalaReflection extends ScalaReflection { val valueData = Invoke( MapObjects( - p => constructorFor(valueType, Some(p), walkedTypePath), + p => deserializerFor(valueType, Some(p), walkedTypePath), Invoke(getPath, "valueArray", ArrayType(schemaFor(valueType).dataType)), schemaFor(valueType).dataType), "array", @@ -344,12 +344,12 @@ object ScalaReflection extends ScalaReflection { val newTypePath = s"""- field (class: "$clsName", name: "$fieldName")""" +: walkedTypePath // For tuples, we based grab the inner fields by ordinal instead of name. if (cls.getName startsWith "scala.Tuple") { - constructorFor( + deserializerFor( fieldType, Some(addToPathOrdinal(i, dataType, newTypePath)), newTypePath) } else { - val constructor = constructorFor( + val constructor = deserializerFor( fieldType, Some(addToPath(fieldName, dataType, newTypePath)), newTypePath) @@ -387,7 +387,7 @@ object ScalaReflection extends ScalaReflection { } /** - * Returns expressions for extracting all the fields from the given type. + * Returns an expression for serializing an object of type T to an internal row. * * If the given type is not supported, i.e. there is no encoder can be built for this type, * an [[UnsupportedOperationException]] will be thrown with detailed error message to explain @@ -398,18 +398,18 @@ object ScalaReflection extends ScalaReflection { * * the element type of [[Array]] or [[Seq]]: `array element class: "abc.xyz.MyClass"` * * the field of [[Product]]: `field (class: "abc.xyz.MyClass", name: "myField")` */ - def extractorsFor[T : TypeTag](inputObject: Expression): CreateNamedStruct = { + def serializerFor[T : TypeTag](inputObject: Expression): CreateNamedStruct = { val tpe = localTypeOf[T] val clsName = getClassNameFromType(tpe) val walkedTypePath = s"""- root class: "${clsName}"""" :: Nil - extractorFor(inputObject, tpe, walkedTypePath) match { + serializerFor(inputObject, tpe, walkedTypePath) match { case expressions.If(_, _, s: CreateNamedStruct) if tpe <:< localTypeOf[Product] => s case other => CreateNamedStruct(expressions.Literal("value") :: other :: Nil) } } /** Helper for extracting internal fields from a case class. */ - private def extractorFor( + private def serializerFor( inputObject: Expression, tpe: `Type`, walkedTypePath: Seq[String]): Expression = ScalaReflectionLock.synchronized { @@ -425,7 +425,7 @@ object ScalaReflection extends ScalaReflection { } else { val clsName = getClassNameFromType(elementType) val newPath = s"""- array element class: "$clsName"""" +: walkedTypePath - MapObjects(extractorFor(_, elementType, newPath), input, externalDataType) + MapObjects(serializerFor(_, elementType, newPath), input, externalDataType) } } @@ -491,7 +491,7 @@ object ScalaReflection extends ScalaReflection { expressions.If( IsNull(unwrapped), expressions.Literal.create(null, silentSchemaFor(optType).dataType), - extractorFor(unwrapped, optType, newPath)) + serializerFor(unwrapped, optType, newPath)) } case t if t <:< localTypeOf[Product] => @@ -500,7 +500,7 @@ object ScalaReflection extends ScalaReflection { val fieldValue = Invoke(inputObject, fieldName, dataTypeFor(fieldType)) val clsName = getClassNameFromType(fieldType) val newPath = s"""- field (class: "$clsName", name: "$fieldName")""" +: walkedTypePath - expressions.Literal(fieldName) :: extractorFor(fieldValue, fieldType, newPath) :: Nil + expressions.Literal(fieldName) :: serializerFor(fieldValue, fieldType, newPath) :: Nil }) val nullOutput = expressions.Literal.create(null, nonNullOutput.dataType) expressions.If(IsNull(inputObject), nullOutput, nonNullOutput) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala index 918233ddcd..1c712fde26 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala @@ -51,8 +51,8 @@ object ExpressionEncoder { val flat = !classOf[Product].isAssignableFrom(cls) val inputObject = BoundReference(0, ScalaReflection.dataTypeFor[T], nullable = false) - val toRowExpression = ScalaReflection.extractorsFor[T](inputObject) - val fromRowExpression = ScalaReflection.constructorFor[T] + val serializer = ScalaReflection.serializerFor[T](inputObject) + val deserializer = ScalaReflection.deserializerFor[T] val schema = ScalaReflection.schemaFor[T] match { case ScalaReflection.Schema(s: StructType, _) => s @@ -62,8 +62,8 @@ object ExpressionEncoder { new ExpressionEncoder[T]( schema, flat, - toRowExpression.flatten, - fromRowExpression, + serializer.flatten, + deserializer, ClassTag[T](cls)) } @@ -72,14 +72,14 @@ object ExpressionEncoder { val schema = JavaTypeInference.inferDataType(beanClass)._1 assert(schema.isInstanceOf[StructType]) - val toRowExpression = JavaTypeInference.extractorsFor(beanClass) - val fromRowExpression = JavaTypeInference.constructorFor(beanClass) + val serializer = JavaTypeInference.serializerFor(beanClass) + val deserializer = JavaTypeInference.deserializerFor(beanClass) new ExpressionEncoder[T]( schema.asInstanceOf[StructType], flat = false, - toRowExpression.flatten, - fromRowExpression, + serializer.flatten, + deserializer, ClassTag[T](beanClass)) } @@ -103,9 +103,9 @@ object ExpressionEncoder { val cls = Utils.getContextOrSparkClassLoader.loadClass(s"scala.Tuple${encoders.size}") - val toRowExpressions = encoders.map { - case e if e.flat => e.toRowExpressions.head - case other => CreateStruct(other.toRowExpressions) + val serializer = encoders.map { + case e if e.flat => e.serializer.head + case other => CreateStruct(other.serializer) }.zipWithIndex.map { case (expr, index) => expr.transformUp { case BoundReference(0, t, _) => @@ -116,14 +116,14 @@ object ExpressionEncoder { } } - val fromRowExpressions = encoders.zipWithIndex.map { case (enc, index) => + val childrenDeserializers = encoders.zipWithIndex.map { case (enc, index) => if (enc.flat) { - enc.fromRowExpression.transform { + enc.deserializer.transform { case b: BoundReference => b.copy(ordinal = index) } } else { val input = BoundReference(index, enc.schema, nullable = true) - enc.fromRowExpression.transformUp { + enc.deserializer.transformUp { case UnresolvedAttribute(nameParts) => assert(nameParts.length == 1) UnresolvedExtractValue(input, Literal(nameParts.head)) @@ -132,14 +132,14 @@ object ExpressionEncoder { } } - val fromRowExpression = - NewInstance(cls, fromRowExpressions, ObjectType(cls), propagateNull = false) + val deserializer = + NewInstance(cls, childrenDeserializers, ObjectType(cls), propagateNull = false) new ExpressionEncoder[Any]( schema, flat = false, - toRowExpressions, - fromRowExpression, + serializer, + deserializer, ClassTag(cls)) } @@ -174,29 +174,29 @@ object ExpressionEncoder { * A generic encoder for JVM objects. * * @param schema The schema after converting `T` to a Spark SQL row. - * @param toRowExpressions A set of expressions, one for each top-level field that can be used to - * extract the values from a raw object into an [[InternalRow]]. - * @param fromRowExpression An expression that will construct an object given an [[InternalRow]]. + * @param serializer A set of expressions, one for each top-level field that can be used to + * extract the values from a raw object into an [[InternalRow]]. + * @param deserializer An expression that will construct an object given an [[InternalRow]]. * @param clsTag A classtag for `T`. */ case class ExpressionEncoder[T]( schema: StructType, flat: Boolean, - toRowExpressions: Seq[Expression], - fromRowExpression: Expression, + serializer: Seq[Expression], + deserializer: Expression, clsTag: ClassTag[T]) extends Encoder[T] { - if (flat) require(toRowExpressions.size == 1) + if (flat) require(serializer.size == 1) @transient - private lazy val extractProjection = GenerateUnsafeProjection.generate(toRowExpressions) + private lazy val extractProjection = GenerateUnsafeProjection.generate(serializer) @transient private lazy val inputRow = new GenericMutableRow(1) @transient - private lazy val constructProjection = GenerateSafeProjection.generate(fromRowExpression :: Nil) + private lazy val constructProjection = GenerateSafeProjection.generate(deserializer :: Nil) /** * Returns this encoder where it has been bound to its own output (i.e. no remaping of columns @@ -212,7 +212,7 @@ case class ExpressionEncoder[T]( * Returns a new set (with unique ids) of [[NamedExpression]] that represent the serialized form * of this object. */ - def namedExpressions: Seq[NamedExpression] = schema.map(_.name).zip(toRowExpressions).map { + def namedExpressions: Seq[NamedExpression] = schema.map(_.name).zip(serializer).map { case (_, ne: NamedExpression) => ne.newInstance() case (name, e) => Alias(e, name)() } @@ -228,7 +228,7 @@ case class ExpressionEncoder[T]( } catch { case e: Exception => throw new RuntimeException( - s"Error while encoding: $e\n${toRowExpressions.map(_.treeString).mkString("\n")}", e) + s"Error while encoding: $e\n${serializer.map(_.treeString).mkString("\n")}", e) } /** @@ -240,7 +240,7 @@ case class ExpressionEncoder[T]( constructProjection(row).get(0, ObjectType(clsTag.runtimeClass)).asInstanceOf[T] } catch { case e: Exception => - throw new RuntimeException(s"Error while decoding: $e\n${fromRowExpression.treeString}", e) + throw new RuntimeException(s"Error while decoding: $e\n${deserializer.treeString}", e) } /** @@ -249,7 +249,7 @@ case class ExpressionEncoder[T]( * has not been done already in places where we plan to do later composition of encoders. */ def assertUnresolved(): Unit = { - (fromRowExpression +: toRowExpressions).foreach(_.foreach { + (deserializer +: serializer).foreach(_.foreach { case a: AttributeReference if a.name != "loopVar" => sys.error(s"Unresolved encoder expected, but $a was found.") case _ => @@ -257,7 +257,7 @@ case class ExpressionEncoder[T]( } /** - * Validates `fromRowExpression` to make sure it can be resolved by given schema, and produce + * Validates `deserializer` to make sure it can be resolved by given schema, and produce * friendly error messages to explain why it fails to resolve if there is something wrong. */ def validate(schema: Seq[Attribute]): Unit = { @@ -271,7 +271,7 @@ case class ExpressionEncoder[T]( // If this is a tuple encoder or tupled encoder, which means its leaf nodes are all // `BoundReference`, make sure their ordinals are all valid. var maxOrdinal = -1 - fromRowExpression.foreach { + deserializer.foreach { case b: BoundReference => if (b.ordinal > maxOrdinal) maxOrdinal = b.ordinal case _ => } @@ -285,7 +285,7 @@ case class ExpressionEncoder[T]( // we unbound it by the given `schema` and propagate the actual type to `GetStructField`, after // we resolve the `fromRowExpression`. val resolved = SimpleAnalyzer.resolveExpression( - fromRowExpression, + deserializer, LocalRelation(schema), throws = true) @@ -312,42 +312,39 @@ case class ExpressionEncoder[T]( } /** - * Returns a new copy of this encoder, where the expressions used by `fromRow` are resolved to the - * given schema. + * Returns a new copy of this encoder, where the `deserializer` is resolved to the given schema. */ def resolve( schema: Seq[Attribute], outerScopes: ConcurrentMap[String, AnyRef]): ExpressionEncoder[T] = { - val deserializer = SimpleAnalyzer.ResolveReferences.resolveDeserializer( - fromRowExpression, schema) + val resolved = SimpleAnalyzer.ResolveReferences.resolveDeserializer(deserializer, schema) // Make a fake plan to wrap the deserializer, so that we can go though the whole analyzer, check // analysis, go through optimizer, etc. - val plan = Project(Alias(deserializer, "")() :: Nil, LocalRelation(schema)) + val plan = Project(Alias(resolved, "")() :: Nil, LocalRelation(schema)) val analyzedPlan = SimpleAnalyzer.execute(plan) SimpleAnalyzer.checkAnalysis(analyzedPlan) - copy(fromRowExpression = SimplifyCasts(analyzedPlan).expressions.head.children.head) + copy(deserializer = SimplifyCasts(analyzedPlan).expressions.head.children.head) } /** - * Returns a copy of this encoder where the expressions used to construct an object from an input - * row have been bound to the ordinals of the given schema. Note that you need to first call - * resolve before bind. + * Returns a copy of this encoder where the `deserializer` has been bound to the + * ordinals of the given schema. Note that you need to first call resolve before bind. */ def bind(schema: Seq[Attribute]): ExpressionEncoder[T] = { - copy(fromRowExpression = BindReferences.bindReference(fromRowExpression, schema)) + copy(deserializer = BindReferences.bindReference(deserializer, schema)) } /** * Returns a new encoder with input columns shifted by `delta` ordinals */ def shift(delta: Int): ExpressionEncoder[T] = { - copy(fromRowExpression = fromRowExpression transform { + copy(deserializer = deserializer transform { case r: BoundReference => r.copy(ordinal = r.ordinal + delta) }) } - protected val attrs = toRowExpressions.flatMap(_.collect { + protected val attrs = serializer.flatMap(_.collect { case _: UnresolvedAttribute => "" case a: Attribute => s"#${a.exprId}" case b: BoundReference => s"[${b.ordinal}]" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala index 30f56d8c2f..a8397aa5e5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala @@ -36,23 +36,23 @@ object RowEncoder { val cls = classOf[Row] val inputObject = BoundReference(0, ObjectType(cls), nullable = true) // We use an If expression to wrap extractorsFor result of StructType - val extractExpressions = extractorsFor(inputObject, schema).asInstanceOf[If].falseValue - val constructExpression = constructorFor(schema) + val serializer = serializerFor(inputObject, schema).asInstanceOf[If].falseValue + val deserializer = deserializerFor(schema) new ExpressionEncoder[Row]( schema, flat = false, - extractExpressions.asInstanceOf[CreateStruct].children, - constructExpression, + serializer.asInstanceOf[CreateStruct].children, + deserializer, ClassTag(cls)) } - private def extractorsFor( + private def serializerFor( inputObject: Expression, inputType: DataType): Expression = inputType match { case NullType | BooleanType | ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType | BinaryType | CalendarIntervalType => inputObject - case p: PythonUserDefinedType => extractorsFor(inputObject, p.sqlType) + case p: PythonUserDefinedType => serializerFor(inputObject, p.sqlType) case udt: UserDefinedType[_] => val obj = NewInstance( @@ -95,7 +95,7 @@ object RowEncoder { classOf[GenericArrayData], inputObject :: Nil, dataType = t) - case _ => MapObjects(extractorsFor(_, et), inputObject, externalDataTypeForInput(et)) + case _ => MapObjects(serializerFor(_, et), inputObject, externalDataTypeForInput(et)) } case t @ MapType(kt, vt, valueNullable) => @@ -104,14 +104,14 @@ object RowEncoder { Invoke(inputObject, "keysIterator", ObjectType(classOf[scala.collection.Iterator[_]])), "toSeq", ObjectType(classOf[scala.collection.Seq[_]])) - val convertedKeys = extractorsFor(keys, ArrayType(kt, false)) + val convertedKeys = serializerFor(keys, ArrayType(kt, false)) val values = Invoke( Invoke(inputObject, "valuesIterator", ObjectType(classOf[scala.collection.Iterator[_]])), "toSeq", ObjectType(classOf[scala.collection.Seq[_]])) - val convertedValues = extractorsFor(values, ArrayType(vt, valueNullable)) + val convertedValues = serializerFor(values, ArrayType(vt, valueNullable)) NewInstance( classOf[ArrayBasedMapData], @@ -128,7 +128,7 @@ object RowEncoder { If( Invoke(inputObject, "isNullAt", BooleanType, Literal(i) :: Nil), Literal.create(null, f.dataType), - extractorsFor( + serializerFor( Invoke(inputObject, method, externalDataTypeForInput(f.dataType), Literal(i) :: Nil), f.dataType)) } @@ -166,7 +166,7 @@ object RowEncoder { case _: NullType => ObjectType(classOf[java.lang.Object]) } - private def constructorFor(schema: StructType): Expression = { + private def deserializerFor(schema: StructType): Expression = { val fields = schema.zipWithIndex.map { case (f, i) => val dt = f.dataType match { case p: PythonUserDefinedType => p.sqlType @@ -176,13 +176,13 @@ object RowEncoder { If( IsNull(field), Literal.create(null, externalDataTypeFor(dt)), - constructorFor(field) + deserializerFor(field) ) } CreateExternalRow(fields, schema) } - private def constructorFor(input: Expression): Expression = input.dataType match { + private def deserializerFor(input: Expression): Expression = input.dataType match { case NullType | BooleanType | ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType | BinaryType | CalendarIntervalType => input @@ -216,7 +216,7 @@ object RowEncoder { case ArrayType(et, nullable) => val arrayData = Invoke( - MapObjects(constructorFor(_), input, et), + MapObjects(deserializerFor(_), input, et), "array", ObjectType(classOf[Array[_]])) StaticInvoke( @@ -227,10 +227,10 @@ object RowEncoder { case MapType(kt, vt, valueNullable) => val keyArrayType = ArrayType(kt, false) - val keyData = constructorFor(Invoke(input, "keyArray", keyArrayType)) + val keyData = deserializerFor(Invoke(input, "keyArray", keyArrayType)) val valueArrayType = ArrayType(vt, valueNullable) - val valueData = constructorFor(Invoke(input, "valueArray", valueArrayType)) + val valueData = deserializerFor(Invoke(input, "valueArray", valueArrayType)) StaticInvoke( ArrayBasedMapData.getClass, @@ -243,7 +243,7 @@ object RowEncoder { If( Invoke(input, "isNullAt", BooleanType, Literal(i) :: Nil), Literal.create(null, externalDataTypeFor(f.dataType)), - constructorFor(GetStructField(input, i))) + deserializerFor(GetStructField(input, i))) } If(IsNull(input), Literal.create(null, externalDataTypeFor(input.dataType)), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala index da7f81c785..058fb6bff1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala @@ -71,7 +71,7 @@ object MapPartitions { child: LogicalPlan): MapPartitions = { MapPartitions( func.asInstanceOf[Iterator[Any] => Iterator[Any]], - encoderFor[T].fromRowExpression, + encoderFor[T].deserializer, encoderFor[U].namedExpressions, child) } @@ -98,7 +98,7 @@ object AppendColumns { child: LogicalPlan): AppendColumns = { new AppendColumns( func.asInstanceOf[Any => Any], - encoderFor[T].fromRowExpression, + encoderFor[T].deserializer, encoderFor[U].namedExpressions, child) } @@ -133,8 +133,8 @@ object MapGroups { child: LogicalPlan): MapGroups = { new MapGroups( func.asInstanceOf[(Any, Iterator[Any]) => TraversableOnce[Any]], - encoderFor[K].fromRowExpression, - encoderFor[T].fromRowExpression, + encoderFor[K].deserializer, + encoderFor[T].deserializer, encoderFor[U].namedExpressions, groupingAttributes, dataAttributes, @@ -178,9 +178,9 @@ object CoGroup { CoGroup( func.asInstanceOf[(Any, Iterator[Any], Iterator[Any]) => TraversableOnce[Any]], - encoderFor[Key].fromRowExpression, - encoderFor[Left].fromRowExpression, - encoderFor[Right].fromRowExpression, + encoderFor[Key].deserializer, + encoderFor[Left].deserializer, + encoderFor[Right].deserializer, encoderFor[Result].namedExpressions, leftGroup, rightGroup, diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala index dd31050bb5..5ca5a72512 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala @@ -248,10 +248,10 @@ class ScalaReflectionSuite extends SparkFunSuite { Seq( ("mirror", () => mirror), ("dataTypeFor", () => dataTypeFor[ComplexData]), - ("constructorFor", () => constructorFor[ComplexData]), + ("constructorFor", () => deserializerFor[ComplexData]), ("extractorsFor", { val inputObject = BoundReference(0, dataTypeForComplexData, nullable = false) - () => extractorsFor[ComplexData](inputObject) + () => serializerFor[ComplexData](inputObject) }), ("getConstructorParameters(cls)", () => getConstructorParameters(classOf[ComplexData])), ("getConstructorParameterNames", () => getConstructorParameterNames(classOf[ComplexData])), diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala index f6583bfe42..18752014ea 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala @@ -315,7 +315,7 @@ class ExpressionEncoderSuite extends PlanTest with AnalysisTest { val attr = AttributeReference("obj", ObjectType(encoder.clsTag.runtimeClass))() val inputPlan = LocalRelation(attr) val plan = - Project(Alias(encoder.fromRowExpression, "obj")() :: Nil, + Project(Alias(encoder.deserializer, "obj")() :: Nil, Project(encoder.namedExpressions, inputPlan)) assertAnalysisSuccess(plan) @@ -360,7 +360,7 @@ class ExpressionEncoderSuite extends PlanTest with AnalysisTest { |${encoder.schema.treeString} | |fromRow Expressions: - |${boundEncoder.fromRowExpression.treeString} + |${boundEncoder.deserializer.treeString} """.stripMargin) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index 7ff4ffcaec..854a662cc4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -90,7 +90,7 @@ abstract class QueryTest extends PlanTest { s""" |Exception collecting dataset as objects |${ds.resolvedTEncoder} - |${ds.resolvedTEncoder.fromRowExpression.treeString} + |${ds.resolvedTEncoder.deserializer.treeString} |${ds.queryExecution} """.stripMargin, e) } @@ -109,7 +109,7 @@ abstract class QueryTest extends PlanTest { fail( s"""Decoded objects do not match expected objects: |$comparision - |${ds.resolvedTEncoder.fromRowExpression.treeString} + |${ds.resolvedTEncoder.deserializer.treeString} """.stripMargin) } } -- cgit v1.2.3 From bdabfd43f6e4900b48010dd00ffa48ed5fd15997 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Wed, 30 Mar 2016 13:59:10 -0700 Subject: [SPARK-13955][YARN] Also look for Spark jars in the build directory. Move the logic to find Spark jars to CommandBuilderUtils and make it available for YARN code, so that it's possible to easily launch Spark on YARN from a build directory. Tested by running SparkPi from the build directory on YARN. Author: Marcelo Vanzin Closes #11970 from vanzin/SPARK-13955. --- .../spark/launcher/AbstractCommandBuilder.java | 23 +------------------- .../apache/spark/launcher/CommandBuilderUtils.java | 25 ++++++++++++++++++++++ .../org/apache/spark/deploy/yarn/Client.scala | 11 +++++----- .../spark/launcher/YarnCommandBuilderUtils.scala | 9 ++++++++ .../org/apache/spark/deploy/yarn/ClientSuite.scala | 3 ++- 5 files changed, 42 insertions(+), 29 deletions(-) diff --git a/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java b/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java index 587fda7a3c..d02b2a4994 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java +++ b/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java @@ -174,7 +174,7 @@ abstract class AbstractCommandBuilder { // Add Spark jars to the classpath. For the testing case, we rely on the test code to set and // propagate the test classpath appropriately. For normal invocation, look for the jars // directory under SPARK_HOME. - String jarsDir = findJarsDir(!isTesting); + String jarsDir = findJarsDir(getSparkHome(), getScalaVersion(), !isTesting); if (jarsDir != null) { addToClassPath(cp, join(File.separator, jarsDir, "*")); } @@ -311,27 +311,6 @@ abstract class AbstractCommandBuilder { return props; } - private String findJarsDir(boolean failIfNotFound) { - // TODO: change to the correct directory once the assembly build is changed. - String sparkHome = getSparkHome(); - File libdir; - if (new File(sparkHome, "RELEASE").isFile()) { - libdir = new File(sparkHome, "lib"); - checkState(!failIfNotFound || libdir.isDirectory(), - "Library directory '%s' does not exist.", - libdir.getAbsolutePath()); - } else { - libdir = new File(sparkHome, String.format("assembly/target/scala-%s", getScalaVersion())); - if (!libdir.isDirectory()) { - checkState(!failIfNotFound, - "Library directory '%s' does not exist; make sure Spark is built.", - libdir.getAbsolutePath()); - libdir = null; - } - } - return libdir != null ? libdir.getAbsolutePath() : null; - } - private String getConfDir() { String confDir = getenv("SPARK_CONF_DIR"); return confDir != null ? confDir : join(File.separator, getSparkHome(), "conf"); diff --git a/launcher/src/main/java/org/apache/spark/launcher/CommandBuilderUtils.java b/launcher/src/main/java/org/apache/spark/launcher/CommandBuilderUtils.java index 39fdf300e2..1e55aad5c9 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/CommandBuilderUtils.java +++ b/launcher/src/main/java/org/apache/spark/launcher/CommandBuilderUtils.java @@ -349,4 +349,29 @@ class CommandBuilderUtils { return Integer.parseInt(version[1]); } } + + /** + * Find the location of the Spark jars dir, depending on whether we're looking at a build + * or a distribution directory. + */ + static String findJarsDir(String sparkHome, String scalaVersion, boolean failIfNotFound) { + // TODO: change to the correct directory once the assembly build is changed. + File libdir; + if (new File(sparkHome, "RELEASE").isFile()) { + libdir = new File(sparkHome, "lib"); + checkState(!failIfNotFound || libdir.isDirectory(), + "Library directory '%s' does not exist.", + libdir.getAbsolutePath()); + } else { + libdir = new File(sparkHome, String.format("assembly/target/scala-%s", scalaVersion)); + if (!libdir.isDirectory()) { + checkState(!failIfNotFound, + "Library directory '%s' does not exist; make sure Spark is built.", + libdir.getAbsolutePath()); + libdir = null; + } + } + return libdir != null ? libdir.getAbsolutePath() : null; + } + } diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index 6bbc8c2dfa..7b29c1ae4d 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -468,12 +468,11 @@ private[spark] class Client( // No configuration, so fall back to uploading local jar files. logWarning(s"Neither ${SPARK_JARS.key} nor ${SPARK_ARCHIVE.key} is set, falling back " + "to uploading libraries under SPARK_HOME.") - val jarsDir = new File(sparkConf.getenv("SPARK_HOME"), "lib") - if (jarsDir.isDirectory()) { - jarsDir.listFiles().foreach { f => - if (f.isFile() && f.getName().toLowerCase().endsWith(".jar")) { - distribute(f.getAbsolutePath(), targetDir = Some(LOCALIZED_LIB_DIR)) - } + val jarsDir = new File(YarnCommandBuilderUtils.findJarsDir( + sparkConf.getenv("SPARK_HOME"))) + jarsDir.listFiles().foreach { f => + if (f.isFile() && f.getName().toLowerCase().endsWith(".jar")) { + distribute(f.getAbsolutePath(), targetDir = Some(LOCALIZED_LIB_DIR)) } } } diff --git a/yarn/src/main/scala/org/apache/spark/launcher/YarnCommandBuilderUtils.scala b/yarn/src/main/scala/org/apache/spark/launcher/YarnCommandBuilderUtils.scala index 7d246bf407..6c3556a2ee 100644 --- a/yarn/src/main/scala/org/apache/spark/launcher/YarnCommandBuilderUtils.scala +++ b/yarn/src/main/scala/org/apache/spark/launcher/YarnCommandBuilderUtils.scala @@ -19,6 +19,7 @@ package org.apache.spark.launcher import scala.collection.JavaConverters._ import scala.collection.mutable.ListBuffer +import scala.util.Properties /** * Exposes methods from the launcher library that are used by the YARN backend. @@ -29,6 +30,14 @@ private[spark] object YarnCommandBuilderUtils { CommandBuilderUtils.quoteForBatchScript(arg) } + def findJarsDir(sparkHome: String): String = { + val scalaVer = Properties.versionNumberString + .split("\\.") + .take(2) + .mkString(".") + CommandBuilderUtils.findJarsDir(sparkHome, scalaVer, true) + } + /** * Adds the perm gen configuration to the list of java options if needed and not yet added. * diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala index 24472e006b..e3613a93ed 100644 --- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.deploy.yarn -import java.io.File +import java.io.{File, FileOutputStream} import java.net.URI import java.util.Properties @@ -274,6 +274,7 @@ class ClientSuite extends SparkFunSuite with Matchers with BeforeAndAfterAll val jarsDir = new File(temp, "lib") assert(jarsDir.mkdir()) val jar = TestUtils.createJarWithFiles(Map(), jarsDir) + new FileOutputStream(new File(temp, "RELEASE")).close() val sparkConf = new SparkConfWithEnv(Map("SPARK_HOME" -> temp.getAbsolutePath())) val client = createClient(sparkConf) -- cgit v1.2.3 From 529d6ce8f96ef2b4a57c2d9066c7d80466e36209 Mon Sep 17 00:00:00 2001 From: Xusen Yin Date: Wed, 30 Mar 2016 14:32:29 -0700 Subject: [SPARK-14181] TrainValidationSplit should have HasSeed https://issues.apache.org/jira/browse/SPARK-14181 TrainValidationSplit should have HasSeed for the random split of RDD. I also changed the random split from the RDD function to the DataFrame function. Author: Xusen Yin Closes #11985 from yinxusen/SPARK-14181. --- .../org/apache/spark/ml/tuning/TrainValidationSplit.scala | 15 ++++++++++----- .../spark/ml/tuning/TrainValidationSplitSuite.scala | 4 ++++ 2 files changed, 14 insertions(+), 5 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala index 4d1d6364d7..07330bb6b0 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala @@ -25,6 +25,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.evaluation.Evaluator import org.apache.spark.ml.param.{DoubleParam, ParamMap, ParamValidators} +import org.apache.spark.ml.param.shared.HasSeed import org.apache.spark.ml.util._ import org.apache.spark.sql.DataFrame import org.apache.spark.sql.types.StructType @@ -32,7 +33,7 @@ import org.apache.spark.sql.types.StructType /** * Params for [[TrainValidationSplit]] and [[TrainValidationSplitModel]]. */ -private[ml] trait TrainValidationSplitParams extends ValidatorParams { +private[ml] trait TrainValidationSplitParams extends ValidatorParams with HasSeed { /** * Param for ratio between train and validation data. Must be between 0 and 1. * Default: 0.75 @@ -80,6 +81,10 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: St @Since("1.5.0") def setTrainRatio(value: Double): this.type = set(trainRatio, value) + /** @group setParam */ + @Since("2.0.0") + def setSeed(value: Long): this.type = set(seed, value) + @Since("1.5.0") override def fit(dataset: DataFrame): TrainValidationSplitModel = { val schema = dataset.schema @@ -91,10 +96,10 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: St val numModels = epm.length val metrics = new Array[Double](epm.length) - val Array(training, validation) = - dataset.rdd.randomSplit(Array($(trainRatio), 1 - $(trainRatio))) - val trainingDataset = sqlCtx.createDataFrame(training, schema).cache() - val validationDataset = sqlCtx.createDataFrame(validation, schema).cache() + val Array(trainingDataset, validationDataset) = + dataset.randomSplit(Array($(trainRatio), 1 - $(trainRatio)), $(seed)) + trainingDataset.cache() + validationDataset.cache() // multi-model training logDebug(s"Train split with multiple sets of parameters.") diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala index 7cf7b3e087..4030956fab 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala @@ -48,6 +48,7 @@ class TrainValidationSplitSuite .setEstimatorParamMaps(lrParamMaps) .setEvaluator(eval) .setTrainRatio(0.5) + .setSeed(42L) val cvModel = cv.fit(dataset) val parent = cvModel.bestModel.parent.asInstanceOf[LogisticRegression] assert(cv.getTrainRatio === 0.5) @@ -72,6 +73,7 @@ class TrainValidationSplitSuite .setEstimatorParamMaps(lrParamMaps) .setEvaluator(eval) .setTrainRatio(0.5) + .setSeed(42L) val cvModel = cv.fit(dataset) val parent = cvModel.bestModel.parent.asInstanceOf[LinearRegression] assert(parent.getRegParam === 0.001) @@ -120,6 +122,7 @@ class TrainValidationSplitSuite .setEvaluator(evaluator) .setTrainRatio(0.5) .setEstimatorParamMaps(paramMaps) + .setSeed(42L) val tvs2 = testDefaultReadWrite(tvs, testParams = false) @@ -140,6 +143,7 @@ class TrainValidationSplitSuite .set(tvs.evaluator, evaluator) .set(tvs.trainRatio, 0.5) .set(tvs.estimatorParamMaps, paramMaps) + .set(tvs.seed, 42L) val tvs2 = testDefaultReadWrite(tvs, testParams = false) -- cgit v1.2.3 From 5dc948e8125fd27646a7f1e8991948a45b3f9c50 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Wed, 30 Mar 2016 14:57:38 -0700 Subject: [MINOR][ML] Fix the wrong param name of LDA topicDistributionCol ## What changes were proposed in this pull request? Fix the wrong param name of LDA ```topicDistributionCol```. ## How was this patch tested? No tests. cc jkbradley Author: Yanbo Liang Closes #12065 from yanboliang/lda-topicDistributionCol. --- mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala index fe6a37fd6d..60cc345565 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala @@ -176,7 +176,7 @@ private[clustering] trait LDAParams extends Params with HasFeaturesCol with HasM * @group param */ @Since("1.6.0") - final val topicDistributionCol = new Param[String](this, "topicDistribution", "Output column" + + final val topicDistributionCol = new Param[String](this, "topicDistributionCol", "Output column" + " with estimates of the topic mixture distribution for each document (often called \"theta\"" + " in the literature). Returns a vector of zeros for an empty document.") -- cgit v1.2.3 From f301df37cb63aeecf48077ae56351538e6eeeeb7 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Wed, 30 Mar 2016 15:47:01 -0700 Subject: [SPARK-14152][ML][PYSPARK] MultilayerPerceptronClassifier supports save/load for Python API ## What changes were proposed in this pull request? ```MultilayerPerceptronClassifier``` supports save/load for Python API. ## How was this patch tested? doctest. cc mengxr jkbradley yinxusen Author: Yanbo Liang Closes #11952 from yanboliang/spark-14152. --- python/pyspark/ml/classification.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index d51b80e16c..07cafa0993 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -762,7 +762,7 @@ class NaiveBayesModel(JavaModel, JavaMLWritable, JavaMLReadable): @inherit_doc class MultilayerPerceptronClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, - HasMaxIter, HasTol, HasSeed): + HasMaxIter, HasTol, HasSeed, JavaMLWritable, JavaMLReadable): """ Classifier trainer based on the Multilayer Perceptron. Each layer has sigmoid activation function, output layer has softmax. @@ -792,6 +792,18 @@ class MultilayerPerceptronClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, |[0.0,0.0]| 0.0| +---------+----------+ ... + >>> mlp_path = temp_path + "/mlp" + >>> mlp.save(mlp_path) + >>> mlp2 = MultilayerPerceptronClassifier.load(mlp_path) + >>> mlp2.getBlockSize() + 1 + >>> model_path = temp_path + "/mlp_model" + >>> model.save(model_path) + >>> model2 = MultilayerPerceptronClassificationModel.load(model_path) + >>> model.layers == model2.layers + True + >>> model.weights == model2.weights + True .. versionadded:: 1.6.0 """ @@ -869,7 +881,7 @@ class MultilayerPerceptronClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, return self.getOrDefault(self.blockSize) -class MultilayerPerceptronClassificationModel(JavaModel): +class MultilayerPerceptronClassificationModel(JavaModel, JavaMLWritable, JavaMLReadable): """ Model fitted by MultilayerPerceptronClassifier. -- cgit v1.2.3 From ca458618d8ee659ffa9a081083cd475a440fa8ff Mon Sep 17 00:00:00 2001 From: Yuhao Yang Date: Wed, 30 Mar 2016 15:58:19 -0700 Subject: [SPARK-11507][MLLIB] add compact in Matrices fromBreeze jira: https://issues.apache.org/jira/browse/SPARK-11507 "In certain situations when adding two block matrices, I get an error regarding colPtr and the operation fails. External issue URL includes full error and code for reproducing the problem." root cause: colPtr.last does NOT always equal to values.length in breeze SCSMatrix, which fails the require in SparseMatrix. easy step to repro: ``` val m1: BM[Double] = new CSCMatrix[Double] (Array (1.0, 1, 1), 3, 3, Array (0, 1, 2, 3), Array (0, 1, 2) ) val m2: BM[Double] = new CSCMatrix[Double] (Array (1.0, 2, 2, 4), 3, 3, Array (0, 0, 2, 4), Array (1, 2, 1, 2) ) val sum = m1 + m2 Matrices.fromBreeze(sum) ``` Solution: By checking the code in [CSCMatrix](https://github.com/scalanlp/breeze/blob/28000a7b901bc3cfbbbf5c0bce1d0a5dda8281b0/math/src/main/scala/breeze/linalg/CSCMatrix.scala), CSCMatrix in breeze can have extra zeros in the end of data array. Invoking compact will make sure it aligns with the require of SparseMatrix. This should add limited overhead as the actual compact operation is only performed when necessary. Author: Yuhao Yang Closes #9520 from hhbyyh/matricesFromBreeze. --- .../main/scala/org/apache/spark/mllib/linalg/Matrices.scala | 10 +++++++++- .../scala/org/apache/spark/mllib/linalg/MatricesSuite.scala | 12 ++++++++++++ 2 files changed, 21 insertions(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala index a09bc65cf3..6e571fe35a 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala @@ -940,8 +940,16 @@ object Matrices { case dm: BDM[Double] => new DenseMatrix(dm.rows, dm.cols, dm.data, dm.isTranspose) case sm: BSM[Double] => + // Spark-11507. work around breeze issue 479. + val mat = if (sm.colPtrs.last != sm.data.length) { + val matCopy = sm.copy + matCopy.compact() + matCopy + } else { + sm + } // There is no isTranspose flag for sparse matrices in Breeze - new SparseMatrix(sm.rows, sm.cols, sm.colPtrs, sm.rowIndices, sm.data) + new SparseMatrix(mat.rows, mat.cols, mat.colPtrs, mat.rowIndices, mat.data) case _ => throw new UnsupportedOperationException( s"Do not support conversion from type ${breeze.getClass.getName}.") diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala index 57907f415c..e289724cda 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.mllib.linalg import java.util.Random +import breeze.linalg.{CSCMatrix, Matrix => BM} import org.mockito.Mockito.when import org.scalatest.mock.MockitoSugar._ import scala.collection.mutable.{Map => MutableMap} @@ -499,6 +500,17 @@ class MatricesSuite extends SparkFunSuite { assert(sm1.numActives === 3) } + test("fromBreeze with sparse matrix") { + // colPtr.last does NOT always equal to values.length in breeze SCSMatrix and + // invocation of compact() may be necessary. Refer to SPARK-11507 + val bm1: BM[Double] = new CSCMatrix[Double]( + Array(1.0, 1, 1), 3, 3, Array(0, 1, 2, 3), Array(0, 1, 2)) + val bm2: BM[Double] = new CSCMatrix[Double]( + Array(1.0, 2, 2, 4), 3, 3, Array(0, 0, 2, 4), Array(1, 2, 1, 2)) + val sum = bm1 + bm2 + Matrices.fromBreeze(sum) + } + test("row/col iterator") { val dm = new DenseMatrix(3, 2, Array(0, 1, 2, 3, 4, 0)) val sm = dm.toSparse -- cgit v1.2.3 From dadf0138b3f6fd618677a2c26f40ab66b7a1139d Mon Sep 17 00:00:00 2001 From: Takeshi YAMAMURO Date: Wed, 30 Mar 2016 16:02:48 -0700 Subject: [SPARK-14259][SQL] Add a FileSourceStrategy option for limiting #files in a partition ## What changes were proposed in this pull request? This pr is to add a config to control the maximum number of files as even small files have a non-trivial fixed cost. The current packing can put a lot of small files together which cases straggler tasks. ## How was this patch tested? I added tests to check if many files get split into partitions in FileSourceStrategySuite. Author: Takeshi YAMAMURO Closes #12068 from maropu/SPARK-14259. --- .../execution/datasources/FileSourceStrategy.scala | 7 +++- .../org/apache/spark/sql/internal/SQLConf.scala | 7 ++++ .../datasources/FileSourceStrategySuite.scala | 47 ++++++++++++++++++++++ 3 files changed, 59 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala index 4448796b16..d6534083c0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala @@ -136,7 +136,9 @@ private[sql] object FileSourceStrategy extends Strategy with Logging { case _ => val maxSplitBytes = files.sqlContext.conf.filesMaxPartitionBytes - logInfo(s"Planning scan with bin packing, max size: $maxSplitBytes bytes") + val maxFileNumInPartition = files.sqlContext.conf.filesMaxNumInPartition + logInfo(s"Planning scan with bin packing, max size: $maxSplitBytes bytes, " + + s"max #files: $maxFileNumInPartition") val splitFiles = selectedPartitions.flatMap { partition => partition.files.flatMap { file => @@ -174,7 +176,8 @@ private[sql] object FileSourceStrategy extends Strategy with Logging { // Assign files to partitions using "First Fit Decreasing" (FFD) // TODO: consider adding a slop factor here? splitFiles.foreach { file => - if (currentSize + file.length > maxSplitBytes) { + if (currentSize + file.length > maxSplitBytes || + currentFiles.length >= maxFileNumInPartition) { closePartition() addFile(file) } else { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index ca6ba4c643..d06e9086a3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -524,6 +524,11 @@ object SQLConf { doc = "The maximum number of bytes to pack into a single partition when reading files.", isPublic = true) + val FILES_MAX_NUM_IN_PARTITION = longConf("spark.sql.files.maxNumInPartition", + defaultValue = Some(32), + doc = "The maximum number of files to pack into a single partition when reading files.", + isPublic = true) + val EXCHANGE_REUSE_ENABLED = booleanConf("spark.sql.exchange.reuse", defaultValue = Some(true), doc = "When true, the planner will try to find out duplicated exchanges and re-use them.", @@ -581,6 +586,8 @@ class SQLConf extends Serializable with CatalystConf with ParserConf with Loggin def filesMaxPartitionBytes: Long = getConf(FILES_MAX_PARTITION_BYTES) + def filesMaxNumInPartition: Long = getConf(FILES_MAX_NUM_IN_PARTITION) + def useCompression: Boolean = getConf(COMPRESS_CACHED) def useFileScan: Boolean = getConf(USE_FILE_SCAN) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala index 1fa15730bc..45620bc965 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala @@ -121,6 +121,53 @@ class FileSourceStrategySuite extends QueryTest with SharedSQLContext with Predi } } + test("Unpartitioned table, many files that get split") { + val table = + createTable( + files = Seq( + "file1" -> 2, + "file2" -> 2, + "file3" -> 1, + "file4" -> 1, + "file5" -> 1, + "file6" -> 1)) + + withSQLConf(SQLConf.FILES_MAX_PARTITION_BYTES.key -> "3", + SQLConf.FILES_MAX_NUM_IN_PARTITION.key -> "2") { + checkScan(table.select('c1)) { partitions => + // Files should be laid out [(file1), (file2, file3), (file4, file5), (file6)] + assert(partitions.size == 4, "when checking partitions") + assert(partitions(0).files.size == 1, "when checking partition 1") + assert(partitions(1).files.size == 2, "when checking partition 2") + assert(partitions(2).files.size == 2, "when checking partition 3") + assert(partitions(3).files.size == 1, "when checking partition 4") + + // First partition reads (file1) + assert(partitions(0).files(0).start == 0) + assert(partitions(0).files(0).length == 2) + + // Second partition reads (file2, file3) + assert(partitions(1).files(0).start == 0) + assert(partitions(1).files(0).length == 2) + assert(partitions(1).files(1).start == 0) + assert(partitions(1).files(1).length == 1) + + // Third partition reads (file4, file5) + assert(partitions(2).files(0).start == 0) + assert(partitions(2).files(0).length == 1) + assert(partitions(2).files(1).start == 0) + assert(partitions(2).files(1).length == 1) + + // Final partition reads (file6) + assert(partitions(3).files(0).start == 0) + assert(partitions(3).files(0).length == 1) + } + + checkPartitionSchema(StructType(Nil)) + checkDataSchema(StructType(Nil).add("c1", IntegerType)) + } + } + test("partitioned table") { val table = createTable( -- cgit v1.2.3 From 258a2434193aae62999102a8df73ca70bf0cb9f1 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Wed, 30 Mar 2016 16:15:37 -0700 Subject: [SPARK-14282][SQL] CodeFormatter should handle oneline comment with /* */ properly ## What changes were proposed in this pull request? This PR improves `CodeFormatter` to fix the following malformed indentations. ```java /* 019 */ public java.lang.Object apply(java.lang.Object _i) { /* 020 */ InternalRow i = (InternalRow) _i; /* 021 */ /* createexternalrow(if (isnull(input[0, double])) null else input[0, double], if (isnull(input[1, int])) null else input[1, int], ... */ /* 022 */ boolean isNull = false; /* 023 */ final Object[] values = new Object[2]; /* 024 */ /* if (isnull(input[0, double])) null else input[0, double] */ /* 025 */ /* isnull(input[0, double]) */ ... /* 053 */ if (!false && false) { /* 054 */ /* null */ /* 055 */ final int value9 = -1; /* 056 */ isNull6 = true; /* 057 */ value6 = value9; /* 058 */ } else { ... /* 077 */ return mutableRow; /* 078 */ } /* 079 */ } /* 080 */ ``` After this PR, the code will be formatted like the following. ```java /* 019 */ public java.lang.Object apply(java.lang.Object _i) { /* 020 */ InternalRow i = (InternalRow) _i; /* 021 */ /* createexternalrow(if (isnull(input[0, double])) null else input[0, double], if (isnull(input[1, int])) null else input[1, int], ... */ /* 022 */ boolean isNull = false; /* 023 */ final Object[] values = new Object[2]; /* 024 */ /* if (isnull(input[0, double])) null else input[0, double] */ /* 025 */ /* isnull(input[0, double]) */ ... /* 053 */ if (!false && false) { /* 054 */ /* null */ /* 055 */ final int value9 = -1; /* 056 */ isNull6 = true; /* 057 */ value6 = value9; /* 058 */ } else { ... /* 077 */ return mutableRow; /* 078 */ } /* 079 */ } /* 080 */ ``` Also, this issue fixes the following too. (Similar with [SPARK-14185](https://issues.apache.org/jira/browse/SPARK-14185)) ```java 16/03/30 12:39:24 DEBUG WholeStageCodegen: /* 001 */ public Object generate(Object[] references) { /* 002 */ return new GeneratedIterator(references); /* 003 */ } ``` ```java 16/03/30 12:46:32 DEBUG WholeStageCodegen: /* 001 */ public Object generate(Object[] references) { /* 002 */ return new GeneratedIterator(references); /* 003 */ } ``` ## How was this patch tested? Pass the Jenkins tests (including new CodeFormatterSuite testcases.) Author: Dongjoon Hyun Closes #12072 from dongjoon-hyun/SPARK-14282. --- .../sql/catalyst/expressions/codegen/CodeFormatter.scala | 3 ++- .../catalyst/expressions/codegen/CodeFormatterSuite.scala | 14 ++++++++++++++ .../org/apache/spark/sql/execution/WholeStageCodegen.scala | 2 +- 3 files changed, 17 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeFormatter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeFormatter.scala index 8e40754dc3..ab4831f7ab 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeFormatter.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeFormatter.scala @@ -74,7 +74,8 @@ private class CodeFormatter { // Handle single line comments newIndentLevel = indentLevel } - } else { + } + if (inCommentBlock) { if (line.endsWith("*/")) { inCommentBlock = false newIndentLevel = indentLevelOutsideCommentBlock diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeFormatterSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeFormatterSuite.scala index d7836aa3b2..f57b82bb96 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeFormatterSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeFormatterSuite.scala @@ -115,6 +115,20 @@ class CodeFormatterSuite extends SparkFunSuite { """.stripMargin } + testCase("single line comments /* */ ") { + """/** This is a comment about class A { { { ( ( */ + |class A { + |class body; + |}""".stripMargin + }{ + """ + |/* 001 */ /** This is a comment about class A { { { ( ( */ + |/* 002 */ class A { + |/* 003 */ class body; + |/* 004 */ } + """.stripMargin + } + testCase("multi-line comments") { """ /* This is a comment about |class A { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala index da3ee46b7d..6a779abd40 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala @@ -337,7 +337,7 @@ case class WholeStageCodegen(child: SparkPlan) extends UnaryNode with CodegenSup // try to compile, helpful for debug val cleanedSource = CodeFormatter.stripExtraNewLines(source) - logDebug(s"${CodeFormatter.format(cleanedSource)}") + logDebug(s"\n${CodeFormatter.format(cleanedSource)}") CodeGenerator.compile(cleanedSource) (ctx, cleanedSource) } -- cgit v1.2.3 From da54abfd8730ef752eca921089bcf568773bd24a Mon Sep 17 00:00:00 2001 From: Travis Crawford Date: Wed, 30 Mar 2016 16:59:52 -0700 Subject: [SPARK-14081][SQL] - Preserve DataFrame column types when filling nulls. ## What changes were proposed in this pull request? This change resolves an issue where `DataFrameNaFunctions.fill` changes a `FloatType` column to a `DoubleType`. We also clarify the contract that replacement values will be cast to the column data type, which may change the replacement value when casting to a lower precision type. ## How was this patch tested? This patch has associated unit tests. Author: Travis Crawford Closes #11967 from traviscrawford/SPARK-14081-dataframena. --- .../apache/spark/sql/DataFrameNaFunctions.scala | 18 ++++---- .../spark/sql/DataFrameNaFunctionsSuite.scala | 50 +++++++++++++--------- 2 files changed, 40 insertions(+), 28 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala index 33588ef72f..f0e16eefc7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala @@ -200,6 +200,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { * The key of the map is the column name, and the value of the map is the replacement value. * The value must be of the following type: * `Integer`, `Long`, `Float`, `Double`, `String`, `Boolean`. + * Replacement values are cast to the column data type. * * For example, the following replaces null values in column "A" with string "unknown", and * null values in column "B" with numeric value 1.0. @@ -217,6 +218,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { * * The key of the map is the column name, and the value of the map is the replacement value. * The value must be of the following type: `Int`, `Long`, `Float`, `Double`, `String`, `Boolean`. + * Replacement values are cast to the column data type. * * For example, the following replaces null values in column "A" with string "unknown", and * null values in column "B" with numeric value 1.0. @@ -386,10 +388,10 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { val projections = df.schema.fields.map { f => values.find { case (k, _) => columnEquals(k, f.name) }.map { case (_, v) => v match { - case v: jl.Float => fillCol[Double](f, v.toDouble) + case v: jl.Float => fillCol[Float](f, v) case v: jl.Double => fillCol[Double](f, v) - case v: jl.Long => fillCol[Double](f, v.toDouble) - case v: jl.Integer => fillCol[Double](f, v.toDouble) + case v: jl.Long => fillCol[Long](f, v) + case v: jl.Integer => fillCol[Integer](f, v) case v: jl.Boolean => fillCol[Boolean](f, v.booleanValue()) case v: String => fillCol[String](f, v) } @@ -402,13 +404,13 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { * Returns a [[Column]] expression that replaces null value in `col` with `replacement`. */ private def fillCol[T](col: StructField, replacement: T): Column = { - col.dataType match { + val quotedColName = "`" + col.name + "`" + val colValue = col.dataType match { case DoubleType | FloatType => - coalesce(nanvl(df.col("`" + col.name + "`"), lit(null)), - lit(replacement).cast(col.dataType)).as(col.name) - case _ => - coalesce(df.col("`" + col.name + "`"), lit(replacement).cast(col.dataType)).as(col.name) + nanvl(df.col(quotedColName), lit(null)) // nanvl only supports these types + case _ => df.col(quotedColName) } + coalesce(colValue, lit(replacement)).cast(col.dataType).as(col.name) } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala index e34875471f..18e04c24a4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala @@ -141,26 +141,36 @@ class DataFrameNaFunctionsSuite extends QueryTest with SharedSQLContext { } test("fill with map") { - val df = Seq[(String, String, java.lang.Long, java.lang.Double, java.lang.Boolean)]( - (null, null, null, null, null)).toDF("a", "b", "c", "d", "e") - checkAnswer( - df.na.fill(Map( - "a" -> "test", - "c" -> 1, - "d" -> 2.2, - "e" -> false - )), - Row("test", null, 1, 2.2, false)) - - // Test Java version - checkAnswer( - df.na.fill(Map( - "a" -> "test", - "c" -> 1, - "d" -> 2.2, - "e" -> false - ).asJava), - Row("test", null, 1, 2.2, false)) + val df = Seq[(String, String, java.lang.Integer, java.lang.Long, + java.lang.Float, java.lang.Double, java.lang.Boolean)]( + (null, null, null, null, null, null, null)) + .toDF("stringFieldA", "stringFieldB", "integerField", "longField", + "floatField", "doubleField", "booleanField") + + val fillMap = Map( + "stringFieldA" -> "test", + "integerField" -> 1, + "longField" -> 2L, + "floatField" -> 3.3f, + "doubleField" -> 4.4d, + "booleanField" -> false) + + val expectedRow = Row("test", null, 1, 2L, 3.3f, 4.4d, false) + + checkAnswer(df.na.fill(fillMap), expectedRow) + checkAnswer(df.na.fill(fillMap.asJava), expectedRow) // Test Java version + + // Ensure replacement values are cast to the column data type. + checkAnswer(df.na.fill(Map( + "integerField" -> 1d, + "longField" -> 2d, + "floatField" -> 3d, + "doubleField" -> 4d)), + Row(null, null, 1, 2L, 3f, 4d, null)) + + // Ensure column types do not change. Columns that have null values replaced + // will no longer be flagged as nullable, so do not compare schemas directly. + assert(df.na.fill(fillMap).schema.fields.map(_.dataType) === df.schema.fields.map(_.dataType)) } test("replace") { -- cgit v1.2.3 From 26445c2e472bad137fd350e4089dd0ff43a42039 Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Wed, 30 Mar 2016 18:21:06 -0700 Subject: [SPARK-14206][SQL] buildReader() implementation for CSV ## What changes were proposed in this pull request? Major changes: 1. Implement `FileFormat.buildReader()` for the CSV data source. 1. Add an extra argument to `FileFormat.buildReader()`, `physicalSchema`, which is basically the result of `FileFormat.inferSchema` or user specified schema. This argument is necessary because the CSV data source needs to know all the columns of the underlying files to read the file. ## How was this patch tested? Existing tests should do the work. Author: Cheng Lian Closes #12002 from liancheng/spark-14206-csv-build-reader. --- .../execution/datasources/FileSourceStrategy.scala | 16 +++---- .../execution/datasources/csv/CSVRelation.scala | 41 +++++++++++++---- .../execution/datasources/csv/DefaultSource.scala | 51 +++++++++++++++++++--- .../execution/datasources/json/JSONRelation.scala | 7 +-- .../datasources/parquet/ParquetRelation.scala | 26 +++-------- .../execution/datasources/text/DefaultSource.scala | 3 +- .../org/apache/spark/sql/sources/interfaces.scala | 18 +++++--- .../datasources/FileSourceStrategySuite.scala | 5 ++- .../apache/spark/sql/hive/orc/OrcRelation.scala | 15 ++++--- 9 files changed, 119 insertions(+), 63 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala index d6534083c0..554298772a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala @@ -59,8 +59,9 @@ private[sql] object FileSourceStrategy extends Strategy with Logging { if (files.fileFormat.toString == "TestFileFormat" || files.fileFormat.isInstanceOf[parquet.DefaultSource] || files.fileFormat.toString == "ORC" || - files.fileFormat.isInstanceOf[json.DefaultSource] || - files.fileFormat.isInstanceOf[text.DefaultSource]) && + files.fileFormat.isInstanceOf[csv.DefaultSource] || + files.fileFormat.isInstanceOf[text.DefaultSource] || + files.fileFormat.isInstanceOf[json.DefaultSource]) && files.sqlContext.conf.useFileScan => // Filters on this relation fall into four categories based on where we can use them to avoid // reading unneeded data: @@ -80,14 +81,6 @@ private[sql] object FileSourceStrategy extends Strategy with Logging { val dataColumns = l.resolve(files.dataSchema, files.sqlContext.sessionState.analyzer.resolver) - val bucketColumns = - AttributeSet( - files.bucketSpec - .map(_.bucketColumnNames) - .getOrElse(Nil) - .map(l.resolveQuoted(_, files.sqlContext.conf.resolver) - .getOrElse(sys.error("")))) - // Partition keys are not available in the statistics of the files. val dataFilters = filters.filter(_.references.intersect(partitionSet).isEmpty) @@ -113,8 +106,9 @@ private[sql] object FileSourceStrategy extends Strategy with Logging { val readFile = files.fileFormat.buildReader( sqlContext = files.sqlContext, + dataSchema = files.dataSchema, partitionSchema = files.partitionSchema, - dataSchema = prunedDataSchema, + requiredSchema = prunedDataSchema, filters = pushedDownFilters, options = files.options) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala index 5501015775..b47328a3dd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.datasources.csv import scala.util.control.NonFatal -import org.apache.hadoop.fs.{FileStatus, Path} +import org.apache.hadoop.fs.Path import org.apache.hadoop.io.{NullWritable, Text} import org.apache.hadoop.mapreduce.RecordWriter import org.apache.hadoop.mapreduce.TaskAttemptContext @@ -30,6 +30,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.GenericMutableRow +import org.apache.spark.sql.execution.datasources.PartitionedFile import org.apache.spark.sql.sources._ import org.apache.spark.sql.types._ @@ -49,14 +50,10 @@ object CSVRelation extends Logging { }, true) } - def parseCsv( - tokenizedRDD: RDD[Array[String]], + def csvParser( schema: StructType, requiredColumns: Array[String], - inputs: Seq[FileStatus], - sqlContext: SQLContext, - params: CSVOptions): RDD[InternalRow] = { - + params: CSVOptions): Array[String] => Option[InternalRow] = { val schemaFields = schema.fields val requiredFields = StructType(requiredColumns.map(schema(_))).fields val safeRequiredFields = if (params.dropMalformed) { @@ -74,7 +71,8 @@ object CSVRelation extends Logging { } val requiredSize = requiredFields.length val row = new GenericMutableRow(requiredSize) - tokenizedRDD.flatMap { tokens => + + (tokens: Array[String]) => { if (params.dropMalformed && schemaFields.length != tokens.length) { logWarning(s"Dropping malformed line: ${tokens.mkString(params.delimiter.toString)}") None @@ -118,6 +116,33 @@ object CSVRelation extends Logging { } } } + + def parseCsv( + tokenizedRDD: RDD[Array[String]], + schema: StructType, + requiredColumns: Array[String], + options: CSVOptions): RDD[InternalRow] = { + val parser = csvParser(schema, requiredColumns, options) + tokenizedRDD.flatMap(parser(_).toSeq) + } + + // Skips the header line of each file if the `header` option is set to true. + def dropHeaderLine( + file: PartitionedFile, lines: Iterator[String], csvOptions: CSVOptions): Unit = { + // TODO What if the first partitioned file consists of only comments and empty lines? + if (csvOptions.headerFlag && file.start == 0) { + val nonEmptyLines = if (csvOptions.isCommentSet) { + val commentPrefix = csvOptions.comment.toString + lines.dropWhile { line => + line.trim.isEmpty || line.trim.startsWith(commentPrefix) + } + } else { + lines.dropWhile(_.trim.isEmpty) + } + + if (nonEmptyLines.hasNext) nonEmptyLines.drop(1) + } + } } private[sql] class CSVOutputWriterFactory(params: CSVOptions) extends OutputWriterFactory { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/DefaultSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/DefaultSource.scala index 54e4c1a2c9..6b6add48cd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/DefaultSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/DefaultSource.scala @@ -19,17 +19,19 @@ package org.apache.spark.sql.execution.datasources.csv import java.nio.charset.{Charset, StandardCharsets} +import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.FileStatus import org.apache.hadoop.io.{LongWritable, Text} import org.apache.hadoop.mapred.TextInputFormat -import org.apache.hadoop.mapreduce.Job +import org.apache.hadoop.mapreduce._ import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.RDD import org.apache.spark.sql.SQLContext import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.UnsafeProjection -import org.apache.spark.sql.execution.datasources.CompressionCodecs +import org.apache.spark.sql.catalyst.expressions.{JoinedRow, UnsafeProjection} +import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection +import org.apache.spark.sql.execution.datasources.{CompressionCodecs, HadoopFileLinesReader, PartitionedFile} import org.apache.spark.sql.sources._ import org.apache.spark.sql.types.{StringType, StructField, StructType} import org.apache.spark.util.SerializableConfiguration @@ -91,6 +93,46 @@ class DefaultSource extends FileFormat with DataSourceRegister { new CSVOutputWriterFactory(csvOptions) } + override def buildReader( + sqlContext: SQLContext, + dataSchema: StructType, + partitionSchema: StructType, + requiredSchema: StructType, + filters: Seq[Filter], + options: Map[String, String]): (PartitionedFile) => Iterator[InternalRow] = { + val csvOptions = new CSVOptions(options) + val headers = requiredSchema.fields.map(_.name) + + val conf = new Configuration(sqlContext.sparkContext.hadoopConfiguration) + val broadcastedConf = sqlContext.sparkContext.broadcast(new SerializableConfiguration(conf)) + + (file: PartitionedFile) => { + val lineIterator = { + val conf = broadcastedConf.value.value + new HadoopFileLinesReader(file, conf).map { line => + new String(line.getBytes, 0, line.getLength, csvOptions.charset) + } + } + + CSVRelation.dropHeaderLine(file, lineIterator, csvOptions) + + val unsafeRowIterator = { + val tokenizedIterator = new BulkCsvReader(lineIterator, csvOptions, headers) + val parser = CSVRelation.csvParser(dataSchema, requiredSchema.fieldNames, csvOptions) + tokenizedIterator.flatMap(parser(_).toSeq) + } + + // Appends partition values + val fullOutput = requiredSchema.toAttributes ++ partitionSchema.toAttributes + val joinedRow = new JoinedRow() + val appendPartitionColumns = GenerateUnsafeProjection.generate(fullOutput, fullOutput) + + unsafeRowIterator.map { dataRow => + appendPartitionColumns(joinedRow(dataRow, file.partitionValues)) + } + } + } + /** * This supports to eliminate unneeded columns before producing an RDD * containing all of its tuples as Row objects. This reads all the tokens of each line @@ -113,8 +155,7 @@ class DefaultSource extends FileFormat with DataSourceRegister { val pathsString = csvFiles.map(_.getPath.toUri.toString) val header = dataSchema.fields.map(_.name) val tokenizedRdd = tokenRdd(sqlContext, csvOptions, header, pathsString) - val rows = CSVRelation.parseCsv( - tokenizedRdd, dataSchema, requiredColumns, csvFiles, sqlContext, csvOptions) + val rows = CSVRelation.parseCsv(tokenizedRdd, dataSchema, requiredColumns, csvOptions) val requiredDataSchema = StructType(requiredColumns.map(c => dataSchema.find(_.name == c).get)) rows.mapPartitions { iterator => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala index 21fc1224ef..42cd25a18c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala @@ -124,8 +124,9 @@ class DefaultSource extends FileFormat with DataSourceRegister { override def buildReader( sqlContext: SQLContext, - partitionSchema: StructType, dataSchema: StructType, + partitionSchema: StructType, + requiredSchema: StructType, filters: Seq[Filter], options: Map[String, String]): PartitionedFile => Iterator[InternalRow] = { val conf = new Configuration(sqlContext.sparkContext.hadoopConfiguration) @@ -136,7 +137,7 @@ class DefaultSource extends FileFormat with DataSourceRegister { val columnNameOfCorruptRecord = parsedOptions.columnNameOfCorruptRecord .getOrElse(sqlContext.conf.columnNameOfCorruptRecord) - val fullSchema = dataSchema.toAttributes ++ partitionSchema.toAttributes + val fullSchema = requiredSchema.toAttributes ++ partitionSchema.toAttributes val joinedRow = new JoinedRow() file => { @@ -144,7 +145,7 @@ class DefaultSource extends FileFormat with DataSourceRegister { val rows = JacksonParser.parseJson( lines, - dataSchema, + requiredSchema, columnNameOfCorruptRecord, parsedOptions) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala index d6b84be267..5b58fa1fc5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala @@ -276,38 +276,26 @@ private[sql] class DefaultSource file.getName == ParquetFileWriter.PARQUET_METADATA_FILE } - /** - * Returns a function that can be used to read a single file in as an Iterator of InternalRow. - * - * @param partitionSchema The schema of the partition column row that will be present in each - * PartitionedFile. These columns should be prepended to the rows that - * are produced by the iterator. - * @param dataSchema The schema of the data that should be output for each row. This may be a - * subset of the columns that are present in the file if column pruning has - * occurred. - * @param filters A set of filters than can optionally be used to reduce the number of rows output - * @param options A set of string -> string configuration options. - * @return - */ override def buildReader( sqlContext: SQLContext, - partitionSchema: StructType, dataSchema: StructType, + partitionSchema: StructType, + requiredSchema: StructType, filters: Seq[Filter], options: Map[String, String]): PartitionedFile => Iterator[InternalRow] = { val parquetConf = new Configuration(sqlContext.sparkContext.hadoopConfiguration) parquetConf.set(ParquetInputFormat.READ_SUPPORT_CLASS, classOf[CatalystReadSupport].getName) parquetConf.set( CatalystReadSupport.SPARK_ROW_REQUESTED_SCHEMA, - CatalystSchemaConverter.checkFieldNames(dataSchema).json) + CatalystSchemaConverter.checkFieldNames(requiredSchema).json) parquetConf.set( CatalystWriteSupport.SPARK_ROW_SCHEMA, - CatalystSchemaConverter.checkFieldNames(dataSchema).json) + CatalystSchemaConverter.checkFieldNames(requiredSchema).json) // We want to clear this temporary metadata from saving into Parquet file. // This metadata is only useful for detecting optional columns when pushdowning filters. val dataSchemaToWrite = StructType.removeMetadata(StructType.metadataKeyForOptionalField, - dataSchema).asInstanceOf[StructType] + requiredSchema).asInstanceOf[StructType] CatalystWriteSupport.setSchema(dataSchemaToWrite, parquetConf) // Sets flags for `CatalystSchemaConverter` @@ -324,7 +312,7 @@ private[sql] class DefaultSource // Collects all converted Parquet filter predicates. Notice that not all predicates can be // converted (`ParquetFilters.createFilter` returns an `Option`). That's why a `flatMap` // is used here. - .flatMap(ParquetFilters.createFilter(dataSchema, _)) + .flatMap(ParquetFilters.createFilter(requiredSchema, _)) .reduceOption(FilterApi.and) } else { None @@ -394,7 +382,7 @@ private[sql] class DefaultSource enableVectorizedParquetReader) { iter.asInstanceOf[Iterator[InternalRow]] } else { - val fullSchema = dataSchema.toAttributes ++ partitionSchema.toAttributes + val fullSchema = requiredSchema.toAttributes ++ partitionSchema.toAttributes val joinedRow = new JoinedRow() val appendPartitionColumns = GenerateUnsafeProjection.generate(fullSchema, fullSchema) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/DefaultSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/DefaultSource.scala index d6ab5fc56e..99459ba1d3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/DefaultSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/DefaultSource.scala @@ -129,8 +129,9 @@ class DefaultSource extends FileFormat with DataSourceRegister { override def buildReader( sqlContext: SQLContext, - partitionSchema: StructType, dataSchema: StructType, + partitionSchema: StructType, + requiredSchema: StructType, filters: Seq[Filter], options: Map[String, String]): PartitionedFile => Iterator[InternalRow] = { val conf = new Configuration(sqlContext.sparkContext.hadoopConfiguration) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala index 1e02354edf..6b95a3d25b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala @@ -385,9 +385,9 @@ abstract class OutputWriter { * * @param location A [[FileCatalog]] that can enumerate the locations of all the files that comprise * this relation. - * @param partitionSchema The schmea of the columns (if any) that are used to partition the relation + * @param partitionSchema The schema of the columns (if any) that are used to partition the relation * @param dataSchema The schema of any remaining columns. Note that if any partition columns are - * present in the actual data files as well, they are removed. + * present in the actual data files as well, they are preserved. * @param bucketSpec Describes the bucketing (hash-partitioning of the files by some column values). * @param fileFormat A file format that can be used to read and write the data in files. * @param options Configuration used when reading / writing data. @@ -462,20 +462,24 @@ trait FileFormat { /** * Returns a function that can be used to read a single file in as an Iterator of InternalRow. * + * @param dataSchema The global data schema. It can be either specified by the user, or + * reconciled/merged from all underlying data files. If any partition columns + * are contained in the files, they are preserved in this schema. * @param partitionSchema The schema of the partition column row that will be present in each - * PartitionedFile. These columns should be prepended to the rows that + * PartitionedFile. These columns should be appended to the rows that * are produced by the iterator. - * @param dataSchema The schema of the data that should be output for each row. This may be a - * subset of the columns that are present in the file if column pruning has - * occurred. + * @param requiredSchema The schema of the data that should be output for each row. This may be a + * subset of the columns that are present in the file if column pruning has + * occurred. * @param filters A set of filters than can optionally be used to reduce the number of rows output * @param options A set of string -> string configuration options. * @return */ def buildReader( sqlContext: SQLContext, - partitionSchema: StructType, dataSchema: StructType, + partitionSchema: StructType, + requiredSchema: StructType, filters: Seq[Filter], options: Map[String, String]): PartitionedFile => Iterator[InternalRow] = { // TODO: Remove this default implementation when the other formats have been ported diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala index 45620bc965..717a3a80b7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala @@ -376,14 +376,15 @@ class TestFileFormat extends FileFormat { override def buildReader( sqlContext: SQLContext, - partitionSchema: StructType, dataSchema: StructType, + partitionSchema: StructType, + requiredSchema: StructType, filters: Seq[Filter], options: Map[String, String]): PartitionedFile => Iterator[InternalRow] = { // Record the arguments so they can be checked in the test case. LastArguments.partitionSchema = partitionSchema - LastArguments.dataSchema = dataSchema + LastArguments.dataSchema = requiredSchema LastArguments.filters = filters LastArguments.options = options diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala index 7c4a0a0c0f..43f445edcb 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala @@ -126,8 +126,9 @@ private[sql] class DefaultSource override def buildReader( sqlContext: SQLContext, - partitionSchema: StructType, dataSchema: StructType, + partitionSchema: StructType, + requiredSchema: StructType, filters: Seq[Filter], options: Map[String, String]): (PartitionedFile) => Iterator[InternalRow] = { val orcConf = new Configuration(sqlContext.sparkContext.hadoopConfiguration) @@ -145,15 +146,15 @@ private[sql] class DefaultSource (file: PartitionedFile) => { val conf = broadcastedConf.value.value - // SPARK-8501: Empty ORC files always have an empty schema stored in their footer. In this - // case, `OrcFileOperator.readSchema` returns `None`, and we can simply return an empty - // iterator. + // SPARK-8501: Empty ORC files always have an empty schema stored in their footer. In this + // case, `OrcFileOperator.readSchema` returns `None`, and we can't read the underlying file + // using the given physical schema. Instead, we simply return an empty iterator. val maybePhysicalSchema = OrcFileOperator.readSchema(Seq(file.filePath), Some(conf)) if (maybePhysicalSchema.isEmpty) { Iterator.empty } else { val physicalSchema = maybePhysicalSchema.get - OrcRelation.setRequiredColumns(conf, physicalSchema, dataSchema) + OrcRelation.setRequiredColumns(conf, physicalSchema, requiredSchema) val orcRecordReader = { val job = Job.getInstance(conf) @@ -171,11 +172,11 @@ private[sql] class DefaultSource // Unwraps `OrcStruct`s to `UnsafeRow`s val unsafeRowIterator = OrcRelation.unwrapOrcStructs( - file.filePath, conf, dataSchema, new RecordReaderIterator[OrcStruct](orcRecordReader) + file.filePath, conf, requiredSchema, new RecordReaderIterator[OrcStruct](orcRecordReader) ) // Appends partition values - val fullOutput = dataSchema.toAttributes ++ partitionSchema.toAttributes + val fullOutput = requiredSchema.toAttributes ++ partitionSchema.toAttributes val joinedRow = new JoinedRow() val appendPartitionColumns = GenerateUnsafeProjection.generate(fullOutput, fullOutput) -- cgit v1.2.3 From a9b93e07391faede77dde4c0b3c21c9b3f97f8eb Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Thu, 31 Mar 2016 09:25:09 -0700 Subject: [SPARK-14211][SQL] Remove ANTLR3 based parser ### What changes were proposed in this pull request? This PR removes the ANTLR3 based parser, and moves the new ANTLR4 based parser into the `org.apache.spark.sql.catalyst.parser package`. ### How was this patch tested? Existing unit tests. cc rxin andrewor14 yhuai Author: Herman van Hovell Closes #12071 from hvanhovell/SPARK-14211. --- dev/deps/spark-deps-hadoop-2.2 | 4 +- dev/deps/spark-deps-hadoop-2.3 | 4 +- dev/deps/spark-deps-hadoop-2.4 | 4 +- dev/deps/spark-deps-hadoop-2.6 | 4 +- dev/deps/spark-deps-hadoop-2.7 | 4 +- pom.xml | 6 - project/SparkBuild.scala | 54 +- project/plugins.sbt | 3 - python/pyspark/sql/utils.py | 2 +- sql/catalyst/pom.xml | 22 - .../spark/sql/catalyst/parser/ExpressionParser.g | 400 --- .../spark/sql/catalyst/parser/FromClauseParser.g | 341 --- .../spark/sql/catalyst/parser/IdentifiersParser.g | 184 -- .../spark/sql/catalyst/parser/KeywordParser.g | 244 -- .../spark/sql/catalyst/parser/SelectClauseParser.g | 235 -- .../spark/sql/catalyst/parser/SparkSqlLexer.g | 491 ---- .../spark/sql/catalyst/parser/SparkSqlParser.g | 2596 -------------------- .../apache/spark/sql/catalyst/parser/SqlBase.g4 | 943 +++++++ .../apache/spark/sql/catalyst/parser/ng/SqlBase.g4 | 941 ------- .../apache/spark/sql/catalyst/parser/ASTNode.scala | 99 - .../catalyst/parser/AbstractSparkSQLParser.scala | 145 -- .../spark/sql/catalyst/parser/AstBuilder.scala | 1460 +++++++++++ .../spark/sql/catalyst/parser/CatalystQl.scala | 933 ------- .../spark/sql/catalyst/parser/DataTypeParser.scala | 67 + .../spark/sql/catalyst/parser/ParseDriver.scala | 245 +- .../spark/sql/catalyst/parser/ParserConf.scala | 26 - .../spark/sql/catalyst/parser/ParserUtils.scala | 209 +- .../spark/sql/catalyst/parser/ng/AstBuilder.scala | 1452 ----------- .../spark/sql/catalyst/parser/ng/ParseDriver.scala | 240 -- .../spark/sql/catalyst/parser/ng/ParserUtils.scala | 118 - .../spark/sql/catalyst/parser/ASTNodeSuite.scala | 38 - .../sql/catalyst/parser/CatalystQlSuite.scala | 223 -- .../sql/catalyst/parser/DataTypeParserSuite.scala | 1 - .../sql/catalyst/parser/ErrorParserSuite.scala | 67 + .../catalyst/parser/ExpressionParserSuite.scala | 497 ++++ .../sql/catalyst/parser/PlanParserSuite.scala | 429 ++++ .../parser/TableIdentifierParserSuite.scala | 42 + .../sql/catalyst/parser/ng/ErrorParserSuite.scala | 67 - .../catalyst/parser/ng/ExpressionParserSuite.scala | 497 ---- .../sql/catalyst/parser/ng/PlanParserSuite.scala | 429 ---- .../parser/ng/TableIdentifierParserSuite.scala | 42 - .../org/apache/spark/sql/execution/SparkQl.scala | 387 --- .../spark/sql/execution/SparkSqlParser.scala | 6 +- .../command/AlterTableCommandParser.scala | 431 ---- .../org/apache/spark/sql/internal/SQLConf.scala | 20 +- .../spark/sql/execution/command/DDLSuite.scala | 13 +- sql/hive/pom.xml | 19 - .../spark/sql/hive/HiveMetastoreCatalog.scala | 27 +- .../scala/org/apache/spark/sql/hive/HiveQl.scala | 749 ------ .../spark/sql/hive/execution/HiveSqlParser.scala | 31 +- .../apache/spark/sql/hive/ErrorPositionSuite.scala | 4 +- .../org/apache/spark/sql/hive/HiveQlSuite.scala | 4 +- .../apache/spark/sql/hive/StatisticsSuite.scala | 5 +- 53 files changed, 3816 insertions(+), 11688 deletions(-) delete mode 100644 sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/ExpressionParser.g delete mode 100644 sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/FromClauseParser.g delete mode 100644 sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/IdentifiersParser.g delete mode 100644 sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/KeywordParser.g delete mode 100644 sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SelectClauseParser.g delete mode 100644 sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlLexer.g delete mode 100644 sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlParser.g create mode 100644 sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 delete mode 100644 sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/ng/SqlBase.g4 delete mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ASTNode.scala delete mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AbstractSparkSQLParser.scala create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala delete mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/CatalystQl.scala delete mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserConf.scala delete mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ng/AstBuilder.scala delete mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ng/ParseDriver.scala delete mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ng/ParserUtils.scala delete mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ASTNodeSuite.scala delete mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/CatalystQlSuite.scala create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ErrorParserSuite.scala create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableIdentifierParserSuite.scala delete mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ng/ErrorParserSuite.scala delete mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ng/ExpressionParserSuite.scala delete mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ng/PlanParserSuite.scala delete mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ng/TableIdentifierParserSuite.scala delete mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/SparkQl.scala delete mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/command/AlterTableCommandParser.scala delete mode 100644 sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala diff --git a/dev/deps/spark-deps-hadoop-2.2 b/dev/deps/spark-deps-hadoop-2.2 index 7c2f88bdb1..0c4e43b9c8 100644 --- a/dev/deps/spark-deps-hadoop-2.2 +++ b/dev/deps/spark-deps-hadoop-2.2 @@ -2,7 +2,8 @@ JavaEWAH-0.3.2.jar RoaringBitmap-0.5.11.jar ST4-4.0.4.jar activation-1.1.jar -antlr-runtime-3.5.2.jar +antlr-2.7.7.jar +antlr-runtime-3.4.jar antlr4-runtime-4.5.2-1.jar aopalliance-1.0.jar apache-log4j-extras-1.2.17.jar @@ -173,6 +174,7 @@ spire_2.11-0.7.4.jar stax-api-1.0-2.jar stax-api-1.0.1.jar stream-2.7.0.jar +stringtemplate-3.2.1.jar super-csv-2.2.0.jar univocity-parsers-1.5.6.jar unused-1.0.0.jar diff --git a/dev/deps/spark-deps-hadoop-2.3 b/dev/deps/spark-deps-hadoop-2.3 index f4d600038d..a0d62a1c30 100644 --- a/dev/deps/spark-deps-hadoop-2.3 +++ b/dev/deps/spark-deps-hadoop-2.3 @@ -2,7 +2,8 @@ JavaEWAH-0.3.2.jar RoaringBitmap-0.5.11.jar ST4-4.0.4.jar activation-1.1.1.jar -antlr-runtime-3.5.2.jar +antlr-2.7.7.jar +antlr-runtime-3.4.jar antlr4-runtime-4.5.2-1.jar aopalliance-1.0.jar apache-log4j-extras-1.2.17.jar @@ -164,6 +165,7 @@ spire_2.11-0.7.4.jar stax-api-1.0-2.jar stax-api-1.0.1.jar stream-2.7.0.jar +stringtemplate-3.2.1.jar super-csv-2.2.0.jar univocity-parsers-1.5.6.jar unused-1.0.0.jar diff --git a/dev/deps/spark-deps-hadoop-2.4 b/dev/deps/spark-deps-hadoop-2.4 index 7c5e2c35bd..cc6e40329c 100644 --- a/dev/deps/spark-deps-hadoop-2.4 +++ b/dev/deps/spark-deps-hadoop-2.4 @@ -2,7 +2,8 @@ JavaEWAH-0.3.2.jar RoaringBitmap-0.5.11.jar ST4-4.0.4.jar activation-1.1.1.jar -antlr-runtime-3.5.2.jar +antlr-2.7.7.jar +antlr-runtime-3.4.jar antlr4-runtime-4.5.2-1.jar aopalliance-1.0.jar apache-log4j-extras-1.2.17.jar @@ -165,6 +166,7 @@ spire_2.11-0.7.4.jar stax-api-1.0-2.jar stax-api-1.0.1.jar stream-2.7.0.jar +stringtemplate-3.2.1.jar super-csv-2.2.0.jar univocity-parsers-1.5.6.jar unused-1.0.0.jar diff --git a/dev/deps/spark-deps-hadoop-2.6 b/dev/deps/spark-deps-hadoop-2.6 index 03d9a51057..5c93db5082 100644 --- a/dev/deps/spark-deps-hadoop-2.6 +++ b/dev/deps/spark-deps-hadoop-2.6 @@ -2,7 +2,8 @@ JavaEWAH-0.3.2.jar RoaringBitmap-0.5.11.jar ST4-4.0.4.jar activation-1.1.1.jar -antlr-runtime-3.5.2.jar +antlr-2.7.7.jar +antlr-runtime-3.4.jar antlr4-runtime-4.5.2-1.jar aopalliance-1.0.jar apache-log4j-extras-1.2.17.jar @@ -171,6 +172,7 @@ spire_2.11-0.7.4.jar stax-api-1.0-2.jar stax-api-1.0.1.jar stream-2.7.0.jar +stringtemplate-3.2.1.jar super-csv-2.2.0.jar univocity-parsers-1.5.6.jar unused-1.0.0.jar diff --git a/dev/deps/spark-deps-hadoop-2.7 b/dev/deps/spark-deps-hadoop-2.7 index 5765071a1c..860fd79aad 100644 --- a/dev/deps/spark-deps-hadoop-2.7 +++ b/dev/deps/spark-deps-hadoop-2.7 @@ -2,7 +2,8 @@ JavaEWAH-0.3.2.jar RoaringBitmap-0.5.11.jar ST4-4.0.4.jar activation-1.1.1.jar -antlr-runtime-3.5.2.jar +antlr-2.7.7.jar +antlr-runtime-3.4.jar antlr4-runtime-4.5.2-1.jar aopalliance-1.0.jar apache-log4j-extras-1.2.17.jar @@ -172,6 +173,7 @@ spire_2.11-0.7.4.jar stax-api-1.0-2.jar stax-api-1.0.1.jar stream-2.7.0.jar +stringtemplate-3.2.1.jar super-csv-2.2.0.jar univocity-parsers-1.5.6.jar unused-1.0.0.jar diff --git a/pom.xml b/pom.xml index 1513a18b71..9dab0bca74 100644 --- a/pom.xml +++ b/pom.xml @@ -177,7 +177,6 @@ 3.5.2 1.3.9 0.9.2 - 3.5.2 4.5.2-1 ${java.home} @@ -1755,11 +1754,6 @@ - - org.antlr - antlr-runtime - ${antlr.version} - org.antlr antlr4-runtime diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 39a9e16f7e..5d62b688b9 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -403,59 +403,9 @@ object OldDeps { object Catalyst { lazy val settings = antlr4Settings ++ Seq( - antlr4PackageName in Antlr4 := Some("org.apache.spark.sql.catalyst.parser.ng"), + antlr4PackageName in Antlr4 := Some("org.apache.spark.sql.catalyst.parser"), antlr4GenListener in Antlr4 := true, - antlr4GenVisitor in Antlr4 := true, - // ANTLR code-generation step. - // - // This has been heavily inspired by com.github.stefri.sbt-antlr (0.5.3). It fixes a number of - // build errors in the current plugin. - // Create Parser from ANTLR grammar files. - sourceGenerators in Compile += Def.task { - val log = streams.value.log - - val grammarFileNames = Seq( - "SparkSqlLexer.g", - "SparkSqlParser.g") - val sourceDir = (sourceDirectory in Compile).value / "antlr3" - val targetDir = (sourceManaged in Compile).value / "antlr3" - - // Create default ANTLR Tool. - val antlr = new org.antlr.Tool - - // Setup input and output directories. - antlr.setInputDirectory(sourceDir.getPath) - antlr.setOutputDirectory(targetDir.getPath) - antlr.setForceRelativeOutput(true) - antlr.setMake(true) - - // Add grammar files. - grammarFileNames.flatMap(gFileName => (sourceDir ** gFileName).get).foreach { gFilePath => - val relGFilePath = (gFilePath relativeTo sourceDir).get.getPath - log.info("ANTLR: Grammar file '%s' detected.".format(relGFilePath)) - antlr.addGrammarFile(relGFilePath) - // We will set library directory multiple times here. However, only the - // last one has effect. Because the grammar files are located under the same directory, - // We assume there is only one library directory. - antlr.setLibDirectory(gFilePath.getParent) - } - - // Generate the parser. - antlr.process() - val errorState = org.antlr.tool.ErrorManager.getErrorState - if (errorState.errors > 0) { - sys.error("ANTLR: Caught %d build errors.".format(errorState.errors)) - } else if (errorState.warnings > 0) { - sys.error("ANTLR: Caught %d build warnings.".format(errorState.warnings)) - } - - // Return all generated java files. - (targetDir ** "*.java").get.toSeq - }.taskValue, - // Include ANTLR tokens files. - resourceGenerators in Compile += Def.task { - ((sourceManaged in Compile).value ** "*.tokens").get.toSeq - }.taskValue + antlr4GenVisitor in Antlr4 := true ) } diff --git a/project/plugins.sbt b/project/plugins.sbt index d9ed7962bf..4929ba3c4d 100644 --- a/project/plugins.sbt +++ b/project/plugins.sbt @@ -22,9 +22,6 @@ libraryDependencies += "org.ow2.asm" % "asm" % "5.0.3" libraryDependencies += "org.ow2.asm" % "asm-commons" % "5.0.3" -libraryDependencies += "org.antlr" % "antlr" % "3.5.2" - - // TODO I am not sure we want such a dep. resolvers += "simplytyped" at "http://simplytyped.github.io/repo/releases" diff --git a/python/pyspark/sql/utils.py b/python/pyspark/sql/utils.py index b89ea8c6e0..7ea0e0d5c9 100644 --- a/python/pyspark/sql/utils.py +++ b/python/pyspark/sql/utils.py @@ -55,7 +55,7 @@ def capture_sql_exception(f): e.java_exception.getStackTrace())) if s.startswith('org.apache.spark.sql.AnalysisException: '): raise AnalysisException(s.split(': ', 1)[1], stackTrace) - if s.startswith('org.apache.spark.sql.catalyst.parser.ng.ParseException: '): + if s.startswith('org.apache.spark.sql.catalyst.parser.ParseException: '): raise ParseException(s.split(': ', 1)[1], stackTrace) if s.startswith('java.lang.IllegalArgumentException: '): raise IllegalArgumentException(s.split(': ', 1)[1], stackTrace) diff --git a/sql/catalyst/pom.xml b/sql/catalyst/pom.xml index 9bfe495e90..1748fa2778 100644 --- a/sql/catalyst/pom.xml +++ b/sql/catalyst/pom.xml @@ -71,10 +71,6 @@ org.codehaus.janino janino - - org.antlr - antlr-runtime - org.antlr antlr4-runtime @@ -115,24 +111,6 @@ - - org.antlr - antlr3-maven-plugin - - - - antlr - - - - - ../catalyst/src/main/antlr3 - - **/SparkSqlLexer.g - **/SparkSqlParser.g - - - org.antlr antlr4-maven-plugin diff --git a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/ExpressionParser.g b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/ExpressionParser.g deleted file mode 100644 index 13a6a2d276..0000000000 --- a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/ExpressionParser.g +++ /dev/null @@ -1,400 +0,0 @@ -/** - Licensed to the Apache Software Foundation (ASF) under one or more - contributor license agreements. See the NOTICE file distributed with - this work for additional information regarding copyright ownership. - The ASF licenses this file to You under the Apache License, Version 2.0 - (the "License"); you may not use this file except in compliance with - the License. You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - - This file is an adaptation of Hive's org/apache/hadoop/hive/ql/IdentifiersParser.g grammar. -*/ - -parser grammar ExpressionParser; - -options -{ -output=AST; -ASTLabelType=CommonTree; -backtrack=false; -k=3; -} - -@members { - @Override - public Object recoverFromMismatchedSet(IntStream input, - RecognitionException re, BitSet follow) throws RecognitionException { - throw re; - } - @Override - public void displayRecognitionError(String[] tokenNames, - RecognitionException e) { - gParent.displayRecognitionError(tokenNames, e); - } - protected boolean useSQL11ReservedKeywordsForIdentifier() { - return gParent.useSQL11ReservedKeywordsForIdentifier(); - } -} - -@rulecatch { -catch (RecognitionException e) { - throw e; -} -} - -// fun(par1, par2, par3) -function -@init { gParent.pushMsg("function specification", state); } -@after { gParent.popMsg(state); } - : - functionName - LPAREN - ( - (STAR) => (star=STAR) - | (dist=KW_DISTINCT)? (selectExpression (COMMA selectExpression)*)? - ) - RPAREN (KW_OVER ws=window_specification)? - -> {$star != null}? ^(TOK_FUNCTIONSTAR functionName $ws?) - -> {$dist == null}? ^(TOK_FUNCTION functionName (selectExpression+)? $ws?) - -> ^(TOK_FUNCTIONDI functionName (selectExpression+)? $ws?) - ; - -functionName -@init { gParent.pushMsg("function name", state); } -@after { gParent.popMsg(state); } - : // Keyword IF is also a function name - (KW_IF | KW_ARRAY | KW_MAP | KW_STRUCT | KW_UNIONTYPE) => (KW_IF | KW_ARRAY | KW_MAP | KW_STRUCT | KW_UNIONTYPE) - | - (functionIdentifier) => functionIdentifier - | - {!useSQL11ReservedKeywordsForIdentifier()}? sql11ReservedKeywordsUsedAsCastFunctionName -> Identifier[$sql11ReservedKeywordsUsedAsCastFunctionName.text] - ; - -castExpression -@init { gParent.pushMsg("cast expression", state); } -@after { gParent.popMsg(state); } - : - KW_CAST - LPAREN - expression - KW_AS - primitiveType - RPAREN -> ^(TOK_FUNCTION primitiveType expression) - ; - -caseExpression -@init { gParent.pushMsg("case expression", state); } -@after { gParent.popMsg(state); } - : - KW_CASE expression - (KW_WHEN expression KW_THEN expression)+ - (KW_ELSE expression)? - KW_END -> ^(TOK_FUNCTION KW_CASE expression*) - ; - -whenExpression -@init { gParent.pushMsg("case expression", state); } -@after { gParent.popMsg(state); } - : - KW_CASE - ( KW_WHEN expression KW_THEN expression)+ - (KW_ELSE expression)? - KW_END -> ^(TOK_FUNCTION KW_WHEN expression*) - ; - -constant -@init { gParent.pushMsg("constant", state); } -@after { gParent.popMsg(state); } - : - Number - | dateLiteral - | timestampLiteral - | intervalLiteral - | StringLiteral - | stringLiteralSequence - | BigintLiteral - | SmallintLiteral - | TinyintLiteral - | DoubleLiteral - | booleanValue - ; - -stringLiteralSequence - : - StringLiteral StringLiteral+ -> ^(TOK_STRINGLITERALSEQUENCE StringLiteral StringLiteral+) - ; - -dateLiteral - : - KW_DATE StringLiteral -> - { - // Create DateLiteral token, but with the text of the string value - // This makes the dateLiteral more consistent with the other type literals. - adaptor.create(TOK_DATELITERAL, $StringLiteral.text) - } - | - KW_CURRENT_DATE -> ^(TOK_FUNCTION KW_CURRENT_DATE) - ; - -timestampLiteral - : - KW_TIMESTAMP StringLiteral -> - { - adaptor.create(TOK_TIMESTAMPLITERAL, $StringLiteral.text) - } - | - KW_CURRENT_TIMESTAMP -> ^(TOK_FUNCTION KW_CURRENT_TIMESTAMP) - ; - -intervalLiteral - : - (KW_INTERVAL intervalConstant KW_YEAR KW_TO KW_MONTH) => KW_INTERVAL intervalConstant KW_YEAR KW_TO KW_MONTH - -> ^(TOK_INTERVAL_YEAR_MONTH_LITERAL intervalConstant) - | (KW_INTERVAL intervalConstant KW_DAY KW_TO KW_SECOND) => KW_INTERVAL intervalConstant KW_DAY KW_TO KW_SECOND - -> ^(TOK_INTERVAL_DAY_TIME_LITERAL intervalConstant) - | KW_INTERVAL - ((intervalConstant KW_YEAR)=> year=intervalConstant KW_YEAR)? - ((intervalConstant KW_MONTH)=> month=intervalConstant KW_MONTH)? - ((intervalConstant KW_WEEK)=> week=intervalConstant KW_WEEK)? - ((intervalConstant KW_DAY)=> day=intervalConstant KW_DAY)? - ((intervalConstant KW_HOUR)=> hour=intervalConstant KW_HOUR)? - ((intervalConstant KW_MINUTE)=> minute=intervalConstant KW_MINUTE)? - ((intervalConstant KW_SECOND)=> second=intervalConstant KW_SECOND)? - ((intervalConstant KW_MILLISECOND)=> millisecond=intervalConstant KW_MILLISECOND)? - ((intervalConstant KW_MICROSECOND)=> microsecond=intervalConstant KW_MICROSECOND)? - -> ^(TOK_INTERVAL - ^(TOK_INTERVAL_YEAR_LITERAL $year?) - ^(TOK_INTERVAL_MONTH_LITERAL $month?) - ^(TOK_INTERVAL_WEEK_LITERAL $week?) - ^(TOK_INTERVAL_DAY_LITERAL $day?) - ^(TOK_INTERVAL_HOUR_LITERAL $hour?) - ^(TOK_INTERVAL_MINUTE_LITERAL $minute?) - ^(TOK_INTERVAL_SECOND_LITERAL $second?) - ^(TOK_INTERVAL_MILLISECOND_LITERAL $millisecond?) - ^(TOK_INTERVAL_MICROSECOND_LITERAL $microsecond?)) - ; - -intervalConstant - : - sign=(MINUS|PLUS)? value=Number -> { - adaptor.create(Number, ($sign != null ? $sign.getText() : "") + $value.getText()) - } - | StringLiteral - ; - -expression -@init { gParent.pushMsg("expression specification", state); } -@after { gParent.popMsg(state); } - : - precedenceOrExpression - ; - -atomExpression - : - (KW_NULL) => KW_NULL -> TOK_NULL - | (constant) => constant - | castExpression - | caseExpression - | whenExpression - | (functionName LPAREN) => function - | tableOrColumn - | (LPAREN KW_SELECT) => subQueryExpression - -> ^(TOK_SUBQUERY_EXPR ^(TOK_SUBQUERY_OP) subQueryExpression) - | LPAREN! expression RPAREN! - ; - - -precedenceFieldExpression - : - atomExpression ((LSQUARE^ expression RSQUARE!) | (DOT^ identifier))* - ; - -precedenceUnaryOperator - : - PLUS | MINUS | TILDE - ; - -nullCondition - : - KW_NULL -> ^(TOK_ISNULL) - | KW_NOT KW_NULL -> ^(TOK_ISNOTNULL) - ; - -precedenceUnaryPrefixExpression - : - (precedenceUnaryOperator+)=> precedenceUnaryOperator^ precedenceUnaryPrefixExpression - | precedenceFieldExpression - ; - -precedenceUnarySuffixExpression - : - ( - (LPAREN precedenceUnaryPrefixExpression RPAREN) => LPAREN precedenceUnaryPrefixExpression (a=KW_IS nullCondition)? RPAREN - | - precedenceUnaryPrefixExpression (a=KW_IS nullCondition)? - ) - -> {$a != null}? ^(TOK_FUNCTION nullCondition precedenceUnaryPrefixExpression) - -> precedenceUnaryPrefixExpression - ; - - -precedenceBitwiseXorOperator - : - BITWISEXOR - ; - -precedenceBitwiseXorExpression - : - precedenceUnarySuffixExpression (precedenceBitwiseXorOperator^ precedenceUnarySuffixExpression)* - ; - - -precedenceStarOperator - : - STAR | DIVIDE | MOD | DIV - ; - -precedenceStarExpression - : - precedenceBitwiseXorExpression (precedenceStarOperator^ precedenceBitwiseXorExpression)* - ; - - -precedencePlusOperator - : - PLUS | MINUS - ; - -precedencePlusExpression - : - precedenceStarExpression (precedencePlusOperator^ precedenceStarExpression)* - ; - - -precedenceAmpersandOperator - : - AMPERSAND - ; - -precedenceAmpersandExpression - : - precedencePlusExpression (precedenceAmpersandOperator^ precedencePlusExpression)* - ; - - -precedenceBitwiseOrOperator - : - BITWISEOR - ; - -precedenceBitwiseOrExpression - : - precedenceAmpersandExpression (precedenceBitwiseOrOperator^ precedenceAmpersandExpression)* - ; - - -// Equal operators supporting NOT prefix -precedenceEqualNegatableOperator - : - KW_LIKE | KW_RLIKE | KW_REGEXP - ; - -precedenceEqualOperator - : - precedenceEqualNegatableOperator | EQUAL | EQUAL_NS | NOTEQUAL | LESSTHANOREQUALTO | LESSTHAN | GREATERTHANOREQUALTO | GREATERTHAN - ; - -subQueryExpression - : - LPAREN! selectStatement[true] RPAREN! - ; - -precedenceEqualExpression - : - (LPAREN precedenceBitwiseOrExpression COMMA) => precedenceEqualExpressionMutiple - | - precedenceEqualExpressionSingle - ; - -precedenceEqualExpressionSingle - : - (left=precedenceBitwiseOrExpression -> $left) - ( - (KW_NOT precedenceEqualNegatableOperator notExpr=precedenceBitwiseOrExpression) - -> ^(KW_NOT ^(precedenceEqualNegatableOperator $precedenceEqualExpressionSingle $notExpr)) - | (precedenceEqualOperator equalExpr=precedenceBitwiseOrExpression) - -> ^(precedenceEqualOperator $precedenceEqualExpressionSingle $equalExpr) - | (KW_NOT KW_IN LPAREN KW_SELECT)=> (KW_NOT KW_IN subQueryExpression) - -> ^(KW_NOT ^(TOK_SUBQUERY_EXPR ^(TOK_SUBQUERY_OP KW_IN) subQueryExpression $precedenceEqualExpressionSingle)) - | (KW_NOT KW_IN expressions) - -> ^(KW_NOT ^(TOK_FUNCTION KW_IN $precedenceEqualExpressionSingle expressions)) - | (KW_IN LPAREN KW_SELECT)=> (KW_IN subQueryExpression) - -> ^(TOK_SUBQUERY_EXPR ^(TOK_SUBQUERY_OP KW_IN) subQueryExpression $precedenceEqualExpressionSingle) - | (KW_IN expressions) - -> ^(TOK_FUNCTION KW_IN $precedenceEqualExpressionSingle expressions) - | ( KW_NOT KW_BETWEEN (min=precedenceBitwiseOrExpression) KW_AND (max=precedenceBitwiseOrExpression) ) - -> ^(TOK_FUNCTION Identifier["between"] KW_TRUE $left $min $max) - | ( KW_BETWEEN (min=precedenceBitwiseOrExpression) KW_AND (max=precedenceBitwiseOrExpression) ) - -> ^(TOK_FUNCTION Identifier["between"] KW_FALSE $left $min $max) - )* - | (KW_EXISTS LPAREN KW_SELECT)=> (KW_EXISTS subQueryExpression) -> ^(TOK_SUBQUERY_EXPR ^(TOK_SUBQUERY_OP KW_EXISTS) subQueryExpression) - ; - -expressions - : - LPAREN expression (COMMA expression)* RPAREN -> expression+ - ; - -//we transform the (col0, col1) in ((v00,v01),(v10,v11)) into struct(col0, col1) in (struct(v00,v01),struct(v10,v11)) -precedenceEqualExpressionMutiple - : - (LPAREN precedenceBitwiseOrExpression (COMMA precedenceBitwiseOrExpression)+ RPAREN -> ^(TOK_FUNCTION Identifier["struct"] precedenceBitwiseOrExpression+)) - ( (KW_IN LPAREN expressionsToStruct (COMMA expressionsToStruct)+ RPAREN) - -> ^(TOK_FUNCTION KW_IN $precedenceEqualExpressionMutiple expressionsToStruct+) - | (KW_NOT KW_IN LPAREN expressionsToStruct (COMMA expressionsToStruct)+ RPAREN) - -> ^(KW_NOT ^(TOK_FUNCTION KW_IN $precedenceEqualExpressionMutiple expressionsToStruct+))) - ; - -expressionsToStruct - : - LPAREN expression (COMMA expression)* RPAREN -> ^(TOK_FUNCTION Identifier["struct"] expression+) - ; - -precedenceNotOperator - : - KW_NOT - ; - -precedenceNotExpression - : - (precedenceNotOperator^)* precedenceEqualExpression - ; - - -precedenceAndOperator - : - KW_AND - ; - -precedenceAndExpression - : - precedenceNotExpression (precedenceAndOperator^ precedenceNotExpression)* - ; - - -precedenceOrOperator - : - KW_OR - ; - -precedenceOrExpression - : - precedenceAndExpression (precedenceOrOperator^ precedenceAndExpression)* - ; diff --git a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/FromClauseParser.g b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/FromClauseParser.g deleted file mode 100644 index 1bf461c912..0000000000 --- a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/FromClauseParser.g +++ /dev/null @@ -1,341 +0,0 @@ -/** - Licensed to the Apache Software Foundation (ASF) under one or more - contributor license agreements. See the NOTICE file distributed with - this work for additional information regarding copyright ownership. - The ASF licenses this file to You under the Apache License, Version 2.0 - (the "License"); you may not use this file except in compliance with - the License. You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - - This file is an adaptation of Hive's org/apache/hadoop/hive/ql/FromClauseParser.g grammar. -*/ -parser grammar FromClauseParser; - -options -{ -output=AST; -ASTLabelType=CommonTree; -backtrack=false; -k=3; -} - -@members { - @Override - public Object recoverFromMismatchedSet(IntStream input, - RecognitionException re, BitSet follow) throws RecognitionException { - throw re; - } - @Override - public void displayRecognitionError(String[] tokenNames, - RecognitionException e) { - gParent.displayRecognitionError(tokenNames, e); - } - protected boolean useSQL11ReservedKeywordsForIdentifier() { - return gParent.useSQL11ReservedKeywordsForIdentifier(); - } -} - -@rulecatch { -catch (RecognitionException e) { - throw e; -} -} - -//----------------------------------------------------------------------------------- - -tableAllColumns - : STAR - -> ^(TOK_ALLCOLREF) - | tableName DOT STAR - -> ^(TOK_ALLCOLREF tableName) - ; - -// (table|column) -tableOrColumn -@init { gParent.pushMsg("table or column identifier", state); } -@after { gParent.popMsg(state); } - : - identifier -> ^(TOK_TABLE_OR_COL identifier) - ; - -expressionList -@init { gParent.pushMsg("expression list", state); } -@after { gParent.popMsg(state); } - : - expression (COMMA expression)* -> ^(TOK_EXPLIST expression+) - ; - -aliasList -@init { gParent.pushMsg("alias list", state); } -@after { gParent.popMsg(state); } - : - identifier (COMMA identifier)* -> ^(TOK_ALIASLIST identifier+) - ; - -//----------------------- Rules for parsing fromClause ------------------------------ -// from [col1, col2, col3] table1, [col4, col5] table2 -fromClause -@init { gParent.pushMsg("from clause", state); } -@after { gParent.popMsg(state); } - : - KW_FROM joinSource -> ^(TOK_FROM joinSource) - ; - -joinSource -@init { gParent.pushMsg("join source", state); } -@after { gParent.popMsg(state); } - : fromSource ( joinToken^ fromSource ( joinCond {$joinToken.start.getType() != COMMA}? )? )* - | uniqueJoinToken^ uniqueJoinSource (COMMA! uniqueJoinSource)+ - ; - -joinCond -@init { gParent.pushMsg("join expression list", state); } -@after { gParent.popMsg(state); } - : KW_ON! expression - | KW_USING LPAREN columnNameList RPAREN -> ^(TOK_USING columnNameList) - ; - -uniqueJoinSource -@init { gParent.pushMsg("unique join source", state); } -@after { gParent.popMsg(state); } - : KW_PRESERVE? fromSource uniqueJoinExpr - ; - -uniqueJoinExpr -@init { gParent.pushMsg("unique join expression list", state); } -@after { gParent.popMsg(state); } - : LPAREN e1+=expression (COMMA e1+=expression)* RPAREN - -> ^(TOK_EXPLIST $e1*) - ; - -uniqueJoinToken -@init { gParent.pushMsg("unique join", state); } -@after { gParent.popMsg(state); } - : KW_UNIQUEJOIN -> TOK_UNIQUEJOIN; - -joinToken -@init { gParent.pushMsg("join type specifier", state); } -@after { gParent.popMsg(state); } - : - KW_JOIN -> TOK_JOIN - | KW_INNER KW_JOIN -> TOK_JOIN - | KW_NATURAL KW_JOIN -> TOK_NATURALJOIN - | KW_NATURAL KW_INNER KW_JOIN -> TOK_NATURALJOIN - | COMMA -> TOK_JOIN - | KW_CROSS KW_JOIN -> TOK_CROSSJOIN - | KW_LEFT (KW_OUTER)? KW_JOIN -> TOK_LEFTOUTERJOIN - | KW_RIGHT (KW_OUTER)? KW_JOIN -> TOK_RIGHTOUTERJOIN - | KW_FULL (KW_OUTER)? KW_JOIN -> TOK_FULLOUTERJOIN - | KW_NATURAL KW_LEFT (KW_OUTER)? KW_JOIN -> TOK_NATURALLEFTOUTERJOIN - | KW_NATURAL KW_RIGHT (KW_OUTER)? KW_JOIN -> TOK_NATURALRIGHTOUTERJOIN - | KW_NATURAL KW_FULL (KW_OUTER)? KW_JOIN -> TOK_NATURALFULLOUTERJOIN - | KW_LEFT KW_SEMI KW_JOIN -> TOK_LEFTSEMIJOIN - | KW_ANTI KW_JOIN -> TOK_ANTIJOIN - ; - -lateralView -@init {gParent.pushMsg("lateral view", state); } -@after {gParent.popMsg(state); } - : - (KW_LATERAL KW_VIEW KW_OUTER) => KW_LATERAL KW_VIEW KW_OUTER function tableAlias (KW_AS identifier ((COMMA)=> COMMA identifier)*)? - -> ^(TOK_LATERAL_VIEW_OUTER ^(TOK_SELECT ^(TOK_SELEXPR function identifier* tableAlias))) - | - KW_LATERAL KW_VIEW function tableAlias (KW_AS identifier ((COMMA)=> COMMA identifier)*)? - -> ^(TOK_LATERAL_VIEW ^(TOK_SELECT ^(TOK_SELEXPR function identifier* tableAlias))) - ; - -tableAlias -@init {gParent.pushMsg("table alias", state); } -@after {gParent.popMsg(state); } - : - identifier -> ^(TOK_TABALIAS identifier) - ; - -fromSource -@init { gParent.pushMsg("from source", state); } -@after { gParent.popMsg(state); } - : - (LPAREN KW_VALUES) => fromSource0 - | fromSource0 - | (LPAREN joinSource) => LPAREN joinSource RPAREN -> joinSource - ; - - -fromSource0 -@init { gParent.pushMsg("from source 0", state); } -@after { gParent.popMsg(state); } - : - ((Identifier LPAREN)=> partitionedTableFunction | tableSource | subQuerySource | virtualTableSource) (lateralView^)* - ; - -tableBucketSample -@init { gParent.pushMsg("table bucket sample specification", state); } -@after { gParent.popMsg(state); } - : - KW_TABLESAMPLE LPAREN KW_BUCKET (numerator=Number) KW_OUT KW_OF (denominator=Number) (KW_ON expr+=expression (COMMA expr+=expression)*)? RPAREN -> ^(TOK_TABLEBUCKETSAMPLE $numerator $denominator $expr*) - ; - -splitSample -@init { gParent.pushMsg("table split sample specification", state); } -@after { gParent.popMsg(state); } - : - KW_TABLESAMPLE LPAREN (numerator=Number) (percent=KW_PERCENT|KW_ROWS) RPAREN - -> {percent != null}? ^(TOK_TABLESPLITSAMPLE TOK_PERCENT $numerator) - -> ^(TOK_TABLESPLITSAMPLE TOK_ROWCOUNT $numerator) - | - KW_TABLESAMPLE LPAREN (numerator=ByteLengthLiteral) RPAREN - -> ^(TOK_TABLESPLITSAMPLE TOK_LENGTH $numerator) - ; - -tableSample -@init { gParent.pushMsg("table sample specification", state); } -@after { gParent.popMsg(state); } - : - tableBucketSample | - splitSample - ; - -tableSource -@init { gParent.pushMsg("table source", state); } -@after { gParent.popMsg(state); } - : tabname=tableName - ((tableProperties) => props=tableProperties)? - ((tableSample) => ts=tableSample)? - ((KW_AS) => (KW_AS alias=Identifier) - | - (Identifier) => (alias=Identifier))? - -> ^(TOK_TABREF $tabname $props? $ts? $alias?) - ; - -tableName -@init { gParent.pushMsg("table name", state); } -@after { gParent.popMsg(state); } - : - id1=identifier (DOT id2=identifier)? - -> ^(TOK_TABNAME $id1 $id2?) - ; - -viewName -@init { gParent.pushMsg("view name", state); } -@after { gParent.popMsg(state); } - : - (db=identifier DOT)? view=identifier - -> ^(TOK_TABNAME $db? $view) - ; - -subQuerySource -@init { gParent.pushMsg("subquery source", state); } -@after { gParent.popMsg(state); } - : - LPAREN queryStatementExpression[false] RPAREN KW_AS? identifier -> ^(TOK_SUBQUERY queryStatementExpression identifier) - ; - -//---------------------- Rules for parsing PTF clauses ----------------------------- -partitioningSpec -@init { gParent.pushMsg("partitioningSpec clause", state); } -@after { gParent.popMsg(state); } - : - partitionByClause orderByClause? -> ^(TOK_PARTITIONINGSPEC partitionByClause orderByClause?) | - orderByClause -> ^(TOK_PARTITIONINGSPEC orderByClause) | - distributeByClause sortByClause? -> ^(TOK_PARTITIONINGSPEC distributeByClause sortByClause?) | - sortByClause -> ^(TOK_PARTITIONINGSPEC sortByClause) | - clusterByClause -> ^(TOK_PARTITIONINGSPEC clusterByClause) - ; - -partitionTableFunctionSource -@init { gParent.pushMsg("partitionTableFunctionSource clause", state); } -@after { gParent.popMsg(state); } - : - subQuerySource | - tableSource | - partitionedTableFunction - ; - -partitionedTableFunction -@init { gParent.pushMsg("ptf clause", state); } -@after { gParent.popMsg(state); } - : - name=Identifier LPAREN KW_ON - ((partitionTableFunctionSource) => (ptfsrc=partitionTableFunctionSource spec=partitioningSpec?)) - ((Identifier LPAREN expression RPAREN ) => Identifier LPAREN expression RPAREN ( COMMA Identifier LPAREN expression RPAREN)*)? - ((RPAREN) => (RPAREN)) ((Identifier) => alias=Identifier)? - -> ^(TOK_PTBLFUNCTION $name $alias? $ptfsrc $spec? expression*) - ; - -//----------------------- Rules for parsing whereClause ----------------------------- -// where a=b and ... -whereClause -@init { gParent.pushMsg("where clause", state); } -@after { gParent.popMsg(state); } - : - KW_WHERE searchCondition -> ^(TOK_WHERE searchCondition) - ; - -searchCondition -@init { gParent.pushMsg("search condition", state); } -@after { gParent.popMsg(state); } - : - expression - ; - -//----------------------------------------------------------------------------------- - -//-------- Row Constructor ---------------------------------------------------------- -//in support of SELECT * FROM (VALUES(1,2,3),(4,5,6),...) as FOO(a,b,c) and -// INSERT INTO (col1,col2,...) VALUES(...),(...),... -// INSERT INTO
(col1,col2,...) SELECT * FROM (VALUES(1,2,3),(4,5,6),...) as Foo(a,b,c) -valueRowConstructor -@init { gParent.pushMsg("value row constructor", state); } -@after { gParent.popMsg(state); } - : - LPAREN precedenceUnaryPrefixExpression (COMMA precedenceUnaryPrefixExpression)* RPAREN -> ^(TOK_VALUE_ROW precedenceUnaryPrefixExpression+) - ; - -valuesTableConstructor -@init { gParent.pushMsg("values table constructor", state); } -@after { gParent.popMsg(state); } - : - valueRowConstructor (COMMA valueRowConstructor)* -> ^(TOK_VALUES_TABLE valueRowConstructor+) - ; - -/* -VALUES(1),(2) means 2 rows, 1 column each. -VALUES(1,2),(3,4) means 2 rows, 2 columns each. -VALUES(1,2,3) means 1 row, 3 columns -*/ -valuesClause -@init { gParent.pushMsg("values clause", state); } -@after { gParent.popMsg(state); } - : - KW_VALUES valuesTableConstructor -> valuesTableConstructor - ; - -/* -This represents a clause like this: -(VALUES(1,2),(2,3)) as VirtTable(col1,col2) -*/ -virtualTableSource -@init { gParent.pushMsg("virtual table source", state); } -@after { gParent.popMsg(state); } - : - LPAREN valuesClause RPAREN tableNameColList -> ^(TOK_VIRTUAL_TABLE tableNameColList valuesClause) - ; -/* -e.g. as VirtTable(col1,col2) -Note that we only want literals as column names -*/ -tableNameColList -@init { gParent.pushMsg("from source", state); } -@after { gParent.popMsg(state); } - : - KW_AS? identifier LPAREN identifier (COMMA identifier)* RPAREN -> ^(TOK_VIRTUAL_TABREF ^(TOK_TABNAME identifier) ^(TOK_COL_NAME identifier+)) - ; - -//----------------------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/IdentifiersParser.g b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/IdentifiersParser.g deleted file mode 100644 index 916eb6a7ac..0000000000 --- a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/IdentifiersParser.g +++ /dev/null @@ -1,184 +0,0 @@ -/** - Licensed to the Apache Software Foundation (ASF) under one or more - contributor license agreements. See the NOTICE file distributed with - this work for additional information regarding copyright ownership. - The ASF licenses this file to You under the Apache License, Version 2.0 - (the "License"); you may not use this file except in compliance with - the License. You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - - This file is an adaptation of Hive's org/apache/hadoop/hive/ql/IdentifiersParser.g grammar. -*/ -parser grammar IdentifiersParser; - -options -{ -output=AST; -ASTLabelType=CommonTree; -backtrack=false; -k=3; -} - -@members { - @Override - public Object recoverFromMismatchedSet(IntStream input, - RecognitionException re, BitSet follow) throws RecognitionException { - throw re; - } - @Override - public void displayRecognitionError(String[] tokenNames, - RecognitionException e) { - gParent.displayRecognitionError(tokenNames, e); - } - protected boolean useSQL11ReservedKeywordsForIdentifier() { - return gParent.useSQL11ReservedKeywordsForIdentifier(); - } -} - -@rulecatch { -catch (RecognitionException e) { - throw e; -} -} - -//----------------------------------------------------------------------------------- - -// group by a,b -groupByClause -@init { gParent.pushMsg("group by clause", state); } -@after { gParent.popMsg(state); } - : - KW_GROUP KW_BY - expression - ( COMMA expression)* - ((rollup=KW_WITH KW_ROLLUP) | (cube=KW_WITH KW_CUBE)) ? - (sets=KW_GROUPING KW_SETS - LPAREN groupingSetExpression ( COMMA groupingSetExpression)* RPAREN ) ? - -> {rollup != null}? ^(TOK_ROLLUP_GROUPBY expression+) - -> {cube != null}? ^(TOK_CUBE_GROUPBY expression+) - -> {sets != null}? ^(TOK_GROUPING_SETS expression+ groupingSetExpression+) - -> ^(TOK_GROUPBY expression+) - ; - -groupingSetExpression -@init {gParent.pushMsg("grouping set expression", state); } -@after {gParent.popMsg(state); } - : - (LPAREN) => groupingSetExpressionMultiple - | - groupingExpressionSingle - ; - -groupingSetExpressionMultiple -@init {gParent.pushMsg("grouping set part expression", state); } -@after {gParent.popMsg(state); } - : - LPAREN - expression? (COMMA expression)* - RPAREN - -> ^(TOK_GROUPING_SETS_EXPRESSION expression*) - ; - -groupingExpressionSingle -@init { gParent.pushMsg("groupingExpression expression", state); } -@after { gParent.popMsg(state); } - : - expression -> ^(TOK_GROUPING_SETS_EXPRESSION expression) - ; - -havingClause -@init { gParent.pushMsg("having clause", state); } -@after { gParent.popMsg(state); } - : - KW_HAVING havingCondition -> ^(TOK_HAVING havingCondition) - ; - -havingCondition -@init { gParent.pushMsg("having condition", state); } -@after { gParent.popMsg(state); } - : - expression - ; - -expressionsInParenthese - : - LPAREN expression (COMMA expression)* RPAREN -> expression+ - ; - -expressionsNotInParenthese - : - expression (COMMA expression)* -> expression+ - ; - -columnRefOrderInParenthese - : - LPAREN columnRefOrder (COMMA columnRefOrder)* RPAREN -> columnRefOrder+ - ; - -columnRefOrderNotInParenthese - : - columnRefOrder (COMMA columnRefOrder)* -> columnRefOrder+ - ; - -// order by a,b -orderByClause -@init { gParent.pushMsg("order by clause", state); } -@after { gParent.popMsg(state); } - : - KW_ORDER KW_BY columnRefOrder ( COMMA columnRefOrder)* -> ^(TOK_ORDERBY columnRefOrder+) - ; - -clusterByClause -@init { gParent.pushMsg("cluster by clause", state); } -@after { gParent.popMsg(state); } - : - KW_CLUSTER KW_BY - ( - (LPAREN) => expressionsInParenthese -> ^(TOK_CLUSTERBY expressionsInParenthese) - | - expressionsNotInParenthese -> ^(TOK_CLUSTERBY expressionsNotInParenthese) - ) - ; - -partitionByClause -@init { gParent.pushMsg("partition by clause", state); } -@after { gParent.popMsg(state); } - : - KW_PARTITION KW_BY - ( - (LPAREN) => expressionsInParenthese -> ^(TOK_DISTRIBUTEBY expressionsInParenthese) - | - expressionsNotInParenthese -> ^(TOK_DISTRIBUTEBY expressionsNotInParenthese) - ) - ; - -distributeByClause -@init { gParent.pushMsg("distribute by clause", state); } -@after { gParent.popMsg(state); } - : - KW_DISTRIBUTE KW_BY - ( - (LPAREN) => expressionsInParenthese -> ^(TOK_DISTRIBUTEBY expressionsInParenthese) - | - expressionsNotInParenthese -> ^(TOK_DISTRIBUTEBY expressionsNotInParenthese) - ) - ; - -sortByClause -@init { gParent.pushMsg("sort by clause", state); } -@after { gParent.popMsg(state); } - : - KW_SORT KW_BY - ( - (LPAREN) => columnRefOrderInParenthese -> ^(TOK_SORTBY columnRefOrderInParenthese) - | - columnRefOrderNotInParenthese -> ^(TOK_SORTBY columnRefOrderNotInParenthese) - ) - ; diff --git a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/KeywordParser.g b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/KeywordParser.g deleted file mode 100644 index 12cd5f54a0..0000000000 --- a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/KeywordParser.g +++ /dev/null @@ -1,244 +0,0 @@ -/** - Licensed to the Apache Software Foundation (ASF) under one or more - contributor license agreements. See the NOTICE file distributed with - this work for additional information regarding copyright ownership. - The ASF licenses this file to You under the Apache License, Version 2.0 - (the "License"); you may not use this file except in compliance with - the License. You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - - This file is an adaptation of Hive's org/apache/hadoop/hive/ql/IdentifiersParser.g grammar. -*/ - -parser grammar KeywordParser; - -options -{ -output=AST; -ASTLabelType=CommonTree; -backtrack=false; -k=3; -} - -@members { - @Override - public Object recoverFromMismatchedSet(IntStream input, - RecognitionException re, BitSet follow) throws RecognitionException { - throw re; - } - @Override - public void displayRecognitionError(String[] tokenNames, - RecognitionException e) { - gParent.displayRecognitionError(tokenNames, e); - } - protected boolean useSQL11ReservedKeywordsForIdentifier() { - return gParent.useSQL11ReservedKeywordsForIdentifier(); - } -} - -@rulecatch { -catch (RecognitionException e) { - throw e; -} -} - -booleanValue - : - KW_TRUE^ | KW_FALSE^ - ; - -booleanValueTok - : - KW_TRUE -> TOK_TRUE - | KW_FALSE -> TOK_FALSE - ; - -tableOrPartition - : - tableName partitionSpec? -> ^(TOK_TAB tableName partitionSpec?) - ; - -partitionSpec - : - KW_PARTITION - LPAREN partitionVal (COMMA partitionVal )* RPAREN -> ^(TOK_PARTSPEC partitionVal +) - ; - -partitionVal - : - identifier (EQUAL constant)? -> ^(TOK_PARTVAL identifier constant?) - ; - -dropPartitionSpec - : - KW_PARTITION - LPAREN dropPartitionVal (COMMA dropPartitionVal )* RPAREN -> ^(TOK_PARTSPEC dropPartitionVal +) - ; - -dropPartitionVal - : - identifier dropPartitionOperator constant -> ^(TOK_PARTVAL identifier dropPartitionOperator constant) - ; - -dropPartitionOperator - : - EQUAL | NOTEQUAL | LESSTHANOREQUALTO | LESSTHAN | GREATERTHANOREQUALTO | GREATERTHAN - ; - -sysFuncNames - : - KW_AND - | KW_OR - | KW_NOT - | KW_LIKE - | KW_IF - | KW_CASE - | KW_WHEN - | KW_TINYINT - | KW_SMALLINT - | KW_INT - | KW_BIGINT - | KW_FLOAT - | KW_DOUBLE - | KW_BOOLEAN - | KW_STRING - | KW_BINARY - | KW_ARRAY - | KW_MAP - | KW_STRUCT - | KW_UNIONTYPE - | EQUAL - | EQUAL_NS - | NOTEQUAL - | LESSTHANOREQUALTO - | LESSTHAN - | GREATERTHANOREQUALTO - | GREATERTHAN - | DIVIDE - | PLUS - | MINUS - | STAR - | MOD - | DIV - | AMPERSAND - | TILDE - | BITWISEOR - | BITWISEXOR - | KW_RLIKE - | KW_REGEXP - | KW_IN - | KW_BETWEEN - ; - -descFuncNames - : - (sysFuncNames) => sysFuncNames - | StringLiteral - | functionIdentifier - ; - -//We are allowed to use From and To in CreateTableUsing command's options (actually seems we can use any string as the option key). But we can't simply add them into nonReserved because by doing that we mess other existing rules. So we create a looseIdentifier and looseNonReserved here. -looseIdentifier - : - Identifier - | looseNonReserved -> Identifier[$looseNonReserved.text] - // If it decides to support SQL11 reserved keywords, i.e., useSQL11ReservedKeywordsForIdentifier()=false, - // the sql11keywords in existing q tests will NOT be added back. - | {useSQL11ReservedKeywordsForIdentifier()}? sql11ReservedKeywordsUsedAsIdentifier -> Identifier[$sql11ReservedKeywordsUsedAsIdentifier.text] - ; - -identifier - : - Identifier - | nonReserved -> Identifier[$nonReserved.text] - // If it decides to support SQL11 reserved keywords, i.e., useSQL11ReservedKeywordsForIdentifier()=false, - // the sql11keywords in existing q tests will NOT be added back. - | {useSQL11ReservedKeywordsForIdentifier()}? sql11ReservedKeywordsUsedAsIdentifier -> Identifier[$sql11ReservedKeywordsUsedAsIdentifier.text] - ; - -functionIdentifier -@init { gParent.pushMsg("function identifier", state); } -@after { gParent.popMsg(state); } - : - identifier (DOT identifier)? -> identifier+ - ; - -principalIdentifier -@init { gParent.pushMsg("identifier for principal spec", state); } -@after { gParent.popMsg(state); } - : identifier - | QuotedIdentifier - ; - -looseNonReserved - : nonReserved | KW_FROM | KW_TO - ; - -//The new version of nonReserved + sql11ReservedKeywordsUsedAsIdentifier = old version of nonReserved -//Non reserved keywords are basically the keywords that can be used as identifiers. -//All the KW_* are automatically not only keywords, but also reserved keywords. -//That means, they can NOT be used as identifiers. -//If you would like to use them as identifiers, put them in the nonReserved list below. -//If you are not sure, please refer to the SQL2011 column in -//http://www.postgresql.org/docs/9.5/static/sql-keywords-appendix.html -nonReserved - : - KW_ADD | KW_ADMIN | KW_AFTER | KW_ANALYZE | KW_ARCHIVE | KW_ASC | KW_BEFORE | KW_BUCKET | KW_BUCKETS - | KW_CASCADE | KW_CHANGE | KW_CLUSTER | KW_CLUSTERED | KW_CLUSTERSTATUS | KW_COLLECTION | KW_COLUMNS - | KW_COMMENT | KW_COMPACT | KW_COMPACTIONS | KW_COMPUTE | KW_CONCATENATE | KW_CONTINUE | KW_DATA | KW_DAY - | KW_DATABASES | KW_DATETIME | KW_DBPROPERTIES | KW_DEFERRED | KW_DEFINED | KW_DELIMITED | KW_DEPENDENCY - | KW_DESC | KW_DIRECTORIES | KW_DIRECTORY | KW_DISABLE | KW_DISTRIBUTE | KW_ELEM_TYPE - | KW_ENABLE | KW_ESCAPED | KW_EXCLUSIVE | KW_EXPLAIN | KW_EXPORT | KW_FIELDS | KW_FILE | KW_FILEFORMAT - | KW_FIRST | KW_FORMAT | KW_FORMATTED | KW_FUNCTIONS | KW_HOLD_DDLTIME | KW_HOUR | KW_IDXPROPERTIES | KW_IGNORE - | KW_INDEX | KW_INDEXES | KW_INPATH | KW_INPUTDRIVER | KW_INPUTFORMAT | KW_ITEMS | KW_JAR - | KW_KEYS | KW_KEY_TYPE | KW_LIMIT | KW_LINES | KW_LOAD | KW_LOCATION | KW_LOCK | KW_LOCKS | KW_LOGICAL | KW_LONG - | KW_MAPJOIN | KW_MATERIALIZED | KW_METADATA | KW_MINUS | KW_MINUTE | KW_MONTH | KW_MSCK | KW_NOSCAN | KW_NO_DROP | KW_OFFLINE - | KW_OPTION | KW_OUTPUTDRIVER | KW_OUTPUTFORMAT | KW_OVERWRITE | KW_OWNER | KW_PARTITIONED | KW_PARTITIONS | KW_PLUS | KW_PRETTY - | KW_PRINCIPALS | KW_PROTECTION | KW_PURGE | KW_READ | KW_READONLY | KW_REBUILD | KW_RECORDREADER | KW_RECORDWRITER - | KW_RELOAD | KW_RENAME | KW_REPAIR | KW_REPLACE | KW_REPLICATION | KW_RESTRICT | KW_REWRITE - | KW_ROLE | KW_ROLES | KW_SCHEMA | KW_SCHEMAS | KW_SECOND | KW_SEMI | KW_SERDE | KW_SERDEPROPERTIES | KW_SERVER | KW_SETS | KW_SHARED - | KW_SHOW | KW_SHOW_DATABASE | KW_SKEWED | KW_SORT | KW_SORTED | KW_SSL | KW_STATISTICS | KW_STORED - | KW_STREAMTABLE | KW_STRING | KW_STRUCT | KW_TABLES | KW_TBLPROPERTIES | KW_TEMPORARY | KW_TERMINATED - | KW_TINYINT | KW_TOUCH | KW_TRANSACTIONS | KW_UNARCHIVE | KW_UNDO | KW_UNIONTYPE | KW_UNLOCK | KW_UNSET - | KW_UNSIGNED | KW_URI | KW_USE | KW_UTC | KW_UTCTIMESTAMP | KW_VALUE_TYPE | KW_VIEW | KW_WHILE | KW_YEAR - | KW_WORK - | KW_TRANSACTION - | KW_WRITE - | KW_ISOLATION - | KW_LEVEL - | KW_SNAPSHOT - | KW_AUTOCOMMIT - | KW_ANTI - | KW_WEEK | KW_MILLISECOND | KW_MICROSECOND - | KW_CLEAR | KW_LAZY | KW_CACHE | KW_UNCACHE | KW_DFS -; - -//The following SQL2011 reserved keywords are used as cast function name only, but not as identifiers. -sql11ReservedKeywordsUsedAsCastFunctionName - : - KW_BIGINT | KW_BINARY | KW_BOOLEAN | KW_CURRENT_DATE | KW_CURRENT_TIMESTAMP | KW_DATE | KW_DOUBLE | KW_FLOAT | KW_INT | KW_SMALLINT | KW_TIMESTAMP - ; - -//The following SQL2011 reserved keywords are used as identifiers in many q tests, they may be added back due to backward compatibility. -//We are planning to remove the following whole list after several releases. -//Thus, please do not change the following list unless you know what to do. -sql11ReservedKeywordsUsedAsIdentifier - : - KW_ALL | KW_ALTER | KW_ARRAY | KW_AS | KW_AUTHORIZATION | KW_BETWEEN | KW_BIGINT | KW_BINARY | KW_BOOLEAN - | KW_BOTH | KW_BY | KW_CREATE | KW_CUBE | KW_CURRENT_DATE | KW_CURRENT_TIMESTAMP | KW_CURSOR | KW_DATE | KW_DECIMAL | KW_DELETE | KW_DESCRIBE - | KW_DOUBLE | KW_DROP | KW_EXISTS | KW_EXTERNAL | KW_FALSE | KW_FETCH | KW_FLOAT | KW_FOR | KW_FULL | KW_GRANT - | KW_GROUP | KW_GROUPING | KW_IMPORT | KW_IN | KW_INNER | KW_INSERT | KW_INT | KW_INTERSECT | KW_INTO | KW_IS | KW_LATERAL - | KW_LEFT | KW_LIKE | KW_LOCAL | KW_NONE | KW_NULL | KW_OF | KW_ORDER | KW_OUT | KW_OUTER | KW_PARTITION - | KW_PERCENT | KW_PROCEDURE | KW_RANGE | KW_READS | KW_REVOKE | KW_RIGHT - | KW_ROLLUP | KW_ROW | KW_ROWS | KW_SET | KW_SMALLINT | KW_TABLE | KW_TIMESTAMP | KW_TO | KW_TRIGGER | KW_TRUE - | KW_TRUNCATE | KW_UNION | KW_UPDATE | KW_USER | KW_USING | KW_VALUES | KW_WITH -//The following two keywords come from MySQL. Although they are not keywords in SQL2011, they are reserved keywords in MySQL. - | KW_REGEXP | KW_RLIKE - ; diff --git a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SelectClauseParser.g b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SelectClauseParser.g deleted file mode 100644 index f18b6ec496..0000000000 --- a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SelectClauseParser.g +++ /dev/null @@ -1,235 +0,0 @@ -/** - Licensed to the Apache Software Foundation (ASF) under one or more - contributor license agreements. See the NOTICE file distributed with - this work for additional information regarding copyright ownership. - The ASF licenses this file to You under the Apache License, Version 2.0 - (the "License"); you may not use this file except in compliance with - the License. You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - - This file is an adaptation of Hive's org/apache/hadoop/hive/ql/SelectClauseParser.g grammar. -*/ -parser grammar SelectClauseParser; - -options -{ -output=AST; -ASTLabelType=CommonTree; -backtrack=false; -k=3; -} - -@members { - @Override - public Object recoverFromMismatchedSet(IntStream input, - RecognitionException re, BitSet follow) throws RecognitionException { - throw re; - } - @Override - public void displayRecognitionError(String[] tokenNames, - RecognitionException e) { - gParent.displayRecognitionError(tokenNames, e); - } - protected boolean useSQL11ReservedKeywordsForIdentifier() { - return gParent.useSQL11ReservedKeywordsForIdentifier(); - } -} - -@rulecatch { -catch (RecognitionException e) { - throw e; -} -} - -//----------------------- Rules for parsing selectClause ----------------------------- -// select a,b,c ... -selectClause -@init { gParent.pushMsg("select clause", state); } -@after { gParent.popMsg(state); } - : - KW_SELECT hintClause? (((KW_ALL | dist=KW_DISTINCT)? selectList) - | (transform=KW_TRANSFORM selectTrfmClause)) - -> {$transform == null && $dist == null}? ^(TOK_SELECT hintClause? selectList) - -> {$transform == null && $dist != null}? ^(TOK_SELECTDI hintClause? selectList) - -> ^(TOK_SELECT hintClause? ^(TOK_SELEXPR selectTrfmClause) ) - | - trfmClause ->^(TOK_SELECT ^(TOK_SELEXPR trfmClause)) - ; - -selectList -@init { gParent.pushMsg("select list", state); } -@after { gParent.popMsg(state); } - : - selectItem ( COMMA selectItem )* -> selectItem+ - ; - -selectTrfmClause -@init { gParent.pushMsg("transform clause", state); } -@after { gParent.popMsg(state); } - : - LPAREN selectExpressionList RPAREN - inSerde=rowFormat inRec=recordWriter - KW_USING StringLiteral - ( KW_AS ((LPAREN (aliasList | columnNameTypeList) RPAREN) | (aliasList | columnNameTypeList)))? - outSerde=rowFormat outRec=recordReader - -> ^(TOK_TRANSFORM selectExpressionList $inSerde $inRec StringLiteral $outSerde $outRec aliasList? columnNameTypeList?) - ; - -hintClause -@init { gParent.pushMsg("hint clause", state); } -@after { gParent.popMsg(state); } - : - DIVIDE STAR PLUS hintList STAR DIVIDE -> ^(TOK_HINTLIST hintList) - ; - -hintList -@init { gParent.pushMsg("hint list", state); } -@after { gParent.popMsg(state); } - : - hintItem (COMMA hintItem)* -> hintItem+ - ; - -hintItem -@init { gParent.pushMsg("hint item", state); } -@after { gParent.popMsg(state); } - : - hintName (LPAREN hintArgs RPAREN)? -> ^(TOK_HINT hintName hintArgs?) - ; - -hintName -@init { gParent.pushMsg("hint name", state); } -@after { gParent.popMsg(state); } - : - KW_MAPJOIN -> TOK_MAPJOIN - | KW_STREAMTABLE -> TOK_STREAMTABLE - ; - -hintArgs -@init { gParent.pushMsg("hint arguments", state); } -@after { gParent.popMsg(state); } - : - hintArgName (COMMA hintArgName)* -> ^(TOK_HINTARGLIST hintArgName+) - ; - -hintArgName -@init { gParent.pushMsg("hint argument name", state); } -@after { gParent.popMsg(state); } - : - identifier - ; - -selectItem -@init { gParent.pushMsg("selection target", state); } -@after { gParent.popMsg(state); } - : - (tableAllColumns) => tableAllColumns -> ^(TOK_SELEXPR tableAllColumns) - | - namedExpression - ; - -namedExpression -@init { gParent.pushMsg("select named expression", state); } -@after { gParent.popMsg(state); } - : - ( expression - ((KW_AS? identifier) | (KW_AS LPAREN identifier (COMMA identifier)* RPAREN))? - ) -> ^(TOK_SELEXPR expression identifier*) - ; - -trfmClause -@init { gParent.pushMsg("transform clause", state); } -@after { gParent.popMsg(state); } - : - ( KW_MAP selectExpressionList - | KW_REDUCE selectExpressionList ) - inSerde=rowFormat inRec=recordWriter - KW_USING StringLiteral - ( KW_AS ((LPAREN (aliasList | columnNameTypeList) RPAREN) | (aliasList | columnNameTypeList)))? - outSerde=rowFormat outRec=recordReader - -> ^(TOK_TRANSFORM selectExpressionList $inSerde $inRec StringLiteral $outSerde $outRec aliasList? columnNameTypeList?) - ; - -selectExpression -@init { gParent.pushMsg("select expression", state); } -@after { gParent.popMsg(state); } - : - (tableAllColumns) => tableAllColumns - | - expression - ; - -selectExpressionList -@init { gParent.pushMsg("select expression list", state); } -@after { gParent.popMsg(state); } - : - selectExpression (COMMA selectExpression)* -> ^(TOK_EXPLIST selectExpression+) - ; - -//---------------------- Rules for windowing clauses ------------------------------- -window_clause -@init { gParent.pushMsg("window_clause", state); } -@after { gParent.popMsg(state); } -: - KW_WINDOW window_defn (COMMA window_defn)* -> ^(KW_WINDOW window_defn+) -; - -window_defn -@init { gParent.pushMsg("window_defn", state); } -@after { gParent.popMsg(state); } -: - Identifier KW_AS window_specification -> ^(TOK_WINDOWDEF Identifier window_specification) -; - -window_specification -@init { gParent.pushMsg("window_specification", state); } -@after { gParent.popMsg(state); } -: - (Identifier | ( LPAREN Identifier? partitioningSpec? window_frame? RPAREN)) -> ^(TOK_WINDOWSPEC Identifier? partitioningSpec? window_frame?) -; - -window_frame : - window_range_expression | - window_value_expression -; - -window_range_expression -@init { gParent.pushMsg("window_range_expression", state); } -@after { gParent.popMsg(state); } -: - KW_ROWS sb=window_frame_start_boundary -> ^(TOK_WINDOWRANGE $sb) | - KW_ROWS KW_BETWEEN s=window_frame_boundary KW_AND end=window_frame_boundary -> ^(TOK_WINDOWRANGE $s $end) -; - -window_value_expression -@init { gParent.pushMsg("window_value_expression", state); } -@after { gParent.popMsg(state); } -: - KW_RANGE sb=window_frame_start_boundary -> ^(TOK_WINDOWVALUES $sb) | - KW_RANGE KW_BETWEEN s=window_frame_boundary KW_AND end=window_frame_boundary -> ^(TOK_WINDOWVALUES $s $end) -; - -window_frame_start_boundary -@init { gParent.pushMsg("windowframestartboundary", state); } -@after { gParent.popMsg(state); } -: - KW_UNBOUNDED KW_PRECEDING -> ^(KW_PRECEDING KW_UNBOUNDED) | - KW_CURRENT KW_ROW -> ^(KW_CURRENT) | - Number KW_PRECEDING -> ^(KW_PRECEDING Number) -; - -window_frame_boundary -@init { gParent.pushMsg("windowframeboundary", state); } -@after { gParent.popMsg(state); } -: - KW_UNBOUNDED (r=KW_PRECEDING|r=KW_FOLLOWING) -> ^($r KW_UNBOUNDED) | - KW_CURRENT KW_ROW -> ^(KW_CURRENT) | - Number (d=KW_PRECEDING | d=KW_FOLLOWING ) -> ^($d Number) -; - diff --git a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlLexer.g b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlLexer.g deleted file mode 100644 index fd1ad59207..0000000000 --- a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlLexer.g +++ /dev/null @@ -1,491 +0,0 @@ -/** - Licensed to the Apache Software Foundation (ASF) under one or more - contributor license agreements. See the NOTICE file distributed with - this work for additional information regarding copyright ownership. - The ASF licenses this file to You under the Apache License, Version 2.0 - (the "License"); you may not use this file except in compliance with - the License. You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - - This file is an adaptation of Hive's org/apache/hadoop/hive/ql/HiveLexer.g grammar. -*/ -lexer grammar SparkSqlLexer; - -@lexer::header { -package org.apache.spark.sql.catalyst.parser; - -} - -@lexer::members { - private ParserConf parserConf; - private ParseErrorReporter reporter; - - public void configure(ParserConf parserConf, ParseErrorReporter reporter) { - this.parserConf = parserConf; - this.reporter = reporter; - } - - protected boolean allowQuotedId() { - if (parserConf == null) { - return true; - } - return parserConf.supportQuotedId(); - } - - @Override - public void displayRecognitionError(String[] tokenNames, RecognitionException e) { - if (reporter != null) { - reporter.report(this, e, tokenNames); - } - } -} - -// Keywords - -KW_TRUE : 'TRUE'; -KW_FALSE : 'FALSE'; -KW_ALL : 'ALL'; -KW_NONE: 'NONE'; -KW_AND : 'AND'; -KW_OR : 'OR'; -KW_NOT : 'NOT' | '!'; -KW_LIKE : 'LIKE'; - -KW_IF : 'IF'; -KW_EXISTS : 'EXISTS'; - -KW_ASC : 'ASC'; -KW_DESC : 'DESC'; -KW_ORDER : 'ORDER'; -KW_GROUP : 'GROUP'; -KW_BY : 'BY'; -KW_HAVING : 'HAVING'; -KW_WHERE : 'WHERE'; -KW_FROM : 'FROM'; -KW_AS : 'AS'; -KW_SELECT : 'SELECT'; -KW_DISTINCT : 'DISTINCT'; -KW_INSERT : 'INSERT'; -KW_OVERWRITE : 'OVERWRITE'; -KW_OUTER : 'OUTER'; -KW_UNIQUEJOIN : 'UNIQUEJOIN'; -KW_PRESERVE : 'PRESERVE'; -KW_JOIN : 'JOIN'; -KW_LEFT : 'LEFT'; -KW_RIGHT : 'RIGHT'; -KW_FULL : 'FULL'; -KW_ANTI : 'ANTI'; -KW_ON : 'ON'; -KW_PARTITION : 'PARTITION'; -KW_PARTITIONS : 'PARTITIONS'; -KW_TABLE: 'TABLE'; -KW_TABLES: 'TABLES'; -KW_COLUMNS: 'COLUMNS'; -KW_INDEX: 'INDEX'; -KW_INDEXES: 'INDEXES'; -KW_REBUILD: 'REBUILD'; -KW_FUNCTIONS: 'FUNCTIONS'; -KW_SHOW: 'SHOW'; -KW_MSCK: 'MSCK'; -KW_REPAIR: 'REPAIR'; -KW_DIRECTORY: 'DIRECTORY'; -KW_LOCAL: 'LOCAL'; -KW_TRANSFORM : 'TRANSFORM'; -KW_USING: 'USING'; -KW_CLUSTER: 'CLUSTER'; -KW_DISTRIBUTE: 'DISTRIBUTE'; -KW_SORT: 'SORT'; -KW_UNION: 'UNION'; -KW_EXCEPT: 'EXCEPT'; -KW_LOAD: 'LOAD'; -KW_EXPORT: 'EXPORT'; -KW_IMPORT: 'IMPORT'; -KW_REPLICATION: 'REPLICATION'; -KW_METADATA: 'METADATA'; -KW_DATA: 'DATA'; -KW_INPATH: 'INPATH'; -KW_IS: 'IS'; -KW_NULL: 'NULL'; -KW_CREATE: 'CREATE'; -KW_EXTERNAL: 'EXTERNAL'; -KW_ALTER: 'ALTER'; -KW_CHANGE: 'CHANGE'; -KW_COLUMN: 'COLUMN'; -KW_FIRST: 'FIRST'; -KW_AFTER: 'AFTER'; -KW_DESCRIBE: 'DESCRIBE'; -KW_DROP: 'DROP'; -KW_RENAME: 'RENAME'; -KW_TO: 'TO'; -KW_COMMENT: 'COMMENT'; -KW_BOOLEAN: 'BOOLEAN'; -KW_TINYINT: 'TINYINT'; -KW_SMALLINT: 'SMALLINT'; -KW_INT: 'INT'; -KW_BIGINT: 'BIGINT'; -KW_FLOAT: 'FLOAT'; -KW_DOUBLE: 'DOUBLE'; -KW_DATE: 'DATE'; -KW_DATETIME: 'DATETIME'; -KW_TIMESTAMP: 'TIMESTAMP'; -KW_INTERVAL: 'INTERVAL'; -KW_DECIMAL: 'DECIMAL'; -KW_STRING: 'STRING'; -KW_CHAR: 'CHAR'; -KW_VARCHAR: 'VARCHAR'; -KW_ARRAY: 'ARRAY'; -KW_STRUCT: 'STRUCT'; -KW_MAP: 'MAP'; -KW_UNIONTYPE: 'UNIONTYPE'; -KW_REDUCE: 'REDUCE'; -KW_PARTITIONED: 'PARTITIONED'; -KW_CLUSTERED: 'CLUSTERED'; -KW_SORTED: 'SORTED'; -KW_INTO: 'INTO'; -KW_BUCKETS: 'BUCKETS'; -KW_ROW: 'ROW'; -KW_ROWS: 'ROWS'; -KW_FORMAT: 'FORMAT'; -KW_DELIMITED: 'DELIMITED'; -KW_FIELDS: 'FIELDS'; -KW_TERMINATED: 'TERMINATED'; -KW_ESCAPED: 'ESCAPED'; -KW_COLLECTION: 'COLLECTION'; -KW_ITEMS: 'ITEMS'; -KW_KEYS: 'KEYS'; -KW_KEY_TYPE: '$KEY$'; -KW_LINES: 'LINES'; -KW_STORED: 'STORED'; -KW_FILEFORMAT: 'FILEFORMAT'; -KW_INPUTFORMAT: 'INPUTFORMAT'; -KW_OUTPUTFORMAT: 'OUTPUTFORMAT'; -KW_INPUTDRIVER: 'INPUTDRIVER'; -KW_OUTPUTDRIVER: 'OUTPUTDRIVER'; -KW_ENABLE: 'ENABLE'; -KW_DISABLE: 'DISABLE'; -KW_LOCATION: 'LOCATION'; -KW_TABLESAMPLE: 'TABLESAMPLE'; -KW_BUCKET: 'BUCKET'; -KW_OUT: 'OUT'; -KW_OF: 'OF'; -KW_PERCENT: 'PERCENT'; -KW_CAST: 'CAST'; -KW_ADD: 'ADD'; -KW_REPLACE: 'REPLACE'; -KW_RLIKE: 'RLIKE'; -KW_REGEXP: 'REGEXP'; -KW_TEMPORARY: 'TEMPORARY'; -KW_FUNCTION: 'FUNCTION'; -KW_MACRO: 'MACRO'; -KW_FILE: 'FILE'; -KW_JAR: 'JAR'; -KW_EXPLAIN: 'EXPLAIN'; -KW_EXTENDED: 'EXTENDED'; -KW_FORMATTED: 'FORMATTED'; -KW_PRETTY: 'PRETTY'; -KW_DEPENDENCY: 'DEPENDENCY'; -KW_LOGICAL: 'LOGICAL'; -KW_SERDE: 'SERDE'; -KW_WITH: 'WITH'; -KW_DEFERRED: 'DEFERRED'; -KW_SERDEPROPERTIES: 'SERDEPROPERTIES'; -KW_DBPROPERTIES: 'DBPROPERTIES'; -KW_LIMIT: 'LIMIT'; -KW_SET: 'SET'; -KW_UNSET: 'UNSET'; -KW_TBLPROPERTIES: 'TBLPROPERTIES'; -KW_IDXPROPERTIES: 'IDXPROPERTIES'; -KW_VALUE_TYPE: '$VALUE$'; -KW_ELEM_TYPE: '$ELEM$'; -KW_DEFINED: 'DEFINED'; -KW_CASE: 'CASE'; -KW_WHEN: 'WHEN'; -KW_THEN: 'THEN'; -KW_ELSE: 'ELSE'; -KW_END: 'END'; -KW_MAPJOIN: 'MAPJOIN'; -KW_STREAMTABLE: 'STREAMTABLE'; -KW_CLUSTERSTATUS: 'CLUSTERSTATUS'; -KW_UTC: 'UTC'; -KW_UTCTIMESTAMP: 'UTC_TMESTAMP'; -KW_LONG: 'LONG'; -KW_DELETE: 'DELETE'; -KW_PLUS: 'PLUS'; -KW_MINUS: 'MINUS'; -KW_FETCH: 'FETCH'; -KW_INTERSECT: 'INTERSECT'; -KW_VIEW: 'VIEW'; -KW_IN: 'IN'; -KW_DATABASE: 'DATABASE'; -KW_DATABASES: 'DATABASES'; -KW_MATERIALIZED: 'MATERIALIZED'; -KW_SCHEMA: 'SCHEMA'; -KW_SCHEMAS: 'SCHEMAS'; -KW_GRANT: 'GRANT'; -KW_REVOKE: 'REVOKE'; -KW_SSL: 'SSL'; -KW_UNDO: 'UNDO'; -KW_LOCK: 'LOCK'; -KW_LOCKS: 'LOCKS'; -KW_UNLOCK: 'UNLOCK'; -KW_SHARED: 'SHARED'; -KW_EXCLUSIVE: 'EXCLUSIVE'; -KW_PROCEDURE: 'PROCEDURE'; -KW_UNSIGNED: 'UNSIGNED'; -KW_WHILE: 'WHILE'; -KW_READ: 'READ'; -KW_READS: 'READS'; -KW_PURGE: 'PURGE'; -KW_RANGE: 'RANGE'; -KW_ANALYZE: 'ANALYZE'; -KW_BEFORE: 'BEFORE'; -KW_BETWEEN: 'BETWEEN'; -KW_BOTH: 'BOTH'; -KW_BINARY: 'BINARY'; -KW_CROSS: 'CROSS'; -KW_CONTINUE: 'CONTINUE'; -KW_CURSOR: 'CURSOR'; -KW_TRIGGER: 'TRIGGER'; -KW_RECORDREADER: 'RECORDREADER'; -KW_RECORDWRITER: 'RECORDWRITER'; -KW_SEMI: 'SEMI'; -KW_LATERAL: 'LATERAL'; -KW_TOUCH: 'TOUCH'; -KW_ARCHIVE: 'ARCHIVE'; -KW_UNARCHIVE: 'UNARCHIVE'; -KW_COMPUTE: 'COMPUTE'; -KW_STATISTICS: 'STATISTICS'; -KW_USE: 'USE'; -KW_OPTION: 'OPTION'; -KW_CONCATENATE: 'CONCATENATE'; -KW_SHOW_DATABASE: 'SHOW_DATABASE'; -KW_UPDATE: 'UPDATE'; -KW_RESTRICT: 'RESTRICT'; -KW_CASCADE: 'CASCADE'; -KW_SKEWED: 'SKEWED'; -KW_ROLLUP: 'ROLLUP'; -KW_CUBE: 'CUBE'; -KW_DIRECTORIES: 'DIRECTORIES'; -KW_FOR: 'FOR'; -KW_WINDOW: 'WINDOW'; -KW_UNBOUNDED: 'UNBOUNDED'; -KW_PRECEDING: 'PRECEDING'; -KW_FOLLOWING: 'FOLLOWING'; -KW_CURRENT: 'CURRENT'; -KW_CURRENT_DATE: 'CURRENT_DATE'; -KW_CURRENT_TIMESTAMP: 'CURRENT_TIMESTAMP'; -KW_LESS: 'LESS'; -KW_MORE: 'MORE'; -KW_OVER: 'OVER'; -KW_GROUPING: 'GROUPING'; -KW_SETS: 'SETS'; -KW_TRUNCATE: 'TRUNCATE'; -KW_NOSCAN: 'NOSCAN'; -KW_PARTIALSCAN: 'PARTIALSCAN'; -KW_USER: 'USER'; -KW_ROLE: 'ROLE'; -KW_ROLES: 'ROLES'; -KW_INNER: 'INNER'; -KW_EXCHANGE: 'EXCHANGE'; -KW_URI: 'URI'; -KW_SERVER : 'SERVER'; -KW_ADMIN: 'ADMIN'; -KW_OWNER: 'OWNER'; -KW_PRINCIPALS: 'PRINCIPALS'; -KW_COMPACT: 'COMPACT'; -KW_COMPACTIONS: 'COMPACTIONS'; -KW_TRANSACTIONS: 'TRANSACTIONS'; -KW_REWRITE : 'REWRITE'; -KW_AUTHORIZATION: 'AUTHORIZATION'; -KW_CONF: 'CONF'; -KW_VALUES: 'VALUES'; -KW_RELOAD: 'RELOAD'; -KW_YEAR: 'YEAR'|'YEARS'; -KW_MONTH: 'MONTH'|'MONTHS'; -KW_DAY: 'DAY'|'DAYS'; -KW_HOUR: 'HOUR'|'HOURS'; -KW_MINUTE: 'MINUTE'|'MINUTES'; -KW_SECOND: 'SECOND'|'SECONDS'; -KW_START: 'START'; -KW_TRANSACTION: 'TRANSACTION'; -KW_COMMIT: 'COMMIT'; -KW_ROLLBACK: 'ROLLBACK'; -KW_WORK: 'WORK'; -KW_ONLY: 'ONLY'; -KW_WRITE: 'WRITE'; -KW_ISOLATION: 'ISOLATION'; -KW_LEVEL: 'LEVEL'; -KW_SNAPSHOT: 'SNAPSHOT'; -KW_AUTOCOMMIT: 'AUTOCOMMIT'; -KW_REFRESH: 'REFRESH'; -KW_OPTIONS: 'OPTIONS'; -KW_WEEK: 'WEEK'|'WEEKS'; -KW_MILLISECOND: 'MILLISECOND'|'MILLISECONDS'; -KW_MICROSECOND: 'MICROSECOND'|'MICROSECONDS'; -KW_CLEAR: 'CLEAR'; -KW_LAZY: 'LAZY'; -KW_CACHE: 'CACHE'; -KW_UNCACHE: 'UNCACHE'; -KW_DFS: 'DFS'; - -KW_NATURAL: 'NATURAL'; - -// Operators -// NOTE: if you add a new function/operator, add it to sysFuncNames so that describe function _FUNC_ will work. - -DOT : '.'; // generated as a part of Number rule -COLON : ':' ; -COMMA : ',' ; -SEMICOLON : ';' ; - -LPAREN : '(' ; -RPAREN : ')' ; -LSQUARE : '[' ; -RSQUARE : ']' ; -LCURLY : '{'; -RCURLY : '}'; - -EQUAL : '=' | '=='; -EQUAL_NS : '<=>'; -NOTEQUAL : '<>' | '!='; -LESSTHANOREQUALTO : '<='; -LESSTHAN : '<'; -GREATERTHANOREQUALTO : '>='; -GREATERTHAN : '>'; - -DIVIDE : '/'; -PLUS : '+'; -MINUS : '-'; -STAR : '*'; -MOD : '%'; -DIV : 'DIV'; - -AMPERSAND : '&'; -TILDE : '~'; -BITWISEOR : '|'; -BITWISEXOR : '^'; -QUESTION : '?'; -DOLLAR : '$'; - -// LITERALS -fragment -Letter - : 'a'..'z' | 'A'..'Z' - ; - -fragment -HexDigit - : 'a'..'f' | 'A'..'F' - ; - -fragment -Digit - : - '0'..'9' - ; - -fragment -Exponent - : - ('e' | 'E') ( PLUS|MINUS )? (Digit)+ - ; - -fragment -RegexComponent - : 'a'..'z' | 'A'..'Z' | '0'..'9' | '_' - | PLUS | STAR | QUESTION | MINUS | DOT - | LPAREN | RPAREN | LSQUARE | RSQUARE | LCURLY | RCURLY - | BITWISEXOR | BITWISEOR | DOLLAR | '!' - ; - -StringLiteral - : - ( '\'' ( ~('\''|'\\') | ('\\' .) )* '\'' - | '\"' ( ~('\"'|'\\') | ('\\' .) )* '\"' - )+ - ; - -BigintLiteral - : - (Digit)+ 'L' - ; - -SmallintLiteral - : - (Digit)+ 'S' - ; - -TinyintLiteral - : - (Digit)+ 'Y' - ; - -DoubleLiteral - : - Number 'D' - ; - -ByteLengthLiteral - : - (Digit)+ ('b' | 'B' | 'k' | 'K' | 'm' | 'M' | 'g' | 'G') - ; - -Number - : - ((Digit+ (DOT Digit*)?) | (DOT Digit+)) Exponent? - ; - -/* -An Identifier can be: -- tableName -- columnName -- select expr alias -- lateral view aliases -- database name -- view name -- subquery alias -- function name -- ptf argument identifier -- index name -- property name for: db,tbl,partition... -- fileFormat -- role name -- privilege name -- principal name -- macro name -- hint name -- window name -*/ -Identifier - : - (Letter | Digit | '_')+ - | {allowQuotedId()}? QuotedIdentifier /* though at the language level we allow all Identifiers to be QuotedIdentifiers; - at the API level only columns are allowed to be of this form */ - | '`' RegexComponent+ '`' - ; - -fragment -QuotedIdentifier - : - '`' ( '``' | ~('`') )* '`' { setText(getText().replaceAll("``", "`")); } - ; - -WS : (' '|'\r'|'\t'|'\n') {$channel=HIDDEN;} - ; - -COMMENT - : '--' (~('\n'|'\r'))* - { $channel=HIDDEN; } - ; - -/* Prevent that the lexer swallows unknown characters. */ -ANY - :. - ; diff --git a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlParser.g b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlParser.g deleted file mode 100644 index f0c236859d..0000000000 --- a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlParser.g +++ /dev/null @@ -1,2596 +0,0 @@ -/** - Licensed to the Apache Software Foundation (ASF) under one or more - contributor license agreements. See the NOTICE file distributed with - this work for additional information regarding copyright ownership. - The ASF licenses this file to You under the Apache License, Version 2.0 - (the "License"); you may not use this file except in compliance with - the License. You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - - This file is an adaptation of Hive's org/apache/hadoop/hive/ql/HiveParser.g grammar. -*/ -parser grammar SparkSqlParser; - -options -{ -tokenVocab=SparkSqlLexer; -output=AST; -ASTLabelType=CommonTree; -backtrack=false; -k=3; -} -import SelectClauseParser, FromClauseParser, IdentifiersParser, KeywordParser, ExpressionParser; - -tokens { -TOK_INSERT; -TOK_QUERY; -TOK_SELECT; -TOK_SELECTDI; -TOK_SELEXPR; -TOK_FROM; -TOK_TAB; -TOK_PARTSPEC; -TOK_PARTVAL; -TOK_DIR; -TOK_TABREF; -TOK_SUBQUERY; -TOK_INSERT_INTO; -TOK_DESTINATION; -TOK_ALLCOLREF; -TOK_TABLE_OR_COL; -TOK_FUNCTION; -TOK_FUNCTIONDI; -TOK_FUNCTIONSTAR; -TOK_WHERE; -TOK_OP_EQ; -TOK_OP_NE; -TOK_OP_LE; -TOK_OP_LT; -TOK_OP_GE; -TOK_OP_GT; -TOK_OP_DIV; -TOK_OP_ADD; -TOK_OP_SUB; -TOK_OP_MUL; -TOK_OP_MOD; -TOK_OP_BITAND; -TOK_OP_BITNOT; -TOK_OP_BITOR; -TOK_OP_BITXOR; -TOK_OP_AND; -TOK_OP_OR; -TOK_OP_NOT; -TOK_OP_LIKE; -TOK_TRUE; -TOK_FALSE; -TOK_TRANSFORM; -TOK_SERDE; -TOK_SERDENAME; -TOK_SERDEPROPS; -TOK_EXPLIST; -TOK_ALIASLIST; -TOK_GROUPBY; -TOK_ROLLUP_GROUPBY; -TOK_CUBE_GROUPBY; -TOK_GROUPING_SETS; -TOK_GROUPING_SETS_EXPRESSION; -TOK_HAVING; -TOK_ORDERBY; -TOK_CLUSTERBY; -TOK_DISTRIBUTEBY; -TOK_SORTBY; -TOK_UNIONALL; -TOK_UNIONDISTINCT; -TOK_EXCEPT; -TOK_INTERSECT; -TOK_JOIN; -TOK_LEFTOUTERJOIN; -TOK_RIGHTOUTERJOIN; -TOK_FULLOUTERJOIN; -TOK_UNIQUEJOIN; -TOK_CROSSJOIN; -TOK_NATURALJOIN; -TOK_NATURALLEFTOUTERJOIN; -TOK_NATURALRIGHTOUTERJOIN; -TOK_NATURALFULLOUTERJOIN; -TOK_LOAD; -TOK_EXPORT; -TOK_IMPORT; -TOK_REPLICATION; -TOK_METADATA; -TOK_NULL; -TOK_ISNULL; -TOK_ISNOTNULL; -TOK_TINYINT; -TOK_SMALLINT; -TOK_INT; -TOK_BIGINT; -TOK_BOOLEAN; -TOK_FLOAT; -TOK_DOUBLE; -TOK_DATE; -TOK_DATELITERAL; -TOK_DATETIME; -TOK_TIMESTAMP; -TOK_TIMESTAMPLITERAL; -TOK_INTERVAL; -TOK_INTERVAL_YEAR_MONTH; -TOK_INTERVAL_YEAR_MONTH_LITERAL; -TOK_INTERVAL_DAY_TIME; -TOK_INTERVAL_DAY_TIME_LITERAL; -TOK_INTERVAL_YEAR_LITERAL; -TOK_INTERVAL_MONTH_LITERAL; -TOK_INTERVAL_WEEK_LITERAL; -TOK_INTERVAL_DAY_LITERAL; -TOK_INTERVAL_HOUR_LITERAL; -TOK_INTERVAL_MINUTE_LITERAL; -TOK_INTERVAL_SECOND_LITERAL; -TOK_INTERVAL_MILLISECOND_LITERAL; -TOK_INTERVAL_MICROSECOND_LITERAL; -TOK_STRING; -TOK_CHAR; -TOK_VARCHAR; -TOK_BINARY; -TOK_DECIMAL; -TOK_LIST; -TOK_STRUCT; -TOK_MAP; -TOK_UNIONTYPE; -TOK_COLTYPELIST; -TOK_CREATEDATABASE; -TOK_CREATETABLE; -TOK_CREATETABLEUSING; -TOK_TRUNCATETABLE; -TOK_CREATEINDEX; -TOK_CREATEINDEX_INDEXTBLNAME; -TOK_DEFERRED_REBUILDINDEX; -TOK_DROPINDEX; -TOK_LIKETABLE; -TOK_DESCTABLE; -TOK_DESCFUNCTION; -TOK_ALTERTABLE; -TOK_ALTERTABLE_RENAME; -TOK_ALTERTABLE_ADDCOLS; -TOK_ALTERTABLE_RENAMECOL; -TOK_ALTERTABLE_RENAMEPART; -TOK_ALTERTABLE_REPLACECOLS; -TOK_ALTERTABLE_ADDPARTS; -TOK_ALTERTABLE_DROPPARTS; -TOK_ALTERTABLE_PARTCOLTYPE; -TOK_ALTERTABLE_MERGEFILES; -TOK_ALTERTABLE_TOUCH; -TOK_ALTERTABLE_ARCHIVE; -TOK_ALTERTABLE_UNARCHIVE; -TOK_ALTERTABLE_SERDEPROPERTIES; -TOK_ALTERTABLE_SERIALIZER; -TOK_ALTERTABLE_UPDATECOLSTATS; -TOK_TABLE_PARTITION; -TOK_ALTERTABLE_FILEFORMAT; -TOK_ALTERTABLE_LOCATION; -TOK_ALTERTABLE_PROPERTIES; -TOK_ALTERTABLE_CHANGECOL_AFTER_POSITION; -TOK_ALTERTABLE_DROPPROPERTIES; -TOK_ALTERTABLE_SKEWED; -TOK_ALTERTABLE_EXCHANGEPARTITION; -TOK_ALTERTABLE_SKEWED_LOCATION; -TOK_ALTERTABLE_BUCKETS; -TOK_ALTERTABLE_CLUSTER_SORT; -TOK_ALTERTABLE_COMPACT; -TOK_ALTERINDEX_REBUILD; -TOK_ALTERINDEX_PROPERTIES; -TOK_MSCK; -TOK_SHOWDATABASES; -TOK_SHOWTABLES; -TOK_SHOWCOLUMNS; -TOK_SHOWFUNCTIONS; -TOK_SHOWPARTITIONS; -TOK_SHOW_CREATEDATABASE; -TOK_SHOW_CREATETABLE; -TOK_SHOW_TABLESTATUS; -TOK_SHOW_TBLPROPERTIES; -TOK_SHOWLOCKS; -TOK_SHOWCONF; -TOK_LOCKTABLE; -TOK_UNLOCKTABLE; -TOK_LOCKDB; -TOK_UNLOCKDB; -TOK_SWITCHDATABASE; -TOK_DROPDATABASE; -TOK_DROPTABLE; -TOK_DATABASECOMMENT; -TOK_TABCOLLIST; -TOK_TABCOL; -TOK_TABLECOMMENT; -TOK_TABLEPARTCOLS; -TOK_TABLEROWFORMAT; -TOK_TABLEROWFORMATFIELD; -TOK_TABLEROWFORMATCOLLITEMS; -TOK_TABLEROWFORMATMAPKEYS; -TOK_TABLEROWFORMATLINES; -TOK_TABLEROWFORMATNULL; -TOK_TABLEFILEFORMAT; -TOK_FILEFORMAT_GENERIC; -TOK_OFFLINE; -TOK_ENABLE; -TOK_DISABLE; -TOK_READONLY; -TOK_NO_DROP; -TOK_STORAGEHANDLER; -TOK_NOT_CLUSTERED; -TOK_NOT_SORTED; -TOK_TABCOLNAME; -TOK_TABLELOCATION; -TOK_PARTITIONLOCATION; -TOK_TABLEBUCKETSAMPLE; -TOK_TABLESPLITSAMPLE; -TOK_PERCENT; -TOK_LENGTH; -TOK_ROWCOUNT; -TOK_TMP_FILE; -TOK_TABSORTCOLNAMEASC; -TOK_TABSORTCOLNAMEDESC; -TOK_STRINGLITERALSEQUENCE; -TOK_CREATEFUNCTION; -TOK_DROPFUNCTION; -TOK_RELOADFUNCTION; -TOK_CREATEMACRO; -TOK_DROPMACRO; -TOK_TEMPORARY; -TOK_CREATEVIEW; -TOK_DROPVIEW; -TOK_ALTERVIEW; -TOK_ALTERVIEW_PROPERTIES; -TOK_ALTERVIEW_DROPPROPERTIES; -TOK_ALTERVIEW_ADDPARTS; -TOK_ALTERVIEW_DROPPARTS; -TOK_ALTERVIEW_RENAME; -TOK_VIEWPARTCOLS; -TOK_EXPLAIN; -TOK_EXPLAIN_SQ_REWRITE; -TOK_TABLESERIALIZER; -TOK_TABLEPROPERTIES; -TOK_TABLEPROPLIST; -TOK_INDEXPROPERTIES; -TOK_INDEXPROPLIST; -TOK_TABTYPE; -TOK_LIMIT; -TOK_TABLEPROPERTY; -TOK_IFEXISTS; -TOK_IFNOTEXISTS; -TOK_ORREPLACE; -TOK_HINTLIST; -TOK_HINT; -TOK_MAPJOIN; -TOK_STREAMTABLE; -TOK_HINTARGLIST; -TOK_USERSCRIPTCOLNAMES; -TOK_USERSCRIPTCOLSCHEMA; -TOK_RECORDREADER; -TOK_RECORDWRITER; -TOK_LEFTSEMIJOIN; -TOK_ANTIJOIN; -TOK_LATERAL_VIEW; -TOK_LATERAL_VIEW_OUTER; -TOK_TABALIAS; -TOK_ANALYZE; -TOK_CREATEROLE; -TOK_DROPROLE; -TOK_GRANT; -TOK_REVOKE; -TOK_SHOW_GRANT; -TOK_PRIVILEGE_LIST; -TOK_PRIVILEGE; -TOK_PRINCIPAL_NAME; -TOK_USER; -TOK_GROUP; -TOK_ROLE; -TOK_RESOURCE_ALL; -TOK_GRANT_WITH_OPTION; -TOK_GRANT_WITH_ADMIN_OPTION; -TOK_ADMIN_OPTION_FOR; -TOK_GRANT_OPTION_FOR; -TOK_PRIV_ALL; -TOK_PRIV_ALTER_METADATA; -TOK_PRIV_ALTER_DATA; -TOK_PRIV_DELETE; -TOK_PRIV_DROP; -TOK_PRIV_INDEX; -TOK_PRIV_INSERT; -TOK_PRIV_LOCK; -TOK_PRIV_SELECT; -TOK_PRIV_SHOW_DATABASE; -TOK_PRIV_CREATE; -TOK_PRIV_OBJECT; -TOK_PRIV_OBJECT_COL; -TOK_GRANT_ROLE; -TOK_REVOKE_ROLE; -TOK_SHOW_ROLE_GRANT; -TOK_SHOW_ROLES; -TOK_SHOW_SET_ROLE; -TOK_SHOW_ROLE_PRINCIPALS; -TOK_SHOWINDEXES; -TOK_SHOWDBLOCKS; -TOK_INDEXCOMMENT; -TOK_DESCDATABASE; -TOK_DATABASEPROPERTIES; -TOK_DATABASELOCATION; -TOK_DBPROPLIST; -TOK_ALTERDATABASE_PROPERTIES; -TOK_ALTERDATABASE_OWNER; -TOK_TABNAME; -TOK_TABSRC; -TOK_RESTRICT; -TOK_CASCADE; -TOK_TABLESKEWED; -TOK_TABCOLVALUE; -TOK_TABCOLVALUE_PAIR; -TOK_TABCOLVALUES; -TOK_SKEWED_LOCATIONS; -TOK_SKEWED_LOCATION_LIST; -TOK_SKEWED_LOCATION_MAP; -TOK_STOREDASDIRS; -TOK_PARTITIONINGSPEC; -TOK_PTBLFUNCTION; -TOK_WINDOWDEF; -TOK_WINDOWSPEC; -TOK_WINDOWVALUES; -TOK_WINDOWRANGE; -TOK_SUBQUERY_EXPR; -TOK_SUBQUERY_OP; -TOK_SUBQUERY_OP_NOTIN; -TOK_SUBQUERY_OP_NOTEXISTS; -TOK_DB_TYPE; -TOK_TABLE_TYPE; -TOK_CTE; -TOK_ARCHIVE; -TOK_FILE; -TOK_JAR; -TOK_RESOURCE_URI; -TOK_RESOURCE_LIST; -TOK_SHOW_COMPACTIONS; -TOK_SHOW_TRANSACTIONS; -TOK_DELETE_FROM; -TOK_UPDATE_TABLE; -TOK_SET_COLUMNS_CLAUSE; -TOK_VALUE_ROW; -TOK_VALUES_TABLE; -TOK_VIRTUAL_TABLE; -TOK_VIRTUAL_TABREF; -TOK_ANONYMOUS; -TOK_COL_NAME; -TOK_URI_TYPE; -TOK_SERVER_TYPE; -TOK_START_TRANSACTION; -TOK_ISOLATION_LEVEL; -TOK_ISOLATION_SNAPSHOT; -TOK_TXN_ACCESS_MODE; -TOK_TXN_READ_ONLY; -TOK_TXN_READ_WRITE; -TOK_COMMIT; -TOK_ROLLBACK; -TOK_SET_AUTOCOMMIT; -TOK_REFRESHTABLE; -TOK_TABLEPROVIDER; -TOK_TABLEOPTIONS; -TOK_TABLEOPTION; -TOK_CACHETABLE; -TOK_UNCACHETABLE; -TOK_CLEARCACHE; -TOK_SETCONFIG; -TOK_DFS; -TOK_ADDFILE; -TOK_ADDJAR; -TOK_USING; -} - - -// Package headers -@header { -package org.apache.spark.sql.catalyst.parser; - -import java.util.Arrays; -import java.util.Collection; -import java.util.HashMap; -} - - -@members { - Stack msgs = new Stack(); - - private static HashMap xlateMap; - static { - //this is used to support auto completion in CLI - xlateMap = new HashMap(); - - // Keywords - xlateMap.put("KW_TRUE", "TRUE"); - xlateMap.put("KW_FALSE", "FALSE"); - xlateMap.put("KW_ALL", "ALL"); - xlateMap.put("KW_NONE", "NONE"); - xlateMap.put("KW_AND", "AND"); - xlateMap.put("KW_OR", "OR"); - xlateMap.put("KW_NOT", "NOT"); - xlateMap.put("KW_LIKE", "LIKE"); - - xlateMap.put("KW_ASC", "ASC"); - xlateMap.put("KW_DESC", "DESC"); - xlateMap.put("KW_ORDER", "ORDER"); - xlateMap.put("KW_BY", "BY"); - xlateMap.put("KW_GROUP", "GROUP"); - xlateMap.put("KW_WHERE", "WHERE"); - xlateMap.put("KW_FROM", "FROM"); - xlateMap.put("KW_AS", "AS"); - xlateMap.put("KW_SELECT", "SELECT"); - xlateMap.put("KW_DISTINCT", "DISTINCT"); - xlateMap.put("KW_INSERT", "INSERT"); - xlateMap.put("KW_OVERWRITE", "OVERWRITE"); - xlateMap.put("KW_OUTER", "OUTER"); - xlateMap.put("KW_JOIN", "JOIN"); - xlateMap.put("KW_LEFT", "LEFT"); - xlateMap.put("KW_RIGHT", "RIGHT"); - xlateMap.put("KW_FULL", "FULL"); - xlateMap.put("KW_ON", "ON"); - xlateMap.put("KW_PARTITION", "PARTITION"); - xlateMap.put("KW_PARTITIONS", "PARTITIONS"); - xlateMap.put("KW_TABLE", "TABLE"); - xlateMap.put("KW_TABLES", "TABLES"); - xlateMap.put("KW_TBLPROPERTIES", "TBLPROPERTIES"); - xlateMap.put("KW_SHOW", "SHOW"); - xlateMap.put("KW_MSCK", "MSCK"); - xlateMap.put("KW_DIRECTORY", "DIRECTORY"); - xlateMap.put("KW_LOCAL", "LOCAL"); - xlateMap.put("KW_TRANSFORM", "TRANSFORM"); - xlateMap.put("KW_USING", "USING"); - xlateMap.put("KW_CLUSTER", "CLUSTER"); - xlateMap.put("KW_DISTRIBUTE", "DISTRIBUTE"); - xlateMap.put("KW_SORT", "SORT"); - xlateMap.put("KW_UNION", "UNION"); - xlateMap.put("KW_LOAD", "LOAD"); - xlateMap.put("KW_DATA", "DATA"); - xlateMap.put("KW_INPATH", "INPATH"); - xlateMap.put("KW_IS", "IS"); - xlateMap.put("KW_NULL", "NULL"); - xlateMap.put("KW_CREATE", "CREATE"); - xlateMap.put("KW_EXTERNAL", "EXTERNAL"); - xlateMap.put("KW_ALTER", "ALTER"); - xlateMap.put("KW_DESCRIBE", "DESCRIBE"); - xlateMap.put("KW_DROP", "DROP"); - xlateMap.put("KW_RENAME", "RENAME"); - xlateMap.put("KW_TO", "TO"); - xlateMap.put("KW_COMMENT", "COMMENT"); - xlateMap.put("KW_BOOLEAN", "BOOLEAN"); - xlateMap.put("KW_TINYINT", "TINYINT"); - xlateMap.put("KW_SMALLINT", "SMALLINT"); - xlateMap.put("KW_INT", "INT"); - xlateMap.put("KW_BIGINT", "BIGINT"); - xlateMap.put("KW_FLOAT", "FLOAT"); - xlateMap.put("KW_DOUBLE", "DOUBLE"); - xlateMap.put("KW_DATE", "DATE"); - xlateMap.put("KW_DATETIME", "DATETIME"); - xlateMap.put("KW_TIMESTAMP", "TIMESTAMP"); - xlateMap.put("KW_STRING", "STRING"); - xlateMap.put("KW_BINARY", "BINARY"); - xlateMap.put("KW_ARRAY", "ARRAY"); - xlateMap.put("KW_MAP", "MAP"); - xlateMap.put("KW_REDUCE", "REDUCE"); - xlateMap.put("KW_PARTITIONED", "PARTITIONED"); - xlateMap.put("KW_CLUSTERED", "CLUSTERED"); - xlateMap.put("KW_SORTED", "SORTED"); - xlateMap.put("KW_INTO", "INTO"); - xlateMap.put("KW_BUCKETS", "BUCKETS"); - xlateMap.put("KW_ROW", "ROW"); - xlateMap.put("KW_FORMAT", "FORMAT"); - xlateMap.put("KW_DELIMITED", "DELIMITED"); - xlateMap.put("KW_FIELDS", "FIELDS"); - xlateMap.put("KW_TERMINATED", "TERMINATED"); - xlateMap.put("KW_COLLECTION", "COLLECTION"); - xlateMap.put("KW_ITEMS", "ITEMS"); - xlateMap.put("KW_KEYS", "KEYS"); - xlateMap.put("KW_KEY_TYPE", "\$KEY\$"); - xlateMap.put("KW_LINES", "LINES"); - xlateMap.put("KW_STORED", "STORED"); - xlateMap.put("KW_SEQUENCEFILE", "SEQUENCEFILE"); - xlateMap.put("KW_TEXTFILE", "TEXTFILE"); - xlateMap.put("KW_INPUTFORMAT", "INPUTFORMAT"); - xlateMap.put("KW_OUTPUTFORMAT", "OUTPUTFORMAT"); - xlateMap.put("KW_LOCATION", "LOCATION"); - xlateMap.put("KW_TABLESAMPLE", "TABLESAMPLE"); - xlateMap.put("KW_BUCKET", "BUCKET"); - xlateMap.put("KW_OUT", "OUT"); - xlateMap.put("KW_OF", "OF"); - xlateMap.put("KW_CAST", "CAST"); - xlateMap.put("KW_ADD", "ADD"); - xlateMap.put("KW_REPLACE", "REPLACE"); - xlateMap.put("KW_COLUMNS", "COLUMNS"); - xlateMap.put("KW_RLIKE", "RLIKE"); - xlateMap.put("KW_REGEXP", "REGEXP"); - xlateMap.put("KW_TEMPORARY", "TEMPORARY"); - xlateMap.put("KW_FUNCTION", "FUNCTION"); - xlateMap.put("KW_EXPLAIN", "EXPLAIN"); - xlateMap.put("KW_EXTENDED", "EXTENDED"); - xlateMap.put("KW_SERDE", "SERDE"); - xlateMap.put("KW_WITH", "WITH"); - xlateMap.put("KW_SERDEPROPERTIES", "SERDEPROPERTIES"); - xlateMap.put("KW_LIMIT", "LIMIT"); - xlateMap.put("KW_SET", "SET"); - xlateMap.put("KW_PROPERTIES", "TBLPROPERTIES"); - xlateMap.put("KW_VALUE_TYPE", "\$VALUE\$"); - xlateMap.put("KW_ELEM_TYPE", "\$ELEM\$"); - xlateMap.put("KW_DEFINED", "DEFINED"); - xlateMap.put("KW_SUBQUERY", "SUBQUERY"); - xlateMap.put("KW_REWRITE", "REWRITE"); - xlateMap.put("KW_UPDATE", "UPDATE"); - xlateMap.put("KW_VALUES", "VALUES"); - xlateMap.put("KW_PURGE", "PURGE"); - xlateMap.put("KW_WEEK", "WEEK"); - xlateMap.put("KW_MILLISECOND", "MILLISECOND"); - xlateMap.put("KW_MICROSECOND", "MICROSECOND"); - xlateMap.put("KW_CLEAR", "CLEAR"); - xlateMap.put("KW_LAZY", "LAZY"); - xlateMap.put("KW_CACHE", "CACHE"); - xlateMap.put("KW_UNCACHE", "UNCACHE"); - xlateMap.put("KW_DFS", "DFS"); - - // Operators - xlateMap.put("DOT", "."); - xlateMap.put("COLON", ":"); - xlateMap.put("COMMA", ","); - xlateMap.put("SEMICOLON", ");"); - - xlateMap.put("LPAREN", "("); - xlateMap.put("RPAREN", ")"); - xlateMap.put("LSQUARE", "["); - xlateMap.put("RSQUARE", "]"); - - xlateMap.put("EQUAL", "="); - xlateMap.put("NOTEQUAL", "<>"); - xlateMap.put("EQUAL_NS", "<=>"); - xlateMap.put("LESSTHANOREQUALTO", "<="); - xlateMap.put("LESSTHAN", "<"); - xlateMap.put("GREATERTHANOREQUALTO", ">="); - xlateMap.put("GREATERTHAN", ">"); - - xlateMap.put("DIVIDE", "/"); - xlateMap.put("PLUS", "+"); - xlateMap.put("MINUS", "-"); - xlateMap.put("STAR", "*"); - xlateMap.put("MOD", "\%"); - - xlateMap.put("AMPERSAND", "&"); - xlateMap.put("TILDE", "~"); - xlateMap.put("BITWISEOR", "|"); - xlateMap.put("BITWISEXOR", "^"); - xlateMap.put("CharSetLiteral", "\\'"); - } - - public static Collection getKeywords() { - return xlateMap.values(); - } - - private static String xlate(String name) { - - String ret = xlateMap.get(name); - if (ret == null) { - ret = name; - } - - return ret; - } - - @Override - public Object recoverFromMismatchedSet(IntStream input, - RecognitionException re, BitSet follow) throws RecognitionException { - throw re; - } - - @Override - public void displayRecognitionError(String[] tokenNames, RecognitionException e) { - if (reporter != null) { - reporter.report(this, e, tokenNames); - } - } - - @Override - public String getErrorHeader(RecognitionException e) { - String header = null; - if (e.charPositionInLine < 0 && input.LT(-1) != null) { - Token t = input.LT(-1); - header = "line " + t.getLine() + ":" + t.getCharPositionInLine(); - } else { - header = super.getErrorHeader(e); - } - - return header; - } - - @Override - public String getErrorMessage(RecognitionException e, String[] tokenNames) { - String msg = null; - - // Translate the token names to something that the user can understand - String[] xlateNames = new String[tokenNames.length]; - for (int i = 0; i < tokenNames.length; ++i) { - xlateNames[i] = SparkSqlParser.xlate(tokenNames[i]); - } - - if (e instanceof NoViableAltException) { - @SuppressWarnings("unused") - NoViableAltException nvae = (NoViableAltException) e; - // for development, can add - // "decision=<<"+nvae.grammarDecisionDescription+">>" - // and "(decision="+nvae.decisionNumber+") and - // "state "+nvae.stateNumber - msg = "cannot recognize input near" - + (input.LT(1) != null ? " " + getTokenErrorDisplay(input.LT(1)) : "") - + (input.LT(2) != null ? " " + getTokenErrorDisplay(input.LT(2)) : "") - + (input.LT(3) != null ? " " + getTokenErrorDisplay(input.LT(3)) : ""); - } else if (e instanceof MismatchedTokenException) { - MismatchedTokenException mte = (MismatchedTokenException) e; - msg = super.getErrorMessage(e, xlateNames) + (input.LT(-1) == null ? "":" near '" + input.LT(-1).getText()) + "'"; - } else if (e instanceof FailedPredicateException) { - FailedPredicateException fpe = (FailedPredicateException) e; - msg = "Failed to recognize predicate '" + fpe.token.getText() + "'. Failed rule: '" + fpe.ruleName + "'"; - } else { - msg = super.getErrorMessage(e, xlateNames); - } - - if (msgs.size() > 0) { - msg = msg + " in " + msgs.peek(); - } - return msg; - } - - public void pushMsg(String msg, RecognizerSharedState state) { - // ANTLR generated code does not wrap the @init code wit this backtracking check, - // even if the matching @after has it. If we have parser rules with that are doing - // some lookahead with syntactic predicates this can cause the push() and pop() calls - // to become unbalanced, so make sure both push/pop check the backtracking state. - if (state.backtracking == 0) { - msgs.push(msg); - } - } - - public void popMsg(RecognizerSharedState state) { - if (state.backtracking == 0) { - Object o = msgs.pop(); - } - } - - // counter to generate unique union aliases - private int aliasCounter; - private String generateUnionAlias() { - return "u_" + (++aliasCounter); - } - private char [] excludedCharForColumnName = {'.', ':'}; - private boolean containExcludedCharForCreateTableColumnName(String input) { - if (input.length() > 0) { - if (input.charAt(0) == '`' && input.charAt(input.length() - 1) == '`') { - // When column name is backquoted, we don't care about excluded chars. - return false; - } - } - for(char c : excludedCharForColumnName) { - if(input.indexOf(c)>-1) { - return true; - } - } - return false; - } - private CommonTree throwSetOpException() throws RecognitionException { - throw new FailedPredicateException(input, "orderByClause clusterByClause distributeByClause sortByClause limitClause can only be applied to the whole union.", ""); - } - private CommonTree throwColumnNameException() throws RecognitionException { - throw new FailedPredicateException(input, Arrays.toString(excludedCharForColumnName) + " can not be used in column name in create table statement.", ""); - } - - private ParserConf parserConf; - private ParseErrorReporter reporter; - - public void configure(ParserConf parserConf, ParseErrorReporter reporter) { - this.parserConf = parserConf; - this.reporter = reporter; - } - - protected boolean useSQL11ReservedKeywordsForIdentifier() { - if (parserConf == null) { - return true; - } - return !parserConf.supportSQL11ReservedKeywords(); - } -} - -@rulecatch { -catch (RecognitionException e) { - reportError(e); - throw e; -} -} - -// starting rule -statement - : explainStatement EOF - | execStatement EOF - | KW_ADD KW_JAR -> ^(TOK_ADDJAR) - | KW_ADD KW_FILE -> ^(TOK_ADDFILE) - | KW_DFS -> ^(TOK_DFS) - | (KW_SET)=> KW_SET -> ^(TOK_SETCONFIG) - ; - -// Rule for expression parsing -singleNamedExpression - : - namedExpression EOF - ; - -// Rule for table name parsing -singleTableName - : - tableName EOF - ; - -explainStatement -@init { pushMsg("explain statement", state); } -@after { popMsg(state); } - : KW_EXPLAIN ( - explainOption* execStatement -> ^(TOK_EXPLAIN execStatement explainOption*) - | - KW_REWRITE queryStatementExpression[true] -> ^(TOK_EXPLAIN_SQ_REWRITE queryStatementExpression)) - ; - -explainOption -@init { msgs.push("explain option"); } -@after { msgs.pop(); } - : KW_EXTENDED|KW_FORMATTED|KW_DEPENDENCY|KW_LOGICAL|KW_AUTHORIZATION - ; - -execStatement -@init { pushMsg("statement", state); } -@after { popMsg(state); } - : queryStatementExpression[true] - | loadStatement - | exportStatement - | importStatement - | ddlStatement - | deleteStatement - | updateStatement - | sqlTransactionStatement - | cacheStatement - ; - -loadStatement -@init { pushMsg("load statement", state); } -@after { popMsg(state); } - : KW_LOAD KW_DATA (islocal=KW_LOCAL)? KW_INPATH (path=StringLiteral) (isoverwrite=KW_OVERWRITE)? KW_INTO KW_TABLE (tab=tableOrPartition) - -> ^(TOK_LOAD $path $tab $islocal? $isoverwrite?) - ; - -replicationClause -@init { pushMsg("replication clause", state); } -@after { popMsg(state); } - : KW_FOR (isMetadataOnly=KW_METADATA)? KW_REPLICATION LPAREN (replId=StringLiteral) RPAREN - -> ^(TOK_REPLICATION $replId $isMetadataOnly?) - ; - -exportStatement -@init { pushMsg("export statement", state); } -@after { popMsg(state); } - : KW_EXPORT - KW_TABLE (tab=tableOrPartition) - KW_TO (path=StringLiteral) - replicationClause? - -> ^(TOK_EXPORT $tab $path replicationClause?) - ; - -importStatement -@init { pushMsg("import statement", state); } -@after { popMsg(state); } - : KW_IMPORT - ((ext=KW_EXTERNAL)? KW_TABLE (tab=tableOrPartition))? - KW_FROM (path=StringLiteral) - tableLocation? - -> ^(TOK_IMPORT $path $tab? $ext? tableLocation?) - ; - -ddlStatement -@init { pushMsg("ddl statement", state); } -@after { popMsg(state); } - : createDatabaseStatement - | switchDatabaseStatement - | dropDatabaseStatement - | createTableStatement - | dropTableStatement - | truncateTableStatement - | alterStatement - | descStatement - | refreshStatement - | showStatement - | metastoreCheck - | createViewStatement - | dropViewStatement - | createFunctionStatement - | createMacroStatement - | createIndexStatement - | dropIndexStatement - | dropFunctionStatement - | reloadFunctionStatement - | dropMacroStatement - | analyzeStatement - | lockStatement - | unlockStatement - | lockDatabase - | unlockDatabase - | createRoleStatement - | dropRoleStatement - | (grantPrivileges) => grantPrivileges - | (revokePrivileges) => revokePrivileges - | showGrants - | showRoleGrants - | showRolePrincipals - | showRoles - | grantRole - | revokeRole - | setRole - | showCurrentRole - ; - -ifExists -@init { pushMsg("if exists clause", state); } -@after { popMsg(state); } - : KW_IF KW_EXISTS - -> ^(TOK_IFEXISTS) - ; - -restrictOrCascade -@init { pushMsg("restrict or cascade clause", state); } -@after { popMsg(state); } - : KW_RESTRICT - -> ^(TOK_RESTRICT) - | KW_CASCADE - -> ^(TOK_CASCADE) - ; - -ifNotExists -@init { pushMsg("if not exists clause", state); } -@after { popMsg(state); } - : KW_IF KW_NOT KW_EXISTS - -> ^(TOK_IFNOTEXISTS) - ; - -storedAsDirs -@init { pushMsg("stored as directories", state); } -@after { popMsg(state); } - : KW_STORED KW_AS KW_DIRECTORIES - -> ^(TOK_STOREDASDIRS) - ; - -orReplace -@init { pushMsg("or replace clause", state); } -@after { popMsg(state); } - : KW_OR KW_REPLACE - -> ^(TOK_ORREPLACE) - ; - -createDatabaseStatement -@init { pushMsg("create database statement", state); } -@after { popMsg(state); } - : KW_CREATE (KW_DATABASE|KW_SCHEMA) - ifNotExists? - name=identifier - databaseComment? - dbLocation? - (KW_WITH KW_DBPROPERTIES dbprops=dbProperties)? - -> ^(TOK_CREATEDATABASE $name ifNotExists? dbLocation? databaseComment? $dbprops?) - ; - -dbLocation -@init { pushMsg("database location specification", state); } -@after { popMsg(state); } - : - KW_LOCATION locn=StringLiteral -> ^(TOK_DATABASELOCATION $locn) - ; - -dbProperties -@init { pushMsg("dbproperties", state); } -@after { popMsg(state); } - : - LPAREN dbPropertiesList RPAREN -> ^(TOK_DATABASEPROPERTIES dbPropertiesList) - ; - -dbPropertiesList -@init { pushMsg("database properties list", state); } -@after { popMsg(state); } - : - keyValueProperty (COMMA keyValueProperty)* -> ^(TOK_DBPROPLIST keyValueProperty+) - ; - - -switchDatabaseStatement -@init { pushMsg("switch database statement", state); } -@after { popMsg(state); } - : KW_USE identifier - -> ^(TOK_SWITCHDATABASE identifier) - ; - -dropDatabaseStatement -@init { pushMsg("drop database statement", state); } -@after { popMsg(state); } - : KW_DROP (KW_DATABASE|KW_SCHEMA) ifExists? identifier restrictOrCascade? - -> ^(TOK_DROPDATABASE identifier ifExists? restrictOrCascade?) - ; - -databaseComment -@init { pushMsg("database's comment", state); } -@after { popMsg(state); } - : KW_COMMENT comment=StringLiteral - -> ^(TOK_DATABASECOMMENT $comment) - ; - -createTableStatement -@init { pushMsg("create table statement", state); } -@after { popMsg(state); } - : KW_CREATE (temp=KW_TEMPORARY)? (ext=KW_EXTERNAL)? KW_TABLE ifNotExists? name=tableName - ( - like=KW_LIKE likeName=tableName - tableRowFormat? - tableFileFormat? - tableLocation? - tablePropertiesPrefixed? - -> ^(TOK_CREATETABLE $name $temp? $ext? ifNotExists? - ^(TOK_LIKETABLE $likeName?) - tableRowFormat? - tableFileFormat? - tableLocation? - tablePropertiesPrefixed? - ) - | - (tableProvider) => tableProvider - tableOpts? - (KW_AS selectStatementWithCTE)? - -> ^(TOK_CREATETABLEUSING $name $temp? ifNotExists? - tableProvider - tableOpts? - selectStatementWithCTE? - ) - | (LPAREN columnNameTypeList RPAREN)? - (p=tableProvider?) - tableOpts? - tableComment? - tablePartition? - tableBuckets? - tableSkewed? - tableRowFormat? - tableFileFormat? - tableLocation? - tablePropertiesPrefixed? - (KW_AS selectStatementWithCTE)? - -> {p != null}? - ^(TOK_CREATETABLEUSING $name $temp? ifNotExists? - columnNameTypeList? - $p - tableOpts? - selectStatementWithCTE? - ) - -> - ^(TOK_CREATETABLE $name $temp? $ext? ifNotExists? - ^(TOK_LIKETABLE $likeName?) - columnNameTypeList? - tableComment? - tablePartition? - tableBuckets? - tableSkewed? - tableRowFormat? - tableFileFormat? - tableLocation? - tablePropertiesPrefixed? - selectStatementWithCTE? - ) - ) - ; - -truncateTableStatement -@init { pushMsg("truncate table statement", state); } -@after { popMsg(state); } - : KW_TRUNCATE KW_TABLE tablePartitionPrefix (KW_COLUMNS LPAREN columnNameList RPAREN)? -> ^(TOK_TRUNCATETABLE tablePartitionPrefix columnNameList?); - -createIndexStatement -@init { pushMsg("create index statement", state);} -@after {popMsg(state);} - : KW_CREATE KW_INDEX indexName=identifier - KW_ON KW_TABLE tab=tableName LPAREN indexedCols=columnNameList RPAREN - KW_AS typeName=StringLiteral - autoRebuild? - indexPropertiesPrefixed? - indexTblName? - tableRowFormat? - tableFileFormat? - tableLocation? - tablePropertiesPrefixed? - indexComment? - ->^(TOK_CREATEINDEX $indexName $typeName $tab $indexedCols - autoRebuild? - indexPropertiesPrefixed? - indexTblName? - tableRowFormat? - tableFileFormat? - tableLocation? - tablePropertiesPrefixed? - indexComment?) - ; - -indexComment -@init { pushMsg("comment on an index", state);} -@after {popMsg(state);} - : - KW_COMMENT comment=StringLiteral -> ^(TOK_INDEXCOMMENT $comment) - ; - -autoRebuild -@init { pushMsg("auto rebuild index", state);} -@after {popMsg(state);} - : KW_WITH KW_DEFERRED KW_REBUILD - ->^(TOK_DEFERRED_REBUILDINDEX) - ; - -indexTblName -@init { pushMsg("index table name", state);} -@after {popMsg(state);} - : KW_IN KW_TABLE indexTbl=tableName - ->^(TOK_CREATEINDEX_INDEXTBLNAME $indexTbl) - ; - -indexPropertiesPrefixed -@init { pushMsg("table properties with prefix", state); } -@after { popMsg(state); } - : - KW_IDXPROPERTIES! indexProperties - ; - -indexProperties -@init { pushMsg("index properties", state); } -@after { popMsg(state); } - : - LPAREN indexPropertiesList RPAREN -> ^(TOK_INDEXPROPERTIES indexPropertiesList) - ; - -indexPropertiesList -@init { pushMsg("index properties list", state); } -@after { popMsg(state); } - : - keyValueProperty (COMMA keyValueProperty)* -> ^(TOK_INDEXPROPLIST keyValueProperty+) - ; - -dropIndexStatement -@init { pushMsg("drop index statement", state);} -@after {popMsg(state);} - : KW_DROP KW_INDEX ifExists? indexName=identifier KW_ON tab=tableName - ->^(TOK_DROPINDEX $indexName $tab ifExists?) - ; - -dropTableStatement -@init { pushMsg("drop statement", state); } -@after { popMsg(state); } - : KW_DROP KW_TABLE ifExists? tableName KW_PURGE? replicationClause? - -> ^(TOK_DROPTABLE tableName ifExists? KW_PURGE? replicationClause?) - ; - -alterStatement -@init { pushMsg("alter statement", state); } -@after { popMsg(state); } - : KW_ALTER KW_TABLE tableName alterTableStatementSuffix -> ^(TOK_ALTERTABLE tableName alterTableStatementSuffix) - | KW_ALTER KW_VIEW tableName KW_AS? alterViewStatementSuffix -> ^(TOK_ALTERVIEW tableName alterViewStatementSuffix) - | KW_ALTER KW_INDEX alterIndexStatementSuffix -> alterIndexStatementSuffix - | KW_ALTER (KW_DATABASE|KW_SCHEMA) alterDatabaseStatementSuffix -> alterDatabaseStatementSuffix - ; - -alterTableStatementSuffix -@init { pushMsg("alter table statement", state); } -@after { popMsg(state); } - : (alterStatementSuffixRename[true]) => alterStatementSuffixRename[true] - | alterStatementSuffixDropPartitions[true] - | alterStatementSuffixAddPartitions[true] - | alterStatementSuffixTouch - | alterStatementSuffixArchive - | alterStatementSuffixUnArchive - | alterStatementSuffixProperties - | alterStatementSuffixSkewedby - | alterStatementSuffixExchangePartition - | alterStatementPartitionKeyType - | partitionSpec? alterTblPartitionStatementSuffix -> alterTblPartitionStatementSuffix partitionSpec? - ; - -alterTblPartitionStatementSuffix -@init {pushMsg("alter table partition statement suffix", state);} -@after {popMsg(state);} - : alterStatementSuffixFileFormat - | alterStatementSuffixLocation - | alterStatementSuffixMergeFiles - | alterStatementSuffixSerdeProperties - | alterStatementSuffixRenamePart - | alterStatementSuffixBucketNum - | alterTblPartitionStatementSuffixSkewedLocation - | alterStatementSuffixClusterbySortby - | alterStatementSuffixCompact - | alterStatementSuffixUpdateStatsCol - | alterStatementSuffixRenameCol - | alterStatementSuffixAddCol - ; - -alterStatementPartitionKeyType -@init {msgs.push("alter partition key type"); } -@after {msgs.pop();} - : KW_PARTITION KW_COLUMN LPAREN columnNameType RPAREN - -> ^(TOK_ALTERTABLE_PARTCOLTYPE columnNameType) - ; - -alterViewStatementSuffix -@init { pushMsg("alter view statement", state); } -@after { popMsg(state); } - : alterViewSuffixProperties - | alterStatementSuffixRename[false] - | alterStatementSuffixAddPartitions[false] - | alterStatementSuffixDropPartitions[false] - | selectStatementWithCTE - ; - -alterIndexStatementSuffix -@init { pushMsg("alter index statement", state); } -@after { popMsg(state); } - : indexName=identifier KW_ON tableName partitionSpec? - ( - KW_REBUILD - ->^(TOK_ALTERINDEX_REBUILD tableName $indexName partitionSpec?) - | - KW_SET KW_IDXPROPERTIES - indexProperties - ->^(TOK_ALTERINDEX_PROPERTIES tableName $indexName indexProperties) - ) - ; - -alterDatabaseStatementSuffix -@init { pushMsg("alter database statement", state); } -@after { popMsg(state); } - : alterDatabaseSuffixProperties - | alterDatabaseSuffixSetOwner - ; - -alterDatabaseSuffixProperties -@init { pushMsg("alter database properties statement", state); } -@after { popMsg(state); } - : name=identifier KW_SET KW_DBPROPERTIES dbProperties - -> ^(TOK_ALTERDATABASE_PROPERTIES $name dbProperties) - ; - -alterDatabaseSuffixSetOwner -@init { pushMsg("alter database set owner", state); } -@after { popMsg(state); } - : dbName=identifier KW_SET KW_OWNER principalName - -> ^(TOK_ALTERDATABASE_OWNER $dbName principalName) - ; - -alterStatementSuffixRename[boolean table] -@init { pushMsg("rename statement", state); } -@after { popMsg(state); } - : KW_RENAME KW_TO tableName - -> { table }? ^(TOK_ALTERTABLE_RENAME tableName) - -> ^(TOK_ALTERVIEW_RENAME tableName) - ; - -alterStatementSuffixAddCol -@init { pushMsg("add column statement", state); } -@after { popMsg(state); } - : (add=KW_ADD | replace=KW_REPLACE) KW_COLUMNS LPAREN columnNameTypeList RPAREN restrictOrCascade? - -> {$add != null}? ^(TOK_ALTERTABLE_ADDCOLS columnNameTypeList restrictOrCascade?) - -> ^(TOK_ALTERTABLE_REPLACECOLS columnNameTypeList restrictOrCascade?) - ; - -alterStatementSuffixRenameCol -@init { pushMsg("rename column name", state); } -@after { popMsg(state); } - : KW_CHANGE KW_COLUMN? oldName=identifier newName=identifier colType (KW_COMMENT comment=StringLiteral)? alterStatementChangeColPosition? restrictOrCascade? - ->^(TOK_ALTERTABLE_RENAMECOL $oldName $newName colType $comment? alterStatementChangeColPosition? restrictOrCascade?) - ; - -alterStatementSuffixUpdateStatsCol -@init { pushMsg("update column statistics", state); } -@after { popMsg(state); } - : KW_UPDATE KW_STATISTICS KW_FOR KW_COLUMN? colName=identifier KW_SET tableProperties (KW_COMMENT comment=StringLiteral)? - ->^(TOK_ALTERTABLE_UPDATECOLSTATS $colName tableProperties $comment?) - ; - -alterStatementChangeColPosition - : first=KW_FIRST|KW_AFTER afterCol=identifier - ->{$first != null}? ^(TOK_ALTERTABLE_CHANGECOL_AFTER_POSITION ) - -> ^(TOK_ALTERTABLE_CHANGECOL_AFTER_POSITION $afterCol) - ; - -alterStatementSuffixAddPartitions[boolean table] -@init { pushMsg("add partition statement", state); } -@after { popMsg(state); } - : KW_ADD ifNotExists? alterStatementSuffixAddPartitionsElement+ - -> { table }? ^(TOK_ALTERTABLE_ADDPARTS ifNotExists? alterStatementSuffixAddPartitionsElement+) - -> ^(TOK_ALTERVIEW_ADDPARTS ifNotExists? alterStatementSuffixAddPartitionsElement+) - ; - -alterStatementSuffixAddPartitionsElement - : partitionSpec partitionLocation? - ; - -alterStatementSuffixTouch -@init { pushMsg("touch statement", state); } -@after { popMsg(state); } - : KW_TOUCH (partitionSpec)* - -> ^(TOK_ALTERTABLE_TOUCH (partitionSpec)*) - ; - -alterStatementSuffixArchive -@init { pushMsg("archive statement", state); } -@after { popMsg(state); } - : KW_ARCHIVE (partitionSpec)* - -> ^(TOK_ALTERTABLE_ARCHIVE (partitionSpec)*) - ; - -alterStatementSuffixUnArchive -@init { pushMsg("unarchive statement", state); } -@after { popMsg(state); } - : KW_UNARCHIVE (partitionSpec)* - -> ^(TOK_ALTERTABLE_UNARCHIVE (partitionSpec)*) - ; - -partitionLocation -@init { pushMsg("partition location", state); } -@after { popMsg(state); } - : - KW_LOCATION locn=StringLiteral -> ^(TOK_PARTITIONLOCATION $locn) - ; - -alterStatementSuffixDropPartitions[boolean table] -@init { pushMsg("drop partition statement", state); } -@after { popMsg(state); } - : KW_DROP ifExists? dropPartitionSpec (COMMA dropPartitionSpec)* KW_PURGE? replicationClause? - -> { table }? ^(TOK_ALTERTABLE_DROPPARTS dropPartitionSpec+ ifExists? KW_PURGE? replicationClause?) - -> ^(TOK_ALTERVIEW_DROPPARTS dropPartitionSpec+ ifExists? replicationClause?) - ; - -alterStatementSuffixProperties -@init { pushMsg("alter properties statement", state); } -@after { popMsg(state); } - : KW_SET KW_TBLPROPERTIES tableProperties - -> ^(TOK_ALTERTABLE_PROPERTIES tableProperties) - | KW_UNSET KW_TBLPROPERTIES ifExists? tableProperties - -> ^(TOK_ALTERTABLE_DROPPROPERTIES tableProperties ifExists?) - ; - -alterViewSuffixProperties -@init { pushMsg("alter view properties statement", state); } -@after { popMsg(state); } - : KW_SET KW_TBLPROPERTIES tableProperties - -> ^(TOK_ALTERVIEW_PROPERTIES tableProperties) - | KW_UNSET KW_TBLPROPERTIES ifExists? tableProperties - -> ^(TOK_ALTERVIEW_DROPPROPERTIES tableProperties ifExists?) - ; - -alterStatementSuffixSerdeProperties -@init { pushMsg("alter serdes statement", state); } -@after { popMsg(state); } - : KW_SET KW_SERDE serdeName=StringLiteral (KW_WITH KW_SERDEPROPERTIES tableProperties)? - -> ^(TOK_ALTERTABLE_SERIALIZER $serdeName tableProperties?) - | KW_SET KW_SERDEPROPERTIES tableProperties - -> ^(TOK_ALTERTABLE_SERDEPROPERTIES tableProperties) - ; - -tablePartitionPrefix -@init {pushMsg("table partition prefix", state);} -@after {popMsg(state);} - : tableName partitionSpec? - ->^(TOK_TABLE_PARTITION tableName partitionSpec?) - ; - -alterStatementSuffixFileFormat -@init {pushMsg("alter fileformat statement", state); } -@after {popMsg(state);} - : KW_SET KW_FILEFORMAT fileFormat - -> ^(TOK_ALTERTABLE_FILEFORMAT fileFormat) - ; - -alterStatementSuffixClusterbySortby -@init {pushMsg("alter partition cluster by sort by statement", state);} -@after {popMsg(state);} - : KW_NOT KW_CLUSTERED -> ^(TOK_ALTERTABLE_CLUSTER_SORT TOK_NOT_CLUSTERED) - | KW_NOT KW_SORTED -> ^(TOK_ALTERTABLE_CLUSTER_SORT TOK_NOT_SORTED) - | tableBuckets -> ^(TOK_ALTERTABLE_CLUSTER_SORT tableBuckets) - ; - -alterTblPartitionStatementSuffixSkewedLocation -@init {pushMsg("alter partition skewed location", state);} -@after {popMsg(state);} - : KW_SET KW_SKEWED KW_LOCATION skewedLocations - -> ^(TOK_ALTERTABLE_SKEWED_LOCATION skewedLocations) - ; - -skewedLocations -@init { pushMsg("skewed locations", state); } -@after { popMsg(state); } - : - LPAREN skewedLocationsList RPAREN -> ^(TOK_SKEWED_LOCATIONS skewedLocationsList) - ; - -skewedLocationsList -@init { pushMsg("skewed locations list", state); } -@after { popMsg(state); } - : - skewedLocationMap (COMMA skewedLocationMap)* -> ^(TOK_SKEWED_LOCATION_LIST skewedLocationMap+) - ; - -skewedLocationMap -@init { pushMsg("specifying skewed location map", state); } -@after { popMsg(state); } - : - key=skewedValueLocationElement EQUAL value=StringLiteral -> ^(TOK_SKEWED_LOCATION_MAP $key $value) - ; - -alterStatementSuffixLocation -@init {pushMsg("alter location", state);} -@after {popMsg(state);} - : KW_SET KW_LOCATION newLoc=StringLiteral - -> ^(TOK_ALTERTABLE_LOCATION $newLoc) - ; - - -alterStatementSuffixSkewedby -@init {pushMsg("alter skewed by statement", state);} -@after{popMsg(state);} - : tableSkewed - ->^(TOK_ALTERTABLE_SKEWED tableSkewed) - | - KW_NOT KW_SKEWED - ->^(TOK_ALTERTABLE_SKEWED) - | - KW_NOT storedAsDirs - ->^(TOK_ALTERTABLE_SKEWED storedAsDirs) - ; - -alterStatementSuffixExchangePartition -@init {pushMsg("alter exchange partition", state);} -@after{popMsg(state);} - : KW_EXCHANGE partitionSpec KW_WITH KW_TABLE exchangename=tableName - -> ^(TOK_ALTERTABLE_EXCHANGEPARTITION partitionSpec $exchangename) - ; - -alterStatementSuffixRenamePart -@init { pushMsg("alter table rename partition statement", state); } -@after { popMsg(state); } - : KW_RENAME KW_TO partitionSpec - ->^(TOK_ALTERTABLE_RENAMEPART partitionSpec) - ; - -alterStatementSuffixStatsPart -@init { pushMsg("alter table stats partition statement", state); } -@after { popMsg(state); } - : KW_UPDATE KW_STATISTICS KW_FOR KW_COLUMN? colName=identifier KW_SET tableProperties (KW_COMMENT comment=StringLiteral)? - ->^(TOK_ALTERTABLE_UPDATECOLSTATS $colName tableProperties $comment?) - ; - -alterStatementSuffixMergeFiles -@init { pushMsg("", state); } -@after { popMsg(state); } - : KW_CONCATENATE - -> ^(TOK_ALTERTABLE_MERGEFILES) - ; - -alterStatementSuffixBucketNum -@init { pushMsg("", state); } -@after { popMsg(state); } - : KW_INTO num=Number KW_BUCKETS - -> ^(TOK_ALTERTABLE_BUCKETS $num) - ; - -alterStatementSuffixCompact -@init { msgs.push("compaction request"); } -@after { msgs.pop(); } - : KW_COMPACT compactType=StringLiteral - -> ^(TOK_ALTERTABLE_COMPACT $compactType) - ; - - -fileFormat -@init { pushMsg("file format specification", state); } -@after { popMsg(state); } - : KW_INPUTFORMAT inFmt=StringLiteral KW_OUTPUTFORMAT outFmt=StringLiteral KW_SERDE serdeCls=StringLiteral (KW_INPUTDRIVER inDriver=StringLiteral KW_OUTPUTDRIVER outDriver=StringLiteral)? - -> ^(TOK_TABLEFILEFORMAT $inFmt $outFmt $serdeCls $inDriver? $outDriver?) - | genericSpec=identifier -> ^(TOK_FILEFORMAT_GENERIC $genericSpec) - ; - -tabTypeExpr -@init { pushMsg("specifying table types", state); } -@after { popMsg(state); } - : identifier (DOT^ identifier)? - (identifier (DOT^ - ( - (KW_ELEM_TYPE) => KW_ELEM_TYPE - | - (KW_KEY_TYPE) => KW_KEY_TYPE - | - (KW_VALUE_TYPE) => KW_VALUE_TYPE - | identifier - ))* - )? - ; - -partTypeExpr -@init { pushMsg("specifying table partitions", state); } -@after { popMsg(state); } - : tabTypeExpr partitionSpec? -> ^(TOK_TABTYPE tabTypeExpr partitionSpec?) - ; - -tabPartColTypeExpr -@init { pushMsg("specifying table partitions columnName", state); } -@after { popMsg(state); } - : tableName partitionSpec? extColumnName? -> ^(TOK_TABTYPE tableName partitionSpec? extColumnName?) - ; - -refreshStatement -@init { pushMsg("refresh statement", state); } -@after { popMsg(state); } - : - KW_REFRESH KW_TABLE tableName -> ^(TOK_REFRESHTABLE tableName) - ; - -descStatement -@init { pushMsg("describe statement", state); } -@after { popMsg(state); } - : - (KW_DESCRIBE|KW_DESC) - ( - (KW_DATABASE|KW_SCHEMA) => (KW_DATABASE|KW_SCHEMA) KW_EXTENDED? (dbName=identifier) -> ^(TOK_DESCDATABASE $dbName KW_EXTENDED?) - | - (KW_FUNCTION) => KW_FUNCTION KW_EXTENDED? (name=descFuncNames) -> ^(TOK_DESCFUNCTION $name KW_EXTENDED?) - | - (KW_FORMATTED|KW_EXTENDED|KW_PRETTY) => ((descOptions=KW_FORMATTED|descOptions=KW_EXTENDED|descOptions=KW_PRETTY) parttype=tabPartColTypeExpr) -> ^(TOK_DESCTABLE $parttype $descOptions) - | - parttype=tabPartColTypeExpr -> ^(TOK_DESCTABLE $parttype) - ) - ; - -analyzeStatement -@init { pushMsg("analyze statement", state); } -@after { popMsg(state); } - : KW_ANALYZE KW_TABLE (parttype=tableOrPartition) KW_COMPUTE KW_STATISTICS ((noscan=KW_NOSCAN) | (partialscan=KW_PARTIALSCAN) - | (KW_FOR KW_COLUMNS (statsColumnName=columnNameList)?))? - -> ^(TOK_ANALYZE $parttype $noscan? $partialscan? KW_COLUMNS? $statsColumnName?) - ; - -showStatement -@init { pushMsg("show statement", state); } -@after { popMsg(state); } - : KW_SHOW (KW_DATABASES|KW_SCHEMAS) (KW_LIKE showStmtIdentifier)? -> ^(TOK_SHOWDATABASES showStmtIdentifier?) - | KW_SHOW KW_TABLES ((KW_FROM|KW_IN) db_name=identifier)? (KW_LIKE showStmtIdentifier|showStmtIdentifier)? -> ^(TOK_SHOWTABLES ^(TOK_FROM $db_name)? showStmtIdentifier?) - | KW_SHOW KW_COLUMNS (KW_FROM|KW_IN) tableName ((KW_FROM|KW_IN) db_name=identifier)? - -> ^(TOK_SHOWCOLUMNS tableName $db_name?) - | KW_SHOW KW_FUNCTIONS (KW_LIKE showFunctionIdentifier|showFunctionIdentifier)? -> ^(TOK_SHOWFUNCTIONS KW_LIKE? showFunctionIdentifier?) - | KW_SHOW KW_PARTITIONS tabName=tableName partitionSpec? -> ^(TOK_SHOWPARTITIONS $tabName partitionSpec?) - | KW_SHOW KW_CREATE ( - (KW_DATABASE|KW_SCHEMA) => (KW_DATABASE|KW_SCHEMA) db_name=identifier -> ^(TOK_SHOW_CREATEDATABASE $db_name) - | - KW_TABLE tabName=tableName -> ^(TOK_SHOW_CREATETABLE $tabName) - ) - | KW_SHOW KW_TABLE KW_EXTENDED ((KW_FROM|KW_IN) db_name=identifier)? KW_LIKE showStmtIdentifier partitionSpec? - -> ^(TOK_SHOW_TABLESTATUS showStmtIdentifier $db_name? partitionSpec?) - | KW_SHOW KW_TBLPROPERTIES tableName (LPAREN prptyName=StringLiteral RPAREN)? -> ^(TOK_SHOW_TBLPROPERTIES tableName $prptyName?) - | KW_SHOW KW_LOCKS - ( - (KW_DATABASE|KW_SCHEMA) => (KW_DATABASE|KW_SCHEMA) (dbName=Identifier) (isExtended=KW_EXTENDED)? -> ^(TOK_SHOWDBLOCKS $dbName $isExtended?) - | - (parttype=partTypeExpr)? (isExtended=KW_EXTENDED)? -> ^(TOK_SHOWLOCKS $parttype? $isExtended?) - ) - | KW_SHOW (showOptions=KW_FORMATTED)? (KW_INDEX|KW_INDEXES) KW_ON showStmtIdentifier ((KW_FROM|KW_IN) db_name=identifier)? - -> ^(TOK_SHOWINDEXES showStmtIdentifier $showOptions? $db_name?) - | KW_SHOW KW_COMPACTIONS -> ^(TOK_SHOW_COMPACTIONS) - | KW_SHOW KW_TRANSACTIONS -> ^(TOK_SHOW_TRANSACTIONS) - | KW_SHOW KW_CONF StringLiteral -> ^(TOK_SHOWCONF StringLiteral) - ; - -lockStatement -@init { pushMsg("lock statement", state); } -@after { popMsg(state); } - : KW_LOCK KW_TABLE tableName partitionSpec? lockMode -> ^(TOK_LOCKTABLE tableName lockMode partitionSpec?) - ; - -lockDatabase -@init { pushMsg("lock database statement", state); } -@after { popMsg(state); } - : KW_LOCK (KW_DATABASE|KW_SCHEMA) (dbName=Identifier) lockMode -> ^(TOK_LOCKDB $dbName lockMode) - ; - -lockMode -@init { pushMsg("lock mode", state); } -@after { popMsg(state); } - : KW_SHARED | KW_EXCLUSIVE - ; - -unlockStatement -@init { pushMsg("unlock statement", state); } -@after { popMsg(state); } - : KW_UNLOCK KW_TABLE tableName partitionSpec? -> ^(TOK_UNLOCKTABLE tableName partitionSpec?) - ; - -unlockDatabase -@init { pushMsg("unlock database statement", state); } -@after { popMsg(state); } - : KW_UNLOCK (KW_DATABASE|KW_SCHEMA) (dbName=Identifier) -> ^(TOK_UNLOCKDB $dbName) - ; - -createRoleStatement -@init { pushMsg("create role", state); } -@after { popMsg(state); } - : KW_CREATE KW_ROLE roleName=identifier - -> ^(TOK_CREATEROLE $roleName) - ; - -dropRoleStatement -@init {pushMsg("drop role", state);} -@after {popMsg(state);} - : KW_DROP KW_ROLE roleName=identifier - -> ^(TOK_DROPROLE $roleName) - ; - -grantPrivileges -@init {pushMsg("grant privileges", state);} -@after {popMsg(state);} - : KW_GRANT privList=privilegeList - privilegeObject? - KW_TO principalSpecification - withGrantOption? - -> ^(TOK_GRANT $privList principalSpecification privilegeObject? withGrantOption?) - ; - -revokePrivileges -@init {pushMsg("revoke privileges", state);} -@afer {popMsg(state);} - : KW_REVOKE grantOptionFor? privilegeList privilegeObject? KW_FROM principalSpecification - -> ^(TOK_REVOKE privilegeList principalSpecification privilegeObject? grantOptionFor?) - ; - -grantRole -@init {pushMsg("grant role", state);} -@after {popMsg(state);} - : KW_GRANT KW_ROLE? identifier (COMMA identifier)* KW_TO principalSpecification withAdminOption? - -> ^(TOK_GRANT_ROLE principalSpecification withAdminOption? identifier+) - ; - -revokeRole -@init {pushMsg("revoke role", state);} -@after {popMsg(state);} - : KW_REVOKE adminOptionFor? KW_ROLE? identifier (COMMA identifier)* KW_FROM principalSpecification - -> ^(TOK_REVOKE_ROLE principalSpecification adminOptionFor? identifier+) - ; - -showRoleGrants -@init {pushMsg("show role grants", state);} -@after {popMsg(state);} - : KW_SHOW KW_ROLE KW_GRANT principalName - -> ^(TOK_SHOW_ROLE_GRANT principalName) - ; - - -showRoles -@init {pushMsg("show roles", state);} -@after {popMsg(state);} - : KW_SHOW KW_ROLES - -> ^(TOK_SHOW_ROLES) - ; - -showCurrentRole -@init {pushMsg("show current role", state);} -@after {popMsg(state);} - : KW_SHOW KW_CURRENT KW_ROLES - -> ^(TOK_SHOW_SET_ROLE) - ; - -setRole -@init {pushMsg("set role", state);} -@after {popMsg(state);} - : KW_SET KW_ROLE - ( - (KW_ALL) => (all=KW_ALL) -> ^(TOK_SHOW_SET_ROLE Identifier[$all.text]) - | - (KW_NONE) => (none=KW_NONE) -> ^(TOK_SHOW_SET_ROLE Identifier[$none.text]) - | - identifier -> ^(TOK_SHOW_SET_ROLE identifier) - ) - ; - -showGrants -@init {pushMsg("show grants", state);} -@after {popMsg(state);} - : KW_SHOW KW_GRANT principalName? (KW_ON privilegeIncludeColObject)? - -> ^(TOK_SHOW_GRANT principalName? privilegeIncludeColObject?) - ; - -showRolePrincipals -@init {pushMsg("show role principals", state);} -@after {popMsg(state);} - : KW_SHOW KW_PRINCIPALS roleName=identifier - -> ^(TOK_SHOW_ROLE_PRINCIPALS $roleName) - ; - - -privilegeIncludeColObject -@init {pushMsg("privilege object including columns", state);} -@after {popMsg(state);} - : (KW_ALL) => KW_ALL -> ^(TOK_RESOURCE_ALL) - | privObjectCols -> ^(TOK_PRIV_OBJECT_COL privObjectCols) - ; - -privilegeObject -@init {pushMsg("privilege object", state);} -@after {popMsg(state);} - : KW_ON privObject -> ^(TOK_PRIV_OBJECT privObject) - ; - -// database or table type. Type is optional, default type is table -privObject - : (KW_DATABASE|KW_SCHEMA) identifier -> ^(TOK_DB_TYPE identifier) - | KW_TABLE? tableName partitionSpec? -> ^(TOK_TABLE_TYPE tableName partitionSpec?) - | KW_URI (path=StringLiteral) -> ^(TOK_URI_TYPE $path) - | KW_SERVER identifier -> ^(TOK_SERVER_TYPE identifier) - ; - -privObjectCols - : (KW_DATABASE|KW_SCHEMA) identifier -> ^(TOK_DB_TYPE identifier) - | KW_TABLE? tableName (LPAREN cols=columnNameList RPAREN)? partitionSpec? -> ^(TOK_TABLE_TYPE tableName $cols? partitionSpec?) - | KW_URI (path=StringLiteral) -> ^(TOK_URI_TYPE $path) - | KW_SERVER identifier -> ^(TOK_SERVER_TYPE identifier) - ; - -privilegeList -@init {pushMsg("grant privilege list", state);} -@after {popMsg(state);} - : privlegeDef (COMMA privlegeDef)* - -> ^(TOK_PRIVILEGE_LIST privlegeDef+) - ; - -privlegeDef -@init {pushMsg("grant privilege", state);} -@after {popMsg(state);} - : privilegeType (LPAREN cols=columnNameList RPAREN)? - -> ^(TOK_PRIVILEGE privilegeType $cols?) - ; - -privilegeType -@init {pushMsg("privilege type", state);} -@after {popMsg(state);} - : KW_ALL -> ^(TOK_PRIV_ALL) - | KW_ALTER -> ^(TOK_PRIV_ALTER_METADATA) - | KW_UPDATE -> ^(TOK_PRIV_ALTER_DATA) - | KW_CREATE -> ^(TOK_PRIV_CREATE) - | KW_DROP -> ^(TOK_PRIV_DROP) - | KW_INDEX -> ^(TOK_PRIV_INDEX) - | KW_LOCK -> ^(TOK_PRIV_LOCK) - | KW_SELECT -> ^(TOK_PRIV_SELECT) - | KW_SHOW_DATABASE -> ^(TOK_PRIV_SHOW_DATABASE) - | KW_INSERT -> ^(TOK_PRIV_INSERT) - | KW_DELETE -> ^(TOK_PRIV_DELETE) - ; - -principalSpecification -@init { pushMsg("user/group/role name list", state); } -@after { popMsg(state); } - : principalName (COMMA principalName)* -> ^(TOK_PRINCIPAL_NAME principalName+) - ; - -principalName -@init {pushMsg("user|group|role name", state);} -@after {popMsg(state);} - : KW_USER principalIdentifier -> ^(TOK_USER principalIdentifier) - | KW_GROUP principalIdentifier -> ^(TOK_GROUP principalIdentifier) - | KW_ROLE identifier -> ^(TOK_ROLE identifier) - ; - -withGrantOption -@init {pushMsg("with grant option", state);} -@after {popMsg(state);} - : KW_WITH KW_GRANT KW_OPTION - -> ^(TOK_GRANT_WITH_OPTION) - ; - -grantOptionFor -@init {pushMsg("grant option for", state);} -@after {popMsg(state);} - : KW_GRANT KW_OPTION KW_FOR - -> ^(TOK_GRANT_OPTION_FOR) -; - -adminOptionFor -@init {pushMsg("admin option for", state);} -@after {popMsg(state);} - : KW_ADMIN KW_OPTION KW_FOR - -> ^(TOK_ADMIN_OPTION_FOR) -; - -withAdminOption -@init {pushMsg("with admin option", state);} -@after {popMsg(state);} - : KW_WITH KW_ADMIN KW_OPTION - -> ^(TOK_GRANT_WITH_ADMIN_OPTION) - ; - -metastoreCheck -@init { pushMsg("metastore check statement", state); } -@after { popMsg(state); } - : KW_MSCK (repair=KW_REPAIR)? (KW_TABLE tableName partitionSpec? (COMMA partitionSpec)*)? - -> ^(TOK_MSCK $repair? (tableName partitionSpec*)?) - ; - -resourceList -@init { pushMsg("resource list", state); } -@after { popMsg(state); } - : - resource (COMMA resource)* -> ^(TOK_RESOURCE_LIST resource+) - ; - -resource -@init { pushMsg("resource", state); } -@after { popMsg(state); } - : - resType=resourceType resPath=StringLiteral -> ^(TOK_RESOURCE_URI $resType $resPath) - ; - -resourceType -@init { pushMsg("resource type", state); } -@after { popMsg(state); } - : - KW_JAR -> ^(TOK_JAR) - | - KW_FILE -> ^(TOK_FILE) - | - KW_ARCHIVE -> ^(TOK_ARCHIVE) - ; - -createFunctionStatement -@init { pushMsg("create function statement", state); } -@after { popMsg(state); } - : KW_CREATE (temp=KW_TEMPORARY)? KW_FUNCTION functionIdentifier KW_AS StringLiteral - (KW_USING rList=resourceList)? - -> {$temp != null}? ^(TOK_CREATEFUNCTION functionIdentifier StringLiteral $rList? TOK_TEMPORARY) - -> ^(TOK_CREATEFUNCTION functionIdentifier StringLiteral $rList?) - ; - -dropFunctionStatement -@init { pushMsg("drop function statement", state); } -@after { popMsg(state); } - : KW_DROP (temp=KW_TEMPORARY)? KW_FUNCTION ifExists? functionIdentifier - -> {$temp != null}? ^(TOK_DROPFUNCTION functionIdentifier ifExists? TOK_TEMPORARY) - -> ^(TOK_DROPFUNCTION functionIdentifier ifExists?) - ; - -reloadFunctionStatement -@init { pushMsg("reload function statement", state); } -@after { popMsg(state); } - : KW_RELOAD KW_FUNCTION -> ^(TOK_RELOADFUNCTION); - -createMacroStatement -@init { pushMsg("create macro statement", state); } -@after { popMsg(state); } - : KW_CREATE KW_TEMPORARY KW_MACRO Identifier - LPAREN columnNameTypeList? RPAREN expression - -> ^(TOK_CREATEMACRO Identifier columnNameTypeList? expression) - ; - -dropMacroStatement -@init { pushMsg("drop macro statement", state); } -@after { popMsg(state); } - : KW_DROP KW_TEMPORARY KW_MACRO ifExists? Identifier - -> ^(TOK_DROPMACRO Identifier ifExists?) - ; - -createViewStatement -@init { - pushMsg("create view statement", state); -} -@after { popMsg(state); } - : KW_CREATE (orReplace)? KW_VIEW (ifNotExists)? name=tableName - (LPAREN columnNameCommentList RPAREN)? tableComment? viewPartition? - tablePropertiesPrefixed? - KW_AS - selectStatementWithCTE - -> ^(TOK_CREATEVIEW $name orReplace? - ifNotExists? - columnNameCommentList? - tableComment? - viewPartition? - tablePropertiesPrefixed? - selectStatementWithCTE - ) - ; - -viewPartition -@init { pushMsg("view partition specification", state); } -@after { popMsg(state); } - : KW_PARTITIONED KW_ON LPAREN columnNameList RPAREN - -> ^(TOK_VIEWPARTCOLS columnNameList) - ; - -dropViewStatement -@init { pushMsg("drop view statement", state); } -@after { popMsg(state); } - : KW_DROP KW_VIEW ifExists? viewName -> ^(TOK_DROPVIEW viewName ifExists?) - ; - -showFunctionIdentifier -@init { pushMsg("identifier for show function statement", state); } -@after { popMsg(state); } - : functionIdentifier - | StringLiteral - ; - -showStmtIdentifier -@init { pushMsg("identifier for show statement", state); } -@after { popMsg(state); } - : identifier - | StringLiteral - ; - -tableProvider -@init { pushMsg("table's provider", state); } -@after { popMsg(state); } - : - KW_USING Identifier (DOT Identifier)* - -> ^(TOK_TABLEPROVIDER Identifier+) - ; - -optionKeyValue -@init { pushMsg("table's option specification", state); } -@after { popMsg(state); } - : - (looseIdentifier (DOT looseIdentifier)*) StringLiteral - -> ^(TOK_TABLEOPTION looseIdentifier+ StringLiteral) - ; - -tableOpts -@init { pushMsg("table's options", state); } -@after { popMsg(state); } - : - KW_OPTIONS LPAREN optionKeyValue (COMMA optionKeyValue)* RPAREN - -> ^(TOK_TABLEOPTIONS optionKeyValue+) - ; - -tableComment -@init { pushMsg("table's comment", state); } -@after { popMsg(state); } - : - KW_COMMENT comment=StringLiteral -> ^(TOK_TABLECOMMENT $comment) - ; - -tablePartition -@init { pushMsg("table partition specification", state); } -@after { popMsg(state); } - : KW_PARTITIONED KW_BY LPAREN columnNameTypeList RPAREN - -> ^(TOK_TABLEPARTCOLS columnNameTypeList) - ; - -tableBuckets -@init { pushMsg("table buckets specification", state); } -@after { popMsg(state); } - : - KW_CLUSTERED KW_BY LPAREN bucketCols=columnNameList RPAREN (KW_SORTED KW_BY LPAREN sortCols=columnNameOrderList RPAREN)? KW_INTO num=Number KW_BUCKETS - -> ^(TOK_ALTERTABLE_BUCKETS $bucketCols $sortCols? $num) - ; - -tableSkewed -@init { pushMsg("table skewed specification", state); } -@after { popMsg(state); } - : - KW_SKEWED KW_BY LPAREN skewedCols=columnNameList RPAREN KW_ON LPAREN (skewedValues=skewedValueElement) RPAREN ((storedAsDirs) => storedAsDirs)? - -> ^(TOK_TABLESKEWED $skewedCols $skewedValues storedAsDirs?) - ; - -rowFormat -@init { pushMsg("serde specification", state); } -@after { popMsg(state); } - : rowFormatSerde -> ^(TOK_SERDE rowFormatSerde) - | rowFormatDelimited -> ^(TOK_SERDE rowFormatDelimited) - | -> ^(TOK_SERDE) - ; - -recordReader -@init { pushMsg("record reader specification", state); } -@after { popMsg(state); } - : KW_RECORDREADER StringLiteral -> ^(TOK_RECORDREADER StringLiteral) - | -> ^(TOK_RECORDREADER) - ; - -recordWriter -@init { pushMsg("record writer specification", state); } -@after { popMsg(state); } - : KW_RECORDWRITER StringLiteral -> ^(TOK_RECORDWRITER StringLiteral) - | -> ^(TOK_RECORDWRITER) - ; - -rowFormatSerde -@init { pushMsg("serde format specification", state); } -@after { popMsg(state); } - : KW_ROW KW_FORMAT KW_SERDE name=StringLiteral (KW_WITH KW_SERDEPROPERTIES serdeprops=tableProperties)? - -> ^(TOK_SERDENAME $name $serdeprops?) - ; - -rowFormatDelimited -@init { pushMsg("serde properties specification", state); } -@after { popMsg(state); } - : - KW_ROW KW_FORMAT KW_DELIMITED tableRowFormatFieldIdentifier? tableRowFormatCollItemsIdentifier? tableRowFormatMapKeysIdentifier? tableRowFormatLinesIdentifier? tableRowNullFormat? - -> ^(TOK_SERDEPROPS tableRowFormatFieldIdentifier? tableRowFormatCollItemsIdentifier? tableRowFormatMapKeysIdentifier? tableRowFormatLinesIdentifier? tableRowNullFormat?) - ; - -tableRowFormat -@init { pushMsg("table row format specification", state); } -@after { popMsg(state); } - : - rowFormatDelimited - -> ^(TOK_TABLEROWFORMAT rowFormatDelimited) - | rowFormatSerde - -> ^(TOK_TABLESERIALIZER rowFormatSerde) - ; - -tablePropertiesPrefixed -@init { pushMsg("table properties with prefix", state); } -@after { popMsg(state); } - : - KW_TBLPROPERTIES! tableProperties - ; - -tableProperties -@init { pushMsg("table properties", state); } -@after { popMsg(state); } - : - LPAREN tablePropertiesList RPAREN -> ^(TOK_TABLEPROPERTIES tablePropertiesList) - ; - -tablePropertiesList -@init { pushMsg("table properties list", state); } -@after { popMsg(state); } - : - keyValueProperty (COMMA keyValueProperty)* -> ^(TOK_TABLEPROPLIST keyValueProperty+) - | - keyProperty (COMMA keyProperty)* -> ^(TOK_TABLEPROPLIST keyProperty+) - ; - -keyValueProperty -@init { pushMsg("specifying key/value property", state); } -@after { popMsg(state); } - : - key=StringLiteral EQUAL value=StringLiteral -> ^(TOK_TABLEPROPERTY $key $value) - ; - -keyProperty -@init { pushMsg("specifying key property", state); } -@after { popMsg(state); } - : - key=StringLiteral -> ^(TOK_TABLEPROPERTY $key TOK_NULL) - ; - -tableRowFormatFieldIdentifier -@init { pushMsg("table row format's field separator", state); } -@after { popMsg(state); } - : - KW_FIELDS KW_TERMINATED KW_BY fldIdnt=StringLiteral (KW_ESCAPED KW_BY fldEscape=StringLiteral)? - -> ^(TOK_TABLEROWFORMATFIELD $fldIdnt $fldEscape?) - ; - -tableRowFormatCollItemsIdentifier -@init { pushMsg("table row format's column separator", state); } -@after { popMsg(state); } - : - KW_COLLECTION KW_ITEMS KW_TERMINATED KW_BY collIdnt=StringLiteral - -> ^(TOK_TABLEROWFORMATCOLLITEMS $collIdnt) - ; - -tableRowFormatMapKeysIdentifier -@init { pushMsg("table row format's map key separator", state); } -@after { popMsg(state); } - : - KW_MAP KW_KEYS KW_TERMINATED KW_BY mapKeysIdnt=StringLiteral - -> ^(TOK_TABLEROWFORMATMAPKEYS $mapKeysIdnt) - ; - -tableRowFormatLinesIdentifier -@init { pushMsg("table row format's line separator", state); } -@after { popMsg(state); } - : - KW_LINES KW_TERMINATED KW_BY linesIdnt=StringLiteral - -> ^(TOK_TABLEROWFORMATLINES $linesIdnt) - ; - -tableRowNullFormat -@init { pushMsg("table row format's null specifier", state); } -@after { popMsg(state); } - : - KW_NULL KW_DEFINED KW_AS nullIdnt=StringLiteral - -> ^(TOK_TABLEROWFORMATNULL $nullIdnt) - ; -tableFileFormat -@init { pushMsg("table file format specification", state); } -@after { popMsg(state); } - : - (KW_STORED KW_AS KW_INPUTFORMAT) => KW_STORED KW_AS KW_INPUTFORMAT inFmt=StringLiteral KW_OUTPUTFORMAT outFmt=StringLiteral (KW_INPUTDRIVER inDriver=StringLiteral KW_OUTPUTDRIVER outDriver=StringLiteral)? - -> ^(TOK_TABLEFILEFORMAT $inFmt $outFmt $inDriver? $outDriver?) - | KW_STORED KW_BY storageHandler=StringLiteral - (KW_WITH KW_SERDEPROPERTIES serdeprops=tableProperties)? - -> ^(TOK_STORAGEHANDLER $storageHandler $serdeprops?) - | KW_STORED KW_AS genericSpec=identifier - -> ^(TOK_FILEFORMAT_GENERIC $genericSpec) - ; - -tableLocation -@init { pushMsg("table location specification", state); } -@after { popMsg(state); } - : - KW_LOCATION locn=StringLiteral -> ^(TOK_TABLELOCATION $locn) - ; - -columnNameTypeList -@init { pushMsg("column name type list", state); } -@after { popMsg(state); } - : columnNameType (COMMA columnNameType)* -> ^(TOK_TABCOLLIST columnNameType+) - ; - -columnNameColonTypeList -@init { pushMsg("column name type list", state); } -@after { popMsg(state); } - : columnNameColonType (COMMA columnNameColonType)* -> ^(TOK_TABCOLLIST columnNameColonType+) - ; - -columnNameList -@init { pushMsg("column name list", state); } -@after { popMsg(state); } - : columnName (COMMA columnName)* -> ^(TOK_TABCOLNAME columnName+) - ; - -columnName -@init { pushMsg("column name", state); } -@after { popMsg(state); } - : - identifier - ; - -extColumnName -@init { pushMsg("column name for complex types", state); } -@after { popMsg(state); } - : - identifier (DOT^ ((KW_ELEM_TYPE) => KW_ELEM_TYPE | (KW_KEY_TYPE) => KW_KEY_TYPE | (KW_VALUE_TYPE) => KW_VALUE_TYPE | identifier))* - ; - -columnNameOrderList -@init { pushMsg("column name order list", state); } -@after { popMsg(state); } - : columnNameOrder (COMMA columnNameOrder)* -> ^(TOK_TABCOLNAME columnNameOrder+) - ; - -skewedValueElement -@init { pushMsg("skewed value element", state); } -@after { popMsg(state); } - : - skewedColumnValues - | skewedColumnValuePairList - ; - -skewedColumnValuePairList -@init { pushMsg("column value pair list", state); } -@after { popMsg(state); } - : skewedColumnValuePair (COMMA skewedColumnValuePair)* -> ^(TOK_TABCOLVALUE_PAIR skewedColumnValuePair+) - ; - -skewedColumnValuePair -@init { pushMsg("column value pair", state); } -@after { popMsg(state); } - : - LPAREN colValues=skewedColumnValues RPAREN - -> ^(TOK_TABCOLVALUES $colValues) - ; - -skewedColumnValues -@init { pushMsg("column values", state); } -@after { popMsg(state); } - : skewedColumnValue (COMMA skewedColumnValue)* -> ^(TOK_TABCOLVALUE skewedColumnValue+) - ; - -skewedColumnValue -@init { pushMsg("column value", state); } -@after { popMsg(state); } - : - constant - ; - -skewedValueLocationElement -@init { pushMsg("skewed value location element", state); } -@after { popMsg(state); } - : - skewedColumnValue - | skewedColumnValuePair - ; - -columnNameOrder -@init { pushMsg("column name order", state); } -@after { popMsg(state); } - : identifier (asc=KW_ASC | desc=KW_DESC)? - -> {$desc == null}? ^(TOK_TABSORTCOLNAMEASC identifier) - -> ^(TOK_TABSORTCOLNAMEDESC identifier) - ; - -columnNameCommentList -@init { pushMsg("column name comment list", state); } -@after { popMsg(state); } - : columnNameComment (COMMA columnNameComment)* -> ^(TOK_TABCOLNAME columnNameComment+) - ; - -columnNameComment -@init { pushMsg("column name comment", state); } -@after { popMsg(state); } - : colName=identifier (KW_COMMENT comment=StringLiteral)? - -> ^(TOK_TABCOL $colName TOK_NULL $comment?) - ; - -columnRefOrder -@init { pushMsg("column order", state); } -@after { popMsg(state); } - : expression (asc=KW_ASC | desc=KW_DESC)? - -> {$desc == null}? ^(TOK_TABSORTCOLNAMEASC expression) - -> ^(TOK_TABSORTCOLNAMEDESC expression) - ; - -columnNameType -@init { pushMsg("column specification", state); } -@after { popMsg(state); } - : colName=identifier colType (KW_COMMENT comment=StringLiteral)? - -> {containExcludedCharForCreateTableColumnName($colName.text)}? {throwColumnNameException()} - -> {$comment == null}? ^(TOK_TABCOL $colName colType) - -> ^(TOK_TABCOL $colName colType $comment) - ; - -columnNameColonType -@init { pushMsg("column specification", state); } -@after { popMsg(state); } - : colName=identifier COLON colType (KW_COMMENT comment=StringLiteral)? - -> {$comment == null}? ^(TOK_TABCOL $colName colType) - -> ^(TOK_TABCOL $colName colType $comment) - ; - -colType -@init { pushMsg("column type", state); } -@after { popMsg(state); } - : type - ; - -colTypeList -@init { pushMsg("column type list", state); } -@after { popMsg(state); } - : colType (COMMA colType)* -> ^(TOK_COLTYPELIST colType+) - ; - -type - : primitiveType - | listType - | structType - | mapType - | unionType; - -primitiveType -@init { pushMsg("primitive type specification", state); } -@after { popMsg(state); } - : KW_TINYINT -> TOK_TINYINT - | KW_SMALLINT -> TOK_SMALLINT - | KW_INT -> TOK_INT - | KW_BIGINT -> TOK_BIGINT - | KW_LONG -> TOK_BIGINT - | KW_BOOLEAN -> TOK_BOOLEAN - | KW_FLOAT -> TOK_FLOAT - | KW_DOUBLE -> TOK_DOUBLE - | KW_DATE -> TOK_DATE - | KW_DATETIME -> TOK_DATETIME - | KW_TIMESTAMP -> TOK_TIMESTAMP - // Uncomment to allow intervals as table column types - //| KW_INTERVAL KW_YEAR KW_TO KW_MONTH -> TOK_INTERVAL_YEAR_MONTH - //| KW_INTERVAL KW_DAY KW_TO KW_SECOND -> TOK_INTERVAL_DAY_TIME - | KW_STRING -> TOK_STRING - | KW_BINARY -> TOK_BINARY - | KW_DECIMAL (LPAREN prec=Number (COMMA scale=Number)? RPAREN)? -> ^(TOK_DECIMAL $prec? $scale?) - | KW_VARCHAR LPAREN length=Number RPAREN -> ^(TOK_VARCHAR $length) - | KW_CHAR LPAREN length=Number RPAREN -> ^(TOK_CHAR $length) - ; - -listType -@init { pushMsg("list type", state); } -@after { popMsg(state); } - : KW_ARRAY LESSTHAN type GREATERTHAN -> ^(TOK_LIST type) - ; - -structType -@init { pushMsg("struct type", state); } -@after { popMsg(state); } - : KW_STRUCT LESSTHAN columnNameColonTypeList GREATERTHAN -> ^(TOK_STRUCT columnNameColonTypeList) - ; - -mapType -@init { pushMsg("map type", state); } -@after { popMsg(state); } - : KW_MAP LESSTHAN left=type COMMA right=type GREATERTHAN - -> ^(TOK_MAP $left $right) - ; - -unionType -@init { pushMsg("uniontype type", state); } -@after { popMsg(state); } - : KW_UNIONTYPE LESSTHAN colTypeList GREATERTHAN -> ^(TOK_UNIONTYPE colTypeList) - ; - -setOperator -@init { pushMsg("set operator", state); } -@after { popMsg(state); } - : KW_UNION KW_ALL -> ^(TOK_UNIONALL) - | KW_UNION KW_DISTINCT? -> ^(TOK_UNIONDISTINCT) - | KW_EXCEPT -> ^(TOK_EXCEPT) - | KW_INTERSECT -> ^(TOK_INTERSECT) - ; - -queryStatementExpression[boolean topLevel] - : - /* Would be nice to do this as a gated semantic perdicate - But the predicate gets pushed as a lookahead decision. - Calling rule doesnot know about topLevel - */ - (w=withClause {topLevel}?)? - queryStatementExpressionBody[topLevel] { - if ($w.tree != null) { - $queryStatementExpressionBody.tree.insertChild(0, $w.tree); - } - } - -> queryStatementExpressionBody - ; - -queryStatementExpressionBody[boolean topLevel] - : - fromStatement[topLevel] - | regularBody[topLevel] - ; - -withClause - : - KW_WITH cteStatement (COMMA cteStatement)* -> ^(TOK_CTE cteStatement+) -; - -cteStatement - : - identifier KW_AS LPAREN queryStatementExpression[false] RPAREN - -> ^(TOK_SUBQUERY queryStatementExpression identifier) -; - -fromStatement[boolean topLevel] -: (singleFromStatement -> singleFromStatement) - (u=setOperator r=singleFromStatement - -> ^($u {$fromStatement.tree} $r) - )* - -> {u != null && topLevel}? ^(TOK_QUERY - ^(TOK_FROM - ^(TOK_SUBQUERY - {$fromStatement.tree} - {adaptor.create(Identifier, generateUnionAlias())} - ) - ) - ^(TOK_INSERT - ^(TOK_DESTINATION ^(TOK_DIR TOK_TMP_FILE)) - ^(TOK_SELECT ^(TOK_SELEXPR TOK_ALLCOLREF)) - ) - ) - -> {$fromStatement.tree} - ; - - -singleFromStatement - : - fromClause - ( b+=body )+ -> ^(TOK_QUERY fromClause body+) - ; - -/* -The valuesClause rule below ensures that the parse tree for -"insert into table FOO values (1,2),(3,4)" looks the same as -"insert into table FOO select a,b from (values(1,2),(3,4)) as BAR(a,b)" which itself is made to look -very similar to the tree for "insert into table FOO select a,b from BAR". Since virtual table name -is implicit, it's represented as TOK_ANONYMOUS. -*/ -regularBody[boolean topLevel] - : - i=insertClause - ( - s=selectStatement[topLevel] - {$s.tree.getFirstChildWithType(TOK_INSERT).replaceChildren(0, 0, $i.tree);} -> {$s.tree} - | - valuesClause - -> ^(TOK_QUERY - ^(TOK_FROM - ^(TOK_VIRTUAL_TABLE ^(TOK_VIRTUAL_TABREF ^(TOK_ANONYMOUS)) valuesClause) - ) - ^(TOK_INSERT {$i.tree} ^(TOK_SELECT ^(TOK_SELEXPR TOK_ALLCOLREF))) - ) - ) - | - selectStatement[topLevel] - ; - -selectStatement[boolean topLevel] - : - ( - ( - LPAREN - s=selectClause - f=fromClause? - w=whereClause? - g=groupByClause? - h=havingClause? - o=orderByClause? - c=clusterByClause? - d=distributeByClause? - sort=sortByClause? - win=window_clause? - l=limitClause? - RPAREN - | - s=selectClause - f=fromClause? - w=whereClause? - g=groupByClause? - h=havingClause? - o=orderByClause? - c=clusterByClause? - d=distributeByClause? - sort=sortByClause? - win=window_clause? - l=limitClause? - ) - -> ^(TOK_QUERY $f? ^(TOK_INSERT ^(TOK_DESTINATION ^(TOK_DIR TOK_TMP_FILE)) - $s $w? $g? $h? $o? $c? - $d? $sort? $win? $l?)) - ) - (set=setOpSelectStatement[$selectStatement.tree, topLevel])? - -> {set == null}? - {$selectStatement.tree} - -> {o==null && c==null && d==null && sort==null && l==null}? - {$set.tree} - -> {throwSetOpException()} - ; - -setOpSelectStatement[CommonTree t, boolean topLevel] - : - (( - u=setOperator LPAREN b=simpleSelectStatement RPAREN - | - u=setOperator b=simpleSelectStatement) - -> {$setOpSelectStatement.tree != null}? - ^($u {$setOpSelectStatement.tree} $b) - -> ^($u {$t} $b) - )+ - o=orderByClause? - c=clusterByClause? - d=distributeByClause? - sort=sortByClause? - win=window_clause? - l=limitClause? - -> {o==null && c==null && d==null && sort==null && win==null && l==null && !topLevel}? - {$setOpSelectStatement.tree} - -> ^(TOK_QUERY - ^(TOK_FROM - ^(TOK_SUBQUERY - {$setOpSelectStatement.tree} - {adaptor.create(Identifier, generateUnionAlias())} - ) - ) - ^(TOK_INSERT - ^(TOK_DESTINATION ^(TOK_DIR TOK_TMP_FILE)) - ^(TOK_SELECT ^(TOK_SELEXPR TOK_ALLCOLREF)) - $o? $c? $d? $sort? $win? $l? - ) - ) - ; - -simpleSelectStatement - : - selectClause - fromClause? - whereClause? - groupByClause? - havingClause? - ((window_clause) => window_clause)? - -> ^(TOK_QUERY fromClause? ^(TOK_INSERT ^(TOK_DESTINATION ^(TOK_DIR TOK_TMP_FILE)) - selectClause whereClause? groupByClause? havingClause? window_clause?)) - ; - -selectStatementWithCTE - : - (w=withClause)? - selectStatement[true] { - if ($w.tree != null) { - $selectStatement.tree.insertChild(0, $w.tree); - } - } - -> selectStatement - ; - -body - : - insertClause - selectClause - lateralView? - whereClause? - groupByClause? - havingClause? - orderByClause? - clusterByClause? - distributeByClause? - sortByClause? - window_clause? - limitClause? -> ^(TOK_INSERT insertClause - selectClause lateralView? whereClause? groupByClause? havingClause? orderByClause? clusterByClause? - distributeByClause? sortByClause? window_clause? limitClause?) - | - selectClause - lateralView? - whereClause? - groupByClause? - havingClause? - orderByClause? - clusterByClause? - distributeByClause? - sortByClause? - window_clause? - limitClause? -> ^(TOK_INSERT ^(TOK_DESTINATION ^(TOK_DIR TOK_TMP_FILE)) - selectClause lateralView? whereClause? groupByClause? havingClause? orderByClause? clusterByClause? - distributeByClause? sortByClause? window_clause? limitClause?) - ; - -insertClause -@init { pushMsg("insert clause", state); } -@after { popMsg(state); } - : - KW_INSERT KW_OVERWRITE destination ifNotExists? -> ^(TOK_DESTINATION destination ifNotExists?) - | KW_INSERT KW_INTO KW_TABLE? tableOrPartition (LPAREN targetCols=columnNameList RPAREN)? - -> ^(TOK_INSERT_INTO tableOrPartition $targetCols?) - ; - -destination -@init { pushMsg("destination specification", state); } -@after { popMsg(state); } - : - (local = KW_LOCAL)? KW_DIRECTORY StringLiteral tableRowFormat? tableFileFormat? - -> ^(TOK_DIR StringLiteral $local? tableRowFormat? tableFileFormat?) - | KW_TABLE tableOrPartition -> tableOrPartition - ; - -limitClause -@init { pushMsg("limit clause", state); } -@after { popMsg(state); } - : - KW_LIMIT num=Number -> ^(TOK_LIMIT $num) - ; - -//DELETE FROM WHERE ...; -deleteStatement -@init { pushMsg("delete statement", state); } -@after { popMsg(state); } - : - KW_DELETE KW_FROM tableName (whereClause)? -> ^(TOK_DELETE_FROM tableName whereClause?) - ; - -/*SET = (3 + col2)*/ -columnAssignmentClause - : - tableOrColumn EQUAL^ precedencePlusExpression - ; - -/*SET col1 = 5, col2 = (4 + col4), ...*/ -setColumnsClause - : - KW_SET columnAssignmentClause (COMMA columnAssignmentClause)* -> ^(TOK_SET_COLUMNS_CLAUSE columnAssignmentClause* ) - ; - -/* - UPDATE
- SET col1 = val1, col2 = val2... WHERE ... -*/ -updateStatement -@init { pushMsg("update statement", state); } -@after { popMsg(state); } - : - KW_UPDATE tableName setColumnsClause whereClause? -> ^(TOK_UPDATE_TABLE tableName setColumnsClause whereClause?) - ; - -/* -BEGIN user defined transaction boundaries; follows SQL 2003 standard exactly except for addition of -"setAutoCommitStatement" which is not in the standard doc but is supported by most SQL engines. -*/ -sqlTransactionStatement -@init { pushMsg("transaction statement", state); } -@after { popMsg(state); } - : startTransactionStatement - | commitStatement - | rollbackStatement - | setAutoCommitStatement - ; - -startTransactionStatement - : - KW_START KW_TRANSACTION ( transactionMode ( COMMA transactionMode )* )? -> ^(TOK_START_TRANSACTION transactionMode*) - ; - -transactionMode - : - isolationLevel - | transactionAccessMode -> ^(TOK_TXN_ACCESS_MODE transactionAccessMode) - ; - -transactionAccessMode - : - KW_READ KW_ONLY -> TOK_TXN_READ_ONLY - | KW_READ KW_WRITE -> TOK_TXN_READ_WRITE - ; - -isolationLevel - : - KW_ISOLATION KW_LEVEL levelOfIsolation -> ^(TOK_ISOLATION_LEVEL levelOfIsolation) - ; - -/*READ UNCOMMITTED | READ COMMITTED | REPEATABLE READ | SERIALIZABLE may be supported later*/ -levelOfIsolation - : - KW_SNAPSHOT -> TOK_ISOLATION_SNAPSHOT - ; - -commitStatement - : - KW_COMMIT ( KW_WORK )? -> TOK_COMMIT - ; - -rollbackStatement - : - KW_ROLLBACK ( KW_WORK )? -> TOK_ROLLBACK - ; -setAutoCommitStatement - : - KW_SET KW_AUTOCOMMIT booleanValueTok -> ^(TOK_SET_AUTOCOMMIT booleanValueTok) - ; -/* -END user defined transaction boundaries -*/ - -/* -Table Caching statements. - */ -cacheStatement -@init { pushMsg("cache statement", state); } -@after { popMsg(state); } - : - cacheTableStatement - | uncacheTableStatement - | clearCacheStatement - ; - -cacheTableStatement - : - KW_CACHE (lazy=KW_LAZY)? KW_TABLE identifier (KW_AS selectStatementWithCTE)? -> ^(TOK_CACHETABLE identifier $lazy? selectStatementWithCTE?) - ; - -uncacheTableStatement - : - KW_UNCACHE KW_TABLE identifier -> ^(TOK_UNCACHETABLE identifier) - ; - -clearCacheStatement - : - KW_CLEAR KW_CACHE -> ^(TOK_CLEARCACHE) - ; - diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 new file mode 100644 index 0000000000..3b9f82a80f --- /dev/null +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -0,0 +1,943 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * This file is an adaptation of Presto's presto-parser/src/main/antlr4/com/facebook/presto/sql/parser/SqlBase.g4 grammar. + */ + +grammar SqlBase; + +tokens { + DELIMITER +} + +singleStatement + : statement EOF + ; + +singleExpression + : namedExpression EOF + ; + +singleTableIdentifier + : tableIdentifier EOF + ; + +singleDataType + : dataType EOF + ; + +statement + : query #statementDefault + | USE db=identifier #use + | CREATE DATABASE (IF NOT EXISTS)? identifier + (COMMENT comment=STRING)? locationSpec? + (WITH DBPROPERTIES tablePropertyList)? #createDatabase + | ALTER DATABASE identifier SET DBPROPERTIES tablePropertyList #setDatabaseProperties + | DROP DATABASE (IF EXISTS)? identifier (RESTRICT | CASCADE)? #dropDatabase + | createTableHeader ('(' colTypeList ')')? tableProvider + (OPTIONS tablePropertyList)? #createTableUsing + | createTableHeader tableProvider + (OPTIONS tablePropertyList)? AS? query #createTableUsing + | createTableHeader ('(' columns=colTypeList ')')? + (COMMENT STRING)? + (PARTITIONED BY '(' partitionColumns=colTypeList ')')? + bucketSpec? skewSpec? + rowFormat? createFileFormat? locationSpec? + (TBLPROPERTIES tablePropertyList)? + (AS? query)? #createTable + | ANALYZE TABLE tableIdentifier partitionSpec? COMPUTE STATISTICS + (identifier | FOR COLUMNS identifierSeq?)? #analyze + | ALTER TABLE from=tableIdentifier RENAME TO to=tableIdentifier #renameTable + | ALTER TABLE tableIdentifier + SET TBLPROPERTIES tablePropertyList #setTableProperties + | ALTER TABLE tableIdentifier + UNSET TBLPROPERTIES (IF EXISTS)? tablePropertyList #unsetTableProperties + | ALTER TABLE tableIdentifier (partitionSpec)? + SET SERDE STRING (WITH SERDEPROPERTIES tablePropertyList)? #setTableSerDe + | ALTER TABLE tableIdentifier (partitionSpec)? + SET SERDEPROPERTIES tablePropertyList #setTableSerDe + | ALTER TABLE tableIdentifier bucketSpec #bucketTable + | ALTER TABLE tableIdentifier NOT CLUSTERED #unclusterTable + | ALTER TABLE tableIdentifier NOT SORTED #unsortTable + | ALTER TABLE tableIdentifier skewSpec #skewTable + | ALTER TABLE tableIdentifier NOT SKEWED #unskewTable + | ALTER TABLE tableIdentifier NOT STORED AS DIRECTORIES #unstoreTable + | ALTER TABLE tableIdentifier + SET SKEWED LOCATION skewedLocationList #setTableSkewLocations + | ALTER TABLE tableIdentifier ADD (IF NOT EXISTS)? + partitionSpecLocation+ #addTablePartition + | ALTER TABLE tableIdentifier + from=partitionSpec RENAME TO to=partitionSpec #renameTablePartition + | ALTER TABLE from=tableIdentifier + EXCHANGE partitionSpec WITH TABLE to=tableIdentifier #exchangeTablePartition + | ALTER TABLE tableIdentifier + DROP (IF EXISTS)? partitionSpec (',' partitionSpec)* PURGE? #dropTablePartitions + | ALTER TABLE tableIdentifier ARCHIVE partitionSpec #archiveTablePartition + | ALTER TABLE tableIdentifier UNARCHIVE partitionSpec #unarchiveTablePartition + | ALTER TABLE tableIdentifier partitionSpec? + SET FILEFORMAT fileFormat #setTableFileFormat + | ALTER TABLE tableIdentifier partitionSpec? SET locationSpec #setTableLocation + | ALTER TABLE tableIdentifier TOUCH partitionSpec? #touchTable + | ALTER TABLE tableIdentifier partitionSpec? COMPACT STRING #compactTable + | ALTER TABLE tableIdentifier partitionSpec? CONCATENATE #concatenateTable + | ALTER TABLE tableIdentifier partitionSpec? + CHANGE COLUMN? oldName=identifier colType + (FIRST | AFTER after=identifier)? (CASCADE | RESTRICT)? #changeColumn + | ALTER TABLE tableIdentifier partitionSpec? + ADD COLUMNS '(' colTypeList ')' (CASCADE | RESTRICT)? #addColumns + | ALTER TABLE tableIdentifier partitionSpec? + REPLACE COLUMNS '(' colTypeList ')' (CASCADE | RESTRICT)? #replaceColumns + | DROP TABLE (IF EXISTS)? tableIdentifier PURGE? + (FOR METADATA? REPLICATION '(' STRING ')')? #dropTable + | CREATE (OR REPLACE)? VIEW (IF NOT EXISTS)? tableIdentifier + identifierCommentList? (COMMENT STRING)? + (PARTITIONED ON identifierList)? + (TBLPROPERTIES tablePropertyList)? AS query #createView + | ALTER VIEW tableIdentifier AS? query #alterViewQuery + | CREATE TEMPORARY? FUNCTION qualifiedName AS className=STRING + (USING resource (',' resource)*)? #createFunction + | DROP TEMPORARY? FUNCTION (IF EXISTS)? qualifiedName #dropFunction + | EXPLAIN explainOption* statement #explain + | SHOW TABLES ((FROM | IN) db=identifier)? + (LIKE (qualifiedName | pattern=STRING))? #showTables + | SHOW FUNCTIONS (LIKE? (qualifiedName | pattern=STRING))? #showFunctions + | (DESC | DESCRIBE) FUNCTION EXTENDED? qualifiedName #describeFunction + | (DESC | DESCRIBE) option=(EXTENDED | FORMATTED)? + tableIdentifier partitionSpec? describeColName? #describeTable + | (DESC | DESCRIBE) DATABASE EXTENDED? identifier #describeDatabase + | REFRESH TABLE tableIdentifier #refreshTable + | CACHE LAZY? TABLE identifier (AS? query)? #cacheTable + | UNCACHE TABLE identifier #uncacheTable + | CLEAR CACHE #clearCache + | ADD identifier .*? #addResource + | SET ROLE .*? #failNativeCommand + | SET .*? #setConfiguration + | kws=unsupportedHiveNativeCommands .*? #failNativeCommand + | hiveNativeCommands #executeNativeCommand + ; + +hiveNativeCommands + : createTableHeader LIKE tableIdentifier + rowFormat? createFileFormat? locationSpec? + (TBLPROPERTIES tablePropertyList)? + | DELETE FROM tableIdentifier (WHERE booleanExpression)? + | TRUNCATE TABLE tableIdentifier partitionSpec? + (COLUMNS identifierList)? + | ALTER VIEW from=tableIdentifier AS? RENAME TO to=tableIdentifier + | ALTER VIEW from=tableIdentifier AS? + SET TBLPROPERTIES tablePropertyList + | ALTER VIEW from=tableIdentifier AS? + UNSET TBLPROPERTIES (IF EXISTS)? tablePropertyList + | ALTER VIEW from=tableIdentifier AS? + ADD (IF NOT EXISTS)? partitionSpecLocation+ + | ALTER VIEW from=tableIdentifier AS? + DROP (IF EXISTS)? partitionSpec (',' partitionSpec)* PURGE? + | DROP VIEW (IF EXISTS)? qualifiedName + | SHOW COLUMNS (FROM | IN) tableIdentifier ((FROM|IN) identifier)? + | START TRANSACTION (transactionMode (',' transactionMode)*)? + | COMMIT WORK? + | ROLLBACK WORK? + | SHOW PARTITIONS tableIdentifier partitionSpec? + | DFS .*? + | (CREATE | ALTER | DROP | SHOW | DESC | DESCRIBE | LOCK | UNLOCK | MSCK | LOAD) .*? + ; + +unsupportedHiveNativeCommands + : kw1=CREATE kw2=ROLE + | kw1=DROP kw2=ROLE + | kw1=GRANT kw2=ROLE? + | kw1=REVOKE kw2=ROLE? + | kw1=SHOW kw2=GRANT + | kw1=SHOW kw2=ROLE kw3=GRANT? + | kw1=SHOW kw2=PRINCIPALS + | kw1=SHOW kw2=ROLES + | kw1=SHOW kw2=CURRENT kw3=ROLES + | kw1=EXPORT kw2=TABLE + | kw1=IMPORT kw2=TABLE + | kw1=SHOW kw2=COMPACTIONS + | kw1=SHOW kw2=CREATE kw3=TABLE + | kw1=SHOW kw2=TRANSACTIONS + | kw1=SHOW kw2=INDEXES + | kw1=SHOW kw2=LOCKS + ; + +createTableHeader + : CREATE TEMPORARY? EXTERNAL? TABLE (IF NOT EXISTS)? tableIdentifier + ; + +bucketSpec + : CLUSTERED BY identifierList + (SORTED BY orderedIdentifierList)? + INTO INTEGER_VALUE BUCKETS + ; + +skewSpec + : SKEWED BY identifierList + ON (constantList | nestedConstantList) + (STORED AS DIRECTORIES)? + ; + +locationSpec + : LOCATION STRING + ; + +query + : ctes? queryNoWith + ; + +insertInto + : INSERT OVERWRITE TABLE tableIdentifier partitionSpec? (IF NOT EXISTS)? + | INSERT INTO TABLE? tableIdentifier partitionSpec? + ; + +partitionSpecLocation + : partitionSpec locationSpec? + ; + +partitionSpec + : PARTITION '(' partitionVal (',' partitionVal)* ')' + ; + +partitionVal + : identifier (EQ constant)? + ; + +describeColName + : identifier ('.' (identifier | STRING))* + ; + +ctes + : WITH namedQuery (',' namedQuery)* + ; + +namedQuery + : name=identifier AS? '(' queryNoWith ')' + ; + +tableProvider + : USING qualifiedName + ; + +tablePropertyList + : '(' tableProperty (',' tableProperty)* ')' + ; + +tableProperty + : key=tablePropertyKey (EQ? value=STRING)? + ; + +tablePropertyKey + : looseIdentifier ('.' looseIdentifier)* + | STRING + ; + +constantList + : '(' constant (',' constant)* ')' + ; + +nestedConstantList + : '(' constantList (',' constantList)* ')' + ; + +skewedLocation + : (constant | constantList) EQ STRING + ; + +skewedLocationList + : '(' skewedLocation (',' skewedLocation)* ')' + ; + +createFileFormat + : STORED AS fileFormat + | STORED BY storageHandler + ; + +fileFormat + : INPUTFORMAT inFmt=STRING OUTPUTFORMAT outFmt=STRING (SERDE serdeCls=STRING)? + (INPUTDRIVER inDriver=STRING OUTPUTDRIVER outDriver=STRING)? #tableFileFormat + | identifier #genericFileFormat + ; + +storageHandler + : STRING (WITH SERDEPROPERTIES tablePropertyList)? + ; + +resource + : identifier STRING + ; + +queryNoWith + : insertInto? queryTerm queryOrganization #singleInsertQuery + | fromClause multiInsertQueryBody+ #multiInsertQuery + ; + +queryOrganization + : (ORDER BY order+=sortItem (',' order+=sortItem)*)? + (CLUSTER BY clusterBy+=expression (',' clusterBy+=expression)*)? + (DISTRIBUTE BY distributeBy+=expression (',' distributeBy+=expression)*)? + (SORT BY sort+=sortItem (',' sort+=sortItem)*)? + windows? + (LIMIT limit=expression)? + ; + +multiInsertQueryBody + : insertInto? + querySpecification + queryOrganization + ; + +queryTerm + : queryPrimary #queryTermDefault + | left=queryTerm operator=(INTERSECT | UNION | EXCEPT) setQuantifier? right=queryTerm #setOperation + ; + +queryPrimary + : querySpecification #queryPrimaryDefault + | TABLE tableIdentifier #table + | inlineTable #inlineTableDefault1 + | '(' queryNoWith ')' #subquery + ; + +sortItem + : expression ordering=(ASC | DESC)? + ; + +querySpecification + : (((SELECT kind=TRANSFORM '(' namedExpressionSeq ')' + | kind=MAP namedExpressionSeq + | kind=REDUCE namedExpressionSeq)) + inRowFormat=rowFormat? + (RECORDWRITER recordWriter=STRING)? + USING script=STRING + (AS (identifierSeq | colTypeList | ('(' (identifierSeq | colTypeList) ')')))? + outRowFormat=rowFormat? + (RECORDREADER recordReader=STRING)? + fromClause? + (WHERE where=booleanExpression)?) + | ((kind=SELECT setQuantifier? namedExpressionSeq fromClause? + | fromClause (kind=SELECT setQuantifier? namedExpressionSeq)?) + lateralView* + (WHERE where=booleanExpression)? + aggregation? + (HAVING having=booleanExpression)? + windows?) + ; + +fromClause + : FROM relation (',' relation)* lateralView* + ; + +aggregation + : GROUP BY groupingExpressions+=expression (',' groupingExpressions+=expression)* ( + WITH kind=ROLLUP + | WITH kind=CUBE + | kind=GROUPING SETS '(' groupingSet (',' groupingSet)* ')')? + ; + +groupingSet + : '(' (expression (',' expression)*)? ')' + | expression + ; + +lateralView + : LATERAL VIEW (OUTER)? qualifiedName '(' (expression (',' expression)*)? ')' tblName=identifier (AS? colName+=identifier (',' colName+=identifier)*)? + ; + +setQuantifier + : DISTINCT + | ALL + ; + +relation + : left=relation + ((CROSS | joinType) JOIN right=relation joinCriteria? + | NATURAL joinType JOIN right=relation + ) #joinRelation + | relationPrimary #relationDefault + ; + +joinType + : INNER? + | LEFT OUTER? + | LEFT SEMI + | RIGHT OUTER? + | FULL OUTER? + ; + +joinCriteria + : ON booleanExpression + | USING '(' identifier (',' identifier)* ')' + ; + +sample + : TABLESAMPLE '(' + ( (percentage=(INTEGER_VALUE | DECIMAL_VALUE) sampleType=PERCENTLIT) + | (expression sampleType=ROWS) + | (sampleType=BUCKET numerator=INTEGER_VALUE OUT OF denominator=INTEGER_VALUE (ON identifier)?)) + ')' + ; + +identifierList + : '(' identifierSeq ')' + ; + +identifierSeq + : identifier (',' identifier)* + ; + +orderedIdentifierList + : '(' orderedIdentifier (',' orderedIdentifier)* ')' + ; + +orderedIdentifier + : identifier ordering=(ASC | DESC)? + ; + +identifierCommentList + : '(' identifierComment (',' identifierComment)* ')' + ; + +identifierComment + : identifier (COMMENT STRING)? + ; + +relationPrimary + : tableIdentifier sample? (AS? identifier)? #tableName + | '(' queryNoWith ')' sample? (AS? identifier)? #aliasedQuery + | '(' relation ')' sample? (AS? identifier)? #aliasedRelation + | inlineTable #inlineTableDefault2 + ; + +inlineTable + : VALUES expression (',' expression)* (AS? identifier identifierList?)? + ; + +rowFormat + : ROW FORMAT SERDE name=STRING (WITH SERDEPROPERTIES props=tablePropertyList)? #rowFormatSerde + | ROW FORMAT DELIMITED + (FIELDS TERMINATED BY fieldsTerminatedBy=STRING (ESCAPED BY escapedBy=STRING)?)? + (COLLECTION ITEMS TERMINATED BY collectionItemsTerminatedBy=STRING)? + (MAP KEYS TERMINATED BY keysTerminatedBy=STRING)? + (LINES TERMINATED BY linesSeparatedBy=STRING)? + (NULL DEFINED AS nullDefinedAs=STRING)? #rowFormatDelimited + ; + +tableIdentifier + : (db=identifier '.')? table=identifier + ; + +namedExpression + : expression (AS? (identifier | identifierList))? + ; + +namedExpressionSeq + : namedExpression (',' namedExpression)* + ; + +expression + : booleanExpression + ; + +booleanExpression + : predicated #booleanDefault + | NOT booleanExpression #logicalNot + | left=booleanExpression operator=AND right=booleanExpression #logicalBinary + | left=booleanExpression operator=OR right=booleanExpression #logicalBinary + | EXISTS '(' query ')' #exists + ; + +// workaround for: +// https://github.com/antlr/antlr4/issues/780 +// https://github.com/antlr/antlr4/issues/781 +predicated + : valueExpression predicate[$valueExpression.ctx]? + ; + +predicate[ParserRuleContext value] + : NOT? BETWEEN lower=valueExpression AND upper=valueExpression #between + | NOT? IN '(' expression (',' expression)* ')' #inList + | NOT? IN '(' query ')' #inSubquery + | NOT? like=(RLIKE | LIKE) pattern=valueExpression #like + | IS NOT? NULL #nullPredicate + ; + +valueExpression + : primaryExpression #valueExpressionDefault + | operator=(MINUS | PLUS | TILDE) valueExpression #arithmeticUnary + | left=valueExpression operator=(ASTERISK | SLASH | PERCENT | DIV) right=valueExpression #arithmeticBinary + | left=valueExpression operator=(PLUS | MINUS) right=valueExpression #arithmeticBinary + | left=valueExpression operator=AMPERSAND right=valueExpression #arithmeticBinary + | left=valueExpression operator=HAT right=valueExpression #arithmeticBinary + | left=valueExpression operator=PIPE right=valueExpression #arithmeticBinary + | left=valueExpression comparisonOperator right=valueExpression #comparison + ; + +primaryExpression + : constant #constantDefault + | ASTERISK #star + | qualifiedName '.' ASTERISK #star + | '(' expression (',' expression)+ ')' #rowConstructor + | qualifiedName '(' (setQuantifier? expression (',' expression)*)? ')' (OVER windowSpec)? #functionCall + | '(' query ')' #subqueryExpression + | CASE valueExpression whenClause+ (ELSE elseExpression=expression)? END #simpleCase + | CASE whenClause+ (ELSE elseExpression=expression)? END #searchedCase + | CAST '(' expression AS dataType ')' #cast + | value=primaryExpression '[' index=valueExpression ']' #subscript + | identifier #columnReference + | base=primaryExpression '.' fieldName=identifier #dereference + | '(' expression ')' #parenthesizedExpression + ; + +constant + : NULL #nullLiteral + | interval #intervalLiteral + | identifier STRING #typeConstructor + | number #numericLiteral + | booleanValue #booleanLiteral + | STRING+ #stringLiteral + ; + +comparisonOperator + : EQ | NEQ | NEQJ | LT | LTE | GT | GTE | NSEQ + ; + +booleanValue + : TRUE | FALSE + ; + +interval + : INTERVAL intervalField* + ; + +intervalField + : value=intervalValue unit=identifier (TO to=identifier)? + ; + +intervalValue + : (PLUS | MINUS)? (INTEGER_VALUE | DECIMAL_VALUE) + | STRING + ; + +dataType + : complex=ARRAY '<' dataType '>' #complexDataType + | complex=MAP '<' dataType ',' dataType '>' #complexDataType + | complex=STRUCT ('<' colTypeList? '>' | NEQ) #complexDataType + | identifier ('(' INTEGER_VALUE (',' INTEGER_VALUE)* ')')? #primitiveDataType + ; + +colTypeList + : colType (',' colType)* + ; + +colType + : identifier ':'? dataType (COMMENT STRING)? + ; + +whenClause + : WHEN condition=expression THEN result=expression + ; + +windows + : WINDOW namedWindow (',' namedWindow)* + ; + +namedWindow + : identifier AS windowSpec + ; + +windowSpec + : name=identifier #windowRef + | '(' + ( CLUSTER BY partition+=expression (',' partition+=expression)* + | ((PARTITION | DISTRIBUTE) BY partition+=expression (',' partition+=expression)*)? + ((ORDER | SORT) BY sortItem (',' sortItem)*)?) + windowFrame? + ')' #windowDef + ; + +windowFrame + : frameType=RANGE start=frameBound + | frameType=ROWS start=frameBound + | frameType=RANGE BETWEEN start=frameBound AND end=frameBound + | frameType=ROWS BETWEEN start=frameBound AND end=frameBound + ; + +frameBound + : UNBOUNDED boundType=(PRECEDING | FOLLOWING) + | boundType=CURRENT ROW + | expression boundType=(PRECEDING | FOLLOWING) + ; + + +explainOption + : LOGICAL | FORMATTED | EXTENDED + ; + +transactionMode + : ISOLATION LEVEL SNAPSHOT #isolationLevel + | READ accessMode=(ONLY | WRITE) #transactionAccessMode + ; + +qualifiedName + : identifier ('.' identifier)* + ; + +// Identifier that also allows the use of a number of SQL keywords (mainly for backwards compatibility). +looseIdentifier + : identifier + | FROM + | TO + | TABLE + | WITH + ; + +identifier + : IDENTIFIER #unquotedIdentifier + | quotedIdentifier #quotedIdentifierAlternative + | nonReserved #unquotedIdentifier + ; + +quotedIdentifier + : BACKQUOTED_IDENTIFIER + ; + +number + : DECIMAL_VALUE #decimalLiteral + | SCIENTIFIC_DECIMAL_VALUE #scientificDecimalLiteral + | INTEGER_VALUE #integerLiteral + | BIGINT_LITERAL #bigIntLiteral + | SMALLINT_LITERAL #smallIntLiteral + | TINYINT_LITERAL #tinyIntLiteral + | DOUBLE_LITERAL #doubleLiteral + ; + +nonReserved + : SHOW | TABLES | COLUMNS | COLUMN | PARTITIONS | FUNCTIONS + | ADD + | OVER | PARTITION | RANGE | ROWS | PRECEDING | FOLLOWING | CURRENT | ROW | MAP | ARRAY | STRUCT + | LATERAL | WINDOW | REDUCE | TRANSFORM | USING | SERDE | SERDEPROPERTIES | RECORDREADER + | DELIMITED | FIELDS | TERMINATED | COLLECTION | ITEMS | KEYS | ESCAPED | LINES | SEPARATED + | EXTENDED | REFRESH | CLEAR | CACHE | UNCACHE | LAZY | TEMPORARY | OPTIONS + | GROUPING | CUBE | ROLLUP + | EXPLAIN | FORMAT | LOGICAL | FORMATTED + | TABLESAMPLE | USE | TO | BUCKET | PERCENTLIT | OUT | OF + | SET + | VIEW | REPLACE + | IF + | NO | DATA + | START | TRANSACTION | COMMIT | ROLLBACK | WORK | ISOLATION | LEVEL + | SNAPSHOT | READ | WRITE | ONLY + | SORT | CLUSTER | DISTRIBUTE UNSET | TBLPROPERTIES | SKEWED | STORED | DIRECTORIES | LOCATION + | EXCHANGE | ARCHIVE | UNARCHIVE | FILEFORMAT | TOUCH | COMPACT | CONCATENATE | CHANGE | FIRST + | AFTER | CASCADE | RESTRICT | BUCKETS | CLUSTERED | SORTED | PURGE | INPUTFORMAT | OUTPUTFORMAT + | INPUTDRIVER | OUTPUTDRIVER | DBPROPERTIES | DFS | TRUNCATE | METADATA | REPLICATION | COMPUTE + | STATISTICS | ANALYZE | PARTITIONED | EXTERNAL | DEFINED | RECORDWRITER + | REVOKE | GRANT | LOCK | UNLOCK | MSCK | EXPORT | IMPORT | LOAD | VALUES | COMMENT | ROLE + | ROLES | COMPACTIONS | PRINCIPALS | TRANSACTIONS | INDEXES | LOCKS | OPTION + ; + +SELECT: 'SELECT'; +FROM: 'FROM'; +ADD: 'ADD'; +AS: 'AS'; +ALL: 'ALL'; +DISTINCT: 'DISTINCT'; +WHERE: 'WHERE'; +GROUP: 'GROUP'; +BY: 'BY'; +GROUPING: 'GROUPING'; +SETS: 'SETS'; +CUBE: 'CUBE'; +ROLLUP: 'ROLLUP'; +ORDER: 'ORDER'; +HAVING: 'HAVING'; +LIMIT: 'LIMIT'; +AT: 'AT'; +OR: 'OR'; +AND: 'AND'; +IN: 'IN'; +NOT: 'NOT' | '!'; +NO: 'NO'; +EXISTS: 'EXISTS'; +BETWEEN: 'BETWEEN'; +LIKE: 'LIKE'; +RLIKE: 'RLIKE' | 'REGEXP'; +IS: 'IS'; +NULL: 'NULL'; +TRUE: 'TRUE'; +FALSE: 'FALSE'; +NULLS: 'NULLS'; +ASC: 'ASC'; +DESC: 'DESC'; +FOR: 'FOR'; +INTERVAL: 'INTERVAL'; +CASE: 'CASE'; +WHEN: 'WHEN'; +THEN: 'THEN'; +ELSE: 'ELSE'; +END: 'END'; +JOIN: 'JOIN'; +CROSS: 'CROSS'; +OUTER: 'OUTER'; +INNER: 'INNER'; +LEFT: 'LEFT'; +SEMI: 'SEMI'; +RIGHT: 'RIGHT'; +FULL: 'FULL'; +NATURAL: 'NATURAL'; +ON: 'ON'; +LATERAL: 'LATERAL'; +WINDOW: 'WINDOW'; +OVER: 'OVER'; +PARTITION: 'PARTITION'; +RANGE: 'RANGE'; +ROWS: 'ROWS'; +UNBOUNDED: 'UNBOUNDED'; +PRECEDING: 'PRECEDING'; +FOLLOWING: 'FOLLOWING'; +CURRENT: 'CURRENT'; +ROW: 'ROW'; +WITH: 'WITH'; +VALUES: 'VALUES'; +CREATE: 'CREATE'; +TABLE: 'TABLE'; +VIEW: 'VIEW'; +REPLACE: 'REPLACE'; +INSERT: 'INSERT'; +DELETE: 'DELETE'; +INTO: 'INTO'; +DESCRIBE: 'DESCRIBE'; +EXPLAIN: 'EXPLAIN'; +FORMAT: 'FORMAT'; +LOGICAL: 'LOGICAL'; +CAST: 'CAST'; +SHOW: 'SHOW'; +TABLES: 'TABLES'; +COLUMNS: 'COLUMNS'; +COLUMN: 'COLUMN'; +USE: 'USE'; +PARTITIONS: 'PARTITIONS'; +FUNCTIONS: 'FUNCTIONS'; +DROP: 'DROP'; +UNION: 'UNION'; +EXCEPT: 'EXCEPT'; +INTERSECT: 'INTERSECT'; +TO: 'TO'; +TABLESAMPLE: 'TABLESAMPLE'; +STRATIFY: 'STRATIFY'; +ALTER: 'ALTER'; +RENAME: 'RENAME'; +ARRAY: 'ARRAY'; +MAP: 'MAP'; +STRUCT: 'STRUCT'; +COMMENT: 'COMMENT'; +SET: 'SET'; +DATA: 'DATA'; +START: 'START'; +TRANSACTION: 'TRANSACTION'; +COMMIT: 'COMMIT'; +ROLLBACK: 'ROLLBACK'; +WORK: 'WORK'; +ISOLATION: 'ISOLATION'; +LEVEL: 'LEVEL'; +SNAPSHOT: 'SNAPSHOT'; +READ: 'READ'; +WRITE: 'WRITE'; +ONLY: 'ONLY'; + +IF: 'IF'; + +EQ : '=' | '=='; +NSEQ: '<=>'; +NEQ : '<>'; +NEQJ: '!='; +LT : '<'; +LTE : '<='; +GT : '>'; +GTE : '>='; + +PLUS: '+'; +MINUS: '-'; +ASTERISK: '*'; +SLASH: '/'; +PERCENT: '%'; +DIV: 'DIV'; +TILDE: '~'; +AMPERSAND: '&'; +PIPE: '|'; +HAT: '^'; + +PERCENTLIT: 'PERCENT'; +BUCKET: 'BUCKET'; +OUT: 'OUT'; +OF: 'OF'; + +SORT: 'SORT'; +CLUSTER: 'CLUSTER'; +DISTRIBUTE: 'DISTRIBUTE'; +OVERWRITE: 'OVERWRITE'; +TRANSFORM: 'TRANSFORM'; +REDUCE: 'REDUCE'; +USING: 'USING'; +SERDE: 'SERDE'; +SERDEPROPERTIES: 'SERDEPROPERTIES'; +RECORDREADER: 'RECORDREADER'; +RECORDWRITER: 'RECORDWRITER'; +DELIMITED: 'DELIMITED'; +FIELDS: 'FIELDS'; +TERMINATED: 'TERMINATED'; +COLLECTION: 'COLLECTION'; +ITEMS: 'ITEMS'; +KEYS: 'KEYS'; +ESCAPED: 'ESCAPED'; +LINES: 'LINES'; +SEPARATED: 'SEPARATED'; +FUNCTION: 'FUNCTION'; +EXTENDED: 'EXTENDED'; +REFRESH: 'REFRESH'; +CLEAR: 'CLEAR'; +CACHE: 'CACHE'; +UNCACHE: 'UNCACHE'; +LAZY: 'LAZY'; +FORMATTED: 'FORMATTED'; +TEMPORARY: 'TEMPORARY' | 'TEMP'; +OPTIONS: 'OPTIONS'; +UNSET: 'UNSET'; +TBLPROPERTIES: 'TBLPROPERTIES'; +DBPROPERTIES: 'DBPROPERTIES'; +BUCKETS: 'BUCKETS'; +SKEWED: 'SKEWED'; +STORED: 'STORED'; +DIRECTORIES: 'DIRECTORIES'; +LOCATION: 'LOCATION'; +EXCHANGE: 'EXCHANGE'; +ARCHIVE: 'ARCHIVE'; +UNARCHIVE: 'UNARCHIVE'; +FILEFORMAT: 'FILEFORMAT'; +TOUCH: 'TOUCH'; +COMPACT: 'COMPACT'; +CONCATENATE: 'CONCATENATE'; +CHANGE: 'CHANGE'; +FIRST: 'FIRST'; +AFTER: 'AFTER'; +CASCADE: 'CASCADE'; +RESTRICT: 'RESTRICT'; +CLUSTERED: 'CLUSTERED'; +SORTED: 'SORTED'; +PURGE: 'PURGE'; +INPUTFORMAT: 'INPUTFORMAT'; +OUTPUTFORMAT: 'OUTPUTFORMAT'; +INPUTDRIVER: 'INPUTDRIVER'; +OUTPUTDRIVER: 'OUTPUTDRIVER'; +DATABASE: 'DATABASE' | 'SCHEMA'; +DFS: 'DFS'; +TRUNCATE: 'TRUNCATE'; +METADATA: 'METADATA'; +REPLICATION: 'REPLICATION'; +ANALYZE: 'ANALYZE'; +COMPUTE: 'COMPUTE'; +STATISTICS: 'STATISTICS'; +PARTITIONED: 'PARTITIONED'; +EXTERNAL: 'EXTERNAL'; +DEFINED: 'DEFINED'; +REVOKE: 'REVOKE'; +GRANT: 'GRANT'; +LOCK: 'LOCK'; +UNLOCK: 'UNLOCK'; +MSCK: 'MSCK'; +EXPORT: 'EXPORT'; +IMPORT: 'IMPORT'; +LOAD: 'LOAD'; +ROLE: 'ROLE'; +ROLES: 'ROLES'; +COMPACTIONS: 'COMPACTIONS'; +PRINCIPALS: 'PRINCIPALS'; +TRANSACTIONS: 'TRANSACTIONS'; +INDEXES: 'INDEXES'; +LOCKS: 'LOCKS'; +OPTION: 'OPTION'; + +STRING + : '\'' ( ~('\''|'\\') | ('\\' .) )* '\'' + | '\"' ( ~('\"'|'\\') | ('\\' .) )* '\"' + ; + +BIGINT_LITERAL + : DIGIT+ 'L' + ; + +SMALLINT_LITERAL + : DIGIT+ 'S' + ; + +TINYINT_LITERAL + : DIGIT+ 'Y' + ; + +INTEGER_VALUE + : DIGIT+ + ; + +DECIMAL_VALUE + : DIGIT+ '.' DIGIT* + | '.' DIGIT+ + ; + +SCIENTIFIC_DECIMAL_VALUE + : DIGIT+ ('.' DIGIT*)? EXPONENT + | '.' DIGIT+ EXPONENT + ; + +DOUBLE_LITERAL + : + (INTEGER_VALUE | DECIMAL_VALUE | SCIENTIFIC_DECIMAL_VALUE) 'D' + ; + +IDENTIFIER + : (LETTER | DIGIT | '_')+ + ; + +BACKQUOTED_IDENTIFIER + : '`' ( ~'`' | '``' )* '`' + ; + +fragment EXPONENT + : 'E' [+-]? DIGIT+ + ; + +fragment DIGIT + : [0-9] + ; + +fragment LETTER + : [A-Z] + ; + +SIMPLE_COMMENT + : '--' ~[\r\n]* '\r'? '\n'? -> channel(HIDDEN) + ; + +BRACKETED_COMMENT + : '/*' .*? '*/' -> channel(HIDDEN) + ; + +WS + : [ \r\n\t]+ -> channel(HIDDEN) + ; + +// Catch-all for anything we can't recognize. +// We use this to be able to ignore and recover all the text +// when splitting statements with DelimiterLexer +UNRECOGNIZED + : . + ; diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/ng/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/ng/SqlBase.g4 deleted file mode 100644 index 4e77b6db25..0000000000 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/ng/SqlBase.g4 +++ /dev/null @@ -1,941 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * - * This file is an adaptation of Presto's presto-parser/src/main/antlr4/com/facebook/presto/sql/parser/SqlBase.g4 grammar. - */ - -grammar SqlBase; - -tokens { - DELIMITER -} - -singleStatement - : statement EOF - ; - -singleExpression - : namedExpression EOF - ; - -singleTableIdentifier - : tableIdentifier EOF - ; - -singleDataType - : dataType EOF - ; - -statement - : query #statementDefault - | USE db=identifier #use - | CREATE DATABASE (IF NOT EXISTS)? identifier - (COMMENT comment=STRING)? locationSpec? - (WITH DBPROPERTIES tablePropertyList)? #createDatabase - | ALTER DATABASE identifier SET DBPROPERTIES tablePropertyList #setDatabaseProperties - | DROP DATABASE (IF EXISTS)? identifier (RESTRICT | CASCADE)? #dropDatabase - | createTableHeader ('(' colTypeList ')')? tableProvider - (OPTIONS tablePropertyList)? #createTableUsing - | createTableHeader tableProvider - (OPTIONS tablePropertyList)? AS? query #createTableUsing - | createTableHeader ('(' colTypeList ')')? (COMMENT STRING)? - (PARTITIONED BY identifierList)? bucketSpec? skewSpec? - rowFormat? createFileFormat? locationSpec? - (TBLPROPERTIES tablePropertyList)? - (AS? query)? #createTable - | ANALYZE TABLE tableIdentifier partitionSpec? COMPUTE STATISTICS - (identifier | FOR COLUMNS identifierSeq?) #analyze - | ALTER TABLE from=tableIdentifier RENAME TO to=tableIdentifier #renameTable - | ALTER TABLE tableIdentifier - SET TBLPROPERTIES tablePropertyList #setTableProperties - | ALTER TABLE tableIdentifier - UNSET TBLPROPERTIES (IF EXISTS)? tablePropertyList #unsetTableProperties - | ALTER TABLE tableIdentifier (partitionSpec)? - SET SERDE STRING (WITH SERDEPROPERTIES tablePropertyList)? #setTableSerDe - | ALTER TABLE tableIdentifier (partitionSpec)? - SET SERDEPROPERTIES tablePropertyList #setTableSerDe - | ALTER TABLE tableIdentifier bucketSpec #bucketTable - | ALTER TABLE tableIdentifier NOT CLUSTERED #unclusterTable - | ALTER TABLE tableIdentifier NOT SORTED #unsortTable - | ALTER TABLE tableIdentifier skewSpec #skewTable - | ALTER TABLE tableIdentifier NOT SKEWED #unskewTable - | ALTER TABLE tableIdentifier NOT STORED AS DIRECTORIES #unstoreTable - | ALTER TABLE tableIdentifier - SET SKEWED LOCATION skewedLocationList #setTableSkewLocations - | ALTER TABLE tableIdentifier ADD (IF NOT EXISTS)? - partitionSpecLocation+ #addTablePartition - | ALTER TABLE tableIdentifier - from=partitionSpec RENAME TO to=partitionSpec #renameTablePartition - | ALTER TABLE from=tableIdentifier - EXCHANGE partitionSpec WITH TABLE to=tableIdentifier #exchangeTablePartition - | ALTER TABLE tableIdentifier - DROP (IF EXISTS)? partitionSpec (',' partitionSpec)* PURGE? #dropTablePartitions - | ALTER TABLE tableIdentifier ARCHIVE partitionSpec #archiveTablePartition - | ALTER TABLE tableIdentifier UNARCHIVE partitionSpec #unarchiveTablePartition - | ALTER TABLE tableIdentifier partitionSpec? - SET FILEFORMAT fileFormat #setTableFileFormat - | ALTER TABLE tableIdentifier partitionSpec? SET locationSpec #setTableLocation - | ALTER TABLE tableIdentifier TOUCH partitionSpec? #touchTable - | ALTER TABLE tableIdentifier partitionSpec? COMPACT STRING #compactTable - | ALTER TABLE tableIdentifier partitionSpec? CONCATENATE #concatenateTable - | ALTER TABLE tableIdentifier partitionSpec? - CHANGE COLUMN? oldName=identifier colType - (FIRST | AFTER after=identifier)? (CASCADE | RESTRICT)? #changeColumn - | ALTER TABLE tableIdentifier partitionSpec? - ADD COLUMNS '(' colTypeList ')' (CASCADE | RESTRICT)? #addColumns - | ALTER TABLE tableIdentifier partitionSpec? - REPLACE COLUMNS '(' colTypeList ')' (CASCADE | RESTRICT)? #replaceColumns - | DROP TABLE (IF EXISTS)? tableIdentifier PURGE? - (FOR METADATA? REPLICATION '(' STRING ')')? #dropTable - | CREATE (OR REPLACE)? VIEW (IF NOT EXISTS)? tableIdentifier - identifierCommentList? (COMMENT STRING)? - (PARTITIONED ON identifierList)? - (TBLPROPERTIES tablePropertyList)? AS query #createView - | ALTER VIEW tableIdentifier AS? query #alterViewQuery - | CREATE TEMPORARY? FUNCTION qualifiedName AS className=STRING - (USING resource (',' resource)*)? #createFunction - | DROP TEMPORARY? FUNCTION (IF EXISTS)? qualifiedName #dropFunction - | EXPLAIN explainOption* statement #explain - | SHOW TABLES ((FROM | IN) db=identifier)? - (LIKE (qualifiedName | pattern=STRING))? #showTables - | SHOW FUNCTIONS (LIKE? (qualifiedName | pattern=STRING))? #showFunctions - | (DESC | DESCRIBE) FUNCTION EXTENDED? qualifiedName #describeFunction - | (DESC | DESCRIBE) option=(EXTENDED | FORMATTED)? - tableIdentifier partitionSpec? describeColName? #describeTable - | (DESC | DESCRIBE) DATABASE EXTENDED? identifier #describeDatabase - | REFRESH TABLE tableIdentifier #refreshTable - | CACHE LAZY? TABLE identifier (AS? query)? #cacheTable - | UNCACHE TABLE identifier #uncacheTable - | CLEAR CACHE #clearCache - | ADD identifier .*? #addResource - | SET ROLE .*? #failNativeCommand - | SET .*? #setConfiguration - | kws=unsupportedHiveNativeCommands .*? #failNativeCommand - | hiveNativeCommands #executeNativeCommand - ; - -hiveNativeCommands - : createTableHeader LIKE tableIdentifier - rowFormat? createFileFormat? locationSpec? - (TBLPROPERTIES tablePropertyList)? - | DELETE FROM tableIdentifier (WHERE booleanExpression)? - | TRUNCATE TABLE tableIdentifier partitionSpec? - (COLUMNS identifierList)? - | ALTER VIEW from=tableIdentifier AS? RENAME TO to=tableIdentifier - | ALTER VIEW from=tableIdentifier AS? - SET TBLPROPERTIES tablePropertyList - | ALTER VIEW from=tableIdentifier AS? - UNSET TBLPROPERTIES (IF EXISTS)? tablePropertyList - | ALTER VIEW from=tableIdentifier AS? - ADD (IF NOT EXISTS)? partitionSpecLocation+ - | ALTER VIEW from=tableIdentifier AS? - DROP (IF EXISTS)? partitionSpec (',' partitionSpec)* PURGE? - | DROP VIEW (IF EXISTS)? qualifiedName - | SHOW COLUMNS (FROM | IN) tableIdentifier ((FROM|IN) identifier)? - | START TRANSACTION (transactionMode (',' transactionMode)*)? - | COMMIT WORK? - | ROLLBACK WORK? - | SHOW PARTITIONS tableIdentifier partitionSpec? - | DFS .*? - | (CREATE | ALTER | DROP | SHOW | DESC | DESCRIBE | LOCK | UNLOCK | MSCK | LOAD) .*? - ; - -unsupportedHiveNativeCommands - : kw1=CREATE kw2=ROLE - | kw1=DROP kw2=ROLE - | kw1=GRANT kw2=ROLE? - | kw1=REVOKE kw2=ROLE? - | kw1=SHOW kw2=GRANT - | kw1=SHOW kw2=ROLE kw3=GRANT? - | kw1=SHOW kw2=PRINCIPALS - | kw1=SHOW kw2=ROLES - | kw1=SHOW kw2=CURRENT kw3=ROLES - | kw1=EXPORT kw2=TABLE - | kw1=IMPORT kw2=TABLE - | kw1=SHOW kw2=COMPACTIONS - | kw1=SHOW kw2=CREATE kw3=TABLE - | kw1=SHOW kw2=TRANSACTIONS - | kw1=SHOW kw2=INDEXES - | kw1=SHOW kw2=LOCKS - ; - -createTableHeader - : CREATE TEMPORARY? EXTERNAL? TABLE (IF NOT EXISTS)? tableIdentifier - ; - -bucketSpec - : CLUSTERED BY identifierList - (SORTED BY orderedIdentifierList)? - INTO INTEGER_VALUE BUCKETS - ; - -skewSpec - : SKEWED BY identifierList - ON (constantList | nestedConstantList) - (STORED AS DIRECTORIES)? - ; - -locationSpec - : LOCATION STRING - ; - -query - : ctes? queryNoWith - ; - -insertInto - : INSERT OVERWRITE TABLE tableIdentifier partitionSpec? (IF NOT EXISTS)? - | INSERT INTO TABLE? tableIdentifier partitionSpec? - ; - -partitionSpecLocation - : partitionSpec locationSpec? - ; - -partitionSpec - : PARTITION '(' partitionVal (',' partitionVal)* ')' - ; - -partitionVal - : identifier (EQ constant)? - ; - -describeColName - : identifier ('.' (identifier | STRING))* - ; - -ctes - : WITH namedQuery (',' namedQuery)* - ; - -namedQuery - : name=identifier AS? '(' queryNoWith ')' - ; - -tableProvider - : USING qualifiedName - ; - -tablePropertyList - : '(' tableProperty (',' tableProperty)* ')' - ; - -tableProperty - : key=tablePropertyKey (EQ? value=STRING)? - ; - -tablePropertyKey - : looseIdentifier ('.' looseIdentifier)* - | STRING - ; - -constantList - : '(' constant (',' constant)* ')' - ; - -nestedConstantList - : '(' constantList (',' constantList)* ')' - ; - -skewedLocation - : (constant | constantList) EQ STRING - ; - -skewedLocationList - : '(' skewedLocation (',' skewedLocation)* ')' - ; - -createFileFormat - : STORED AS fileFormat - | STORED BY storageHandler - ; - -fileFormat - : INPUTFORMAT inFmt=STRING OUTPUTFORMAT outFmt=STRING (SERDE serdeCls=STRING)? - (INPUTDRIVER inDriver=STRING OUTPUTDRIVER outDriver=STRING)? #tableFileFormat - | identifier #genericFileFormat - ; - -storageHandler - : STRING (WITH SERDEPROPERTIES tablePropertyList)? - ; - -resource - : identifier STRING - ; - -queryNoWith - : insertInto? queryTerm queryOrganization #singleInsertQuery - | fromClause multiInsertQueryBody+ #multiInsertQuery - ; - -queryOrganization - : (ORDER BY order+=sortItem (',' order+=sortItem)*)? - (CLUSTER BY clusterBy+=expression (',' clusterBy+=expression)*)? - (DISTRIBUTE BY distributeBy+=expression (',' distributeBy+=expression)*)? - (SORT BY sort+=sortItem (',' sort+=sortItem)*)? - windows? - (LIMIT limit=expression)? - ; - -multiInsertQueryBody - : insertInto? - querySpecification - queryOrganization - ; - -queryTerm - : queryPrimary #queryTermDefault - | left=queryTerm operator=(INTERSECT | UNION | EXCEPT) setQuantifier? right=queryTerm #setOperation - ; - -queryPrimary - : querySpecification #queryPrimaryDefault - | TABLE tableIdentifier #table - | inlineTable #inlineTableDefault1 - | '(' queryNoWith ')' #subquery - ; - -sortItem - : expression ordering=(ASC | DESC)? - ; - -querySpecification - : (((SELECT kind=TRANSFORM '(' namedExpressionSeq ')' - | kind=MAP namedExpressionSeq - | kind=REDUCE namedExpressionSeq)) - inRowFormat=rowFormat? - (RECORDWRITER recordWriter=STRING)? - USING script=STRING - (AS (identifierSeq | colTypeList | ('(' (identifierSeq | colTypeList) ')')))? - outRowFormat=rowFormat? - (RECORDREADER recordReader=STRING)? - fromClause? - (WHERE where=booleanExpression)?) - | ((kind=SELECT setQuantifier? namedExpressionSeq fromClause? - | fromClause (kind=SELECT setQuantifier? namedExpressionSeq)?) - lateralView* - (WHERE where=booleanExpression)? - aggregation? - (HAVING having=booleanExpression)? - windows?) - ; - -fromClause - : FROM relation (',' relation)* lateralView* - ; - -aggregation - : GROUP BY groupingExpressions+=expression (',' groupingExpressions+=expression)* ( - WITH kind=ROLLUP - | WITH kind=CUBE - | kind=GROUPING SETS '(' groupingSet (',' groupingSet)* ')')? - ; - -groupingSet - : '(' (expression (',' expression)*)? ')' - | expression - ; - -lateralView - : LATERAL VIEW (OUTER)? qualifiedName '(' (expression (',' expression)*)? ')' tblName=identifier (AS? colName+=identifier (',' colName+=identifier)*)? - ; - -setQuantifier - : DISTINCT - | ALL - ; - -relation - : left=relation - ((CROSS | joinType) JOIN right=relation joinCriteria? - | NATURAL joinType JOIN right=relation - ) #joinRelation - | relationPrimary #relationDefault - ; - -joinType - : INNER? - | LEFT OUTER? - | LEFT SEMI - | RIGHT OUTER? - | FULL OUTER? - ; - -joinCriteria - : ON booleanExpression - | USING '(' identifier (',' identifier)* ')' - ; - -sample - : TABLESAMPLE '(' - ( (percentage=(INTEGER_VALUE | DECIMAL_VALUE) sampleType=PERCENTLIT) - | (expression sampleType=ROWS) - | (sampleType=BUCKET numerator=INTEGER_VALUE OUT OF denominator=INTEGER_VALUE (ON identifier)?)) - ')' - ; - -identifierList - : '(' identifierSeq ')' - ; - -identifierSeq - : identifier (',' identifier)* - ; - -orderedIdentifierList - : '(' orderedIdentifier (',' orderedIdentifier)* ')' - ; - -orderedIdentifier - : identifier ordering=(ASC | DESC)? - ; - -identifierCommentList - : '(' identifierComment (',' identifierComment)* ')' - ; - -identifierComment - : identifier (COMMENT STRING)? - ; - -relationPrimary - : tableIdentifier sample? (AS? identifier)? #tableName - | '(' queryNoWith ')' sample? (AS? identifier)? #aliasedQuery - | '(' relation ')' sample? (AS? identifier)? #aliasedRelation - | inlineTable #inlineTableDefault2 - ; - -inlineTable - : VALUES expression (',' expression)* (AS? identifier identifierList?)? - ; - -rowFormat - : ROW FORMAT SERDE name=STRING (WITH SERDEPROPERTIES props=tablePropertyList)? #rowFormatSerde - | ROW FORMAT DELIMITED - (FIELDS TERMINATED BY fieldsTerminatedBy=STRING (ESCAPED BY escapedBy=STRING)?)? - (COLLECTION ITEMS TERMINATED BY collectionItemsTerminatedBy=STRING)? - (MAP KEYS TERMINATED BY keysTerminatedBy=STRING)? - (LINES TERMINATED BY linesSeparatedBy=STRING)? - (NULL DEFINED AS nullDefinedAs=STRING)? #rowFormatDelimited - ; - -tableIdentifier - : (db=identifier '.')? table=identifier - ; - -namedExpression - : expression (AS? (identifier | identifierList))? - ; - -namedExpressionSeq - : namedExpression (',' namedExpression)* - ; - -expression - : booleanExpression - ; - -booleanExpression - : predicated #booleanDefault - | NOT booleanExpression #logicalNot - | left=booleanExpression operator=AND right=booleanExpression #logicalBinary - | left=booleanExpression operator=OR right=booleanExpression #logicalBinary - | EXISTS '(' query ')' #exists - ; - -// workaround for: -// https://github.com/antlr/antlr4/issues/780 -// https://github.com/antlr/antlr4/issues/781 -predicated - : valueExpression predicate[$valueExpression.ctx]? - ; - -predicate[ParserRuleContext value] - : NOT? BETWEEN lower=valueExpression AND upper=valueExpression #between - | NOT? IN '(' expression (',' expression)* ')' #inList - | NOT? IN '(' query ')' #inSubquery - | NOT? like=(RLIKE | LIKE) pattern=valueExpression #like - | IS NOT? NULL #nullPredicate - ; - -valueExpression - : primaryExpression #valueExpressionDefault - | operator=(MINUS | PLUS | TILDE) valueExpression #arithmeticUnary - | left=valueExpression operator=(ASTERISK | SLASH | PERCENT | DIV) right=valueExpression #arithmeticBinary - | left=valueExpression operator=(PLUS | MINUS) right=valueExpression #arithmeticBinary - | left=valueExpression operator=AMPERSAND right=valueExpression #arithmeticBinary - | left=valueExpression operator=HAT right=valueExpression #arithmeticBinary - | left=valueExpression operator=PIPE right=valueExpression #arithmeticBinary - | left=valueExpression comparisonOperator right=valueExpression #comparison - ; - -primaryExpression - : constant #constantDefault - | ASTERISK #star - | qualifiedName '.' ASTERISK #star - | '(' expression (',' expression)+ ')' #rowConstructor - | qualifiedName '(' (setQuantifier? expression (',' expression)*)? ')' (OVER windowSpec)? #functionCall - | '(' query ')' #subqueryExpression - | CASE valueExpression whenClause+ (ELSE elseExpression=expression)? END #simpleCase - | CASE whenClause+ (ELSE elseExpression=expression)? END #searchedCase - | CAST '(' expression AS dataType ')' #cast - | value=primaryExpression '[' index=valueExpression ']' #subscript - | identifier #columnReference - | base=primaryExpression '.' fieldName=identifier #dereference - | '(' expression ')' #parenthesizedExpression - ; - -constant - : NULL #nullLiteral - | interval #intervalLiteral - | identifier STRING #typeConstructor - | number #numericLiteral - | booleanValue #booleanLiteral - | STRING+ #stringLiteral - ; - -comparisonOperator - : EQ | NEQ | NEQJ | LT | LTE | GT | GTE | NSEQ - ; - -booleanValue - : TRUE | FALSE - ; - -interval - : INTERVAL intervalField* - ; - -intervalField - : value=intervalValue unit=identifier (TO to=identifier)? - ; - -intervalValue - : (PLUS | MINUS)? (INTEGER_VALUE | DECIMAL_VALUE) - | STRING - ; - -dataType - : complex=ARRAY '<' dataType '>' #complexDataType - | complex=MAP '<' dataType ',' dataType '>' #complexDataType - | complex=STRUCT ('<' colTypeList? '>' | NEQ) #complexDataType - | identifier ('(' INTEGER_VALUE (',' INTEGER_VALUE)* ')')? #primitiveDataType - ; - -colTypeList - : colType (',' colType)* - ; - -colType - : identifier ':'? dataType (COMMENT STRING)? - ; - -whenClause - : WHEN condition=expression THEN result=expression - ; - -windows - : WINDOW namedWindow (',' namedWindow)* - ; - -namedWindow - : identifier AS windowSpec - ; - -windowSpec - : name=identifier #windowRef - | '(' - ( CLUSTER BY partition+=expression (',' partition+=expression)* - | ((PARTITION | DISTRIBUTE) BY partition+=expression (',' partition+=expression)*)? - ((ORDER | SORT) BY sortItem (',' sortItem)*)?) - windowFrame? - ')' #windowDef - ; - -windowFrame - : frameType=RANGE start=frameBound - | frameType=ROWS start=frameBound - | frameType=RANGE BETWEEN start=frameBound AND end=frameBound - | frameType=ROWS BETWEEN start=frameBound AND end=frameBound - ; - -frameBound - : UNBOUNDED boundType=(PRECEDING | FOLLOWING) - | boundType=CURRENT ROW - | expression boundType=(PRECEDING | FOLLOWING) - ; - - -explainOption - : LOGICAL | FORMATTED | EXTENDED - ; - -transactionMode - : ISOLATION LEVEL SNAPSHOT #isolationLevel - | READ accessMode=(ONLY | WRITE) #transactionAccessMode - ; - -qualifiedName - : identifier ('.' identifier)* - ; - -// Identifier that also allows the use of a number of SQL keywords (mainly for backwards compatibility). -looseIdentifier - : identifier - | FROM - | TO - | TABLE - | WITH - ; - -identifier - : IDENTIFIER #unquotedIdentifier - | quotedIdentifier #quotedIdentifierAlternative - | nonReserved #unquotedIdentifier - ; - -quotedIdentifier - : BACKQUOTED_IDENTIFIER - ; - -number - : DECIMAL_VALUE #decimalLiteral - | SCIENTIFIC_DECIMAL_VALUE #scientificDecimalLiteral - | INTEGER_VALUE #integerLiteral - | BIGINT_LITERAL #bigIntLiteral - | SMALLINT_LITERAL #smallIntLiteral - | TINYINT_LITERAL #tinyIntLiteral - | DOUBLE_LITERAL #doubleLiteral - ; - -nonReserved - : SHOW | TABLES | COLUMNS | COLUMN | PARTITIONS | FUNCTIONS - | ADD - | OVER | PARTITION | RANGE | ROWS | PRECEDING | FOLLOWING | CURRENT | ROW | MAP | ARRAY | STRUCT - | LATERAL | WINDOW | REDUCE | TRANSFORM | USING | SERDE | SERDEPROPERTIES | RECORDREADER - | DELIMITED | FIELDS | TERMINATED | COLLECTION | ITEMS | KEYS | ESCAPED | LINES | SEPARATED - | EXTENDED | REFRESH | CLEAR | CACHE | UNCACHE | LAZY | TEMPORARY | OPTIONS - | GROUPING | CUBE | ROLLUP - | EXPLAIN | FORMAT | LOGICAL | FORMATTED - | TABLESAMPLE | USE | TO | BUCKET | PERCENTLIT | OUT | OF - | SET - | VIEW | REPLACE - | IF - | NO | DATA - | START | TRANSACTION | COMMIT | ROLLBACK | WORK | ISOLATION | LEVEL - | SNAPSHOT | READ | WRITE | ONLY - | SORT | CLUSTER | DISTRIBUTE UNSET | TBLPROPERTIES | SKEWED | STORED | DIRECTORIES | LOCATION - | EXCHANGE | ARCHIVE | UNARCHIVE | FILEFORMAT | TOUCH | COMPACT | CONCATENATE | CHANGE | FIRST - | AFTER | CASCADE | RESTRICT | BUCKETS | CLUSTERED | SORTED | PURGE | INPUTFORMAT | OUTPUTFORMAT - | INPUTDRIVER | OUTPUTDRIVER | DBPROPERTIES | DFS | TRUNCATE | METADATA | REPLICATION | COMPUTE - | STATISTICS | ANALYZE | PARTITIONED | EXTERNAL | DEFINED | RECORDWRITER - | REVOKE | GRANT | LOCK | UNLOCK | MSCK | EXPORT | IMPORT | LOAD | VALUES | COMMENT | ROLE - | ROLES | COMPACTIONS | PRINCIPALS | TRANSACTIONS | INDEXES | LOCKS | OPTION - ; - -SELECT: 'SELECT'; -FROM: 'FROM'; -ADD: 'ADD'; -AS: 'AS'; -ALL: 'ALL'; -DISTINCT: 'DISTINCT'; -WHERE: 'WHERE'; -GROUP: 'GROUP'; -BY: 'BY'; -GROUPING: 'GROUPING'; -SETS: 'SETS'; -CUBE: 'CUBE'; -ROLLUP: 'ROLLUP'; -ORDER: 'ORDER'; -HAVING: 'HAVING'; -LIMIT: 'LIMIT'; -AT: 'AT'; -OR: 'OR'; -AND: 'AND'; -IN: 'IN'; -NOT: 'NOT' | '!'; -NO: 'NO'; -EXISTS: 'EXISTS'; -BETWEEN: 'BETWEEN'; -LIKE: 'LIKE'; -RLIKE: 'RLIKE' | 'REGEXP'; -IS: 'IS'; -NULL: 'NULL'; -TRUE: 'TRUE'; -FALSE: 'FALSE'; -NULLS: 'NULLS'; -ASC: 'ASC'; -DESC: 'DESC'; -FOR: 'FOR'; -INTERVAL: 'INTERVAL'; -CASE: 'CASE'; -WHEN: 'WHEN'; -THEN: 'THEN'; -ELSE: 'ELSE'; -END: 'END'; -JOIN: 'JOIN'; -CROSS: 'CROSS'; -OUTER: 'OUTER'; -INNER: 'INNER'; -LEFT: 'LEFT'; -SEMI: 'SEMI'; -RIGHT: 'RIGHT'; -FULL: 'FULL'; -NATURAL: 'NATURAL'; -ON: 'ON'; -LATERAL: 'LATERAL'; -WINDOW: 'WINDOW'; -OVER: 'OVER'; -PARTITION: 'PARTITION'; -RANGE: 'RANGE'; -ROWS: 'ROWS'; -UNBOUNDED: 'UNBOUNDED'; -PRECEDING: 'PRECEDING'; -FOLLOWING: 'FOLLOWING'; -CURRENT: 'CURRENT'; -ROW: 'ROW'; -WITH: 'WITH'; -VALUES: 'VALUES'; -CREATE: 'CREATE'; -TABLE: 'TABLE'; -VIEW: 'VIEW'; -REPLACE: 'REPLACE'; -INSERT: 'INSERT'; -DELETE: 'DELETE'; -INTO: 'INTO'; -DESCRIBE: 'DESCRIBE'; -EXPLAIN: 'EXPLAIN'; -FORMAT: 'FORMAT'; -LOGICAL: 'LOGICAL'; -CAST: 'CAST'; -SHOW: 'SHOW'; -TABLES: 'TABLES'; -COLUMNS: 'COLUMNS'; -COLUMN: 'COLUMN'; -USE: 'USE'; -PARTITIONS: 'PARTITIONS'; -FUNCTIONS: 'FUNCTIONS'; -DROP: 'DROP'; -UNION: 'UNION'; -EXCEPT: 'EXCEPT'; -INTERSECT: 'INTERSECT'; -TO: 'TO'; -TABLESAMPLE: 'TABLESAMPLE'; -STRATIFY: 'STRATIFY'; -ALTER: 'ALTER'; -RENAME: 'RENAME'; -ARRAY: 'ARRAY'; -MAP: 'MAP'; -STRUCT: 'STRUCT'; -COMMENT: 'COMMENT'; -SET: 'SET'; -DATA: 'DATA'; -START: 'START'; -TRANSACTION: 'TRANSACTION'; -COMMIT: 'COMMIT'; -ROLLBACK: 'ROLLBACK'; -WORK: 'WORK'; -ISOLATION: 'ISOLATION'; -LEVEL: 'LEVEL'; -SNAPSHOT: 'SNAPSHOT'; -READ: 'READ'; -WRITE: 'WRITE'; -ONLY: 'ONLY'; - -IF: 'IF'; - -EQ : '=' | '=='; -NSEQ: '<=>'; -NEQ : '<>'; -NEQJ: '!='; -LT : '<'; -LTE : '<='; -GT : '>'; -GTE : '>='; - -PLUS: '+'; -MINUS: '-'; -ASTERISK: '*'; -SLASH: '/'; -PERCENT: '%'; -DIV: 'DIV'; -TILDE: '~'; -AMPERSAND: '&'; -PIPE: '|'; -HAT: '^'; - -PERCENTLIT: 'PERCENT'; -BUCKET: 'BUCKET'; -OUT: 'OUT'; -OF: 'OF'; - -SORT: 'SORT'; -CLUSTER: 'CLUSTER'; -DISTRIBUTE: 'DISTRIBUTE'; -OVERWRITE: 'OVERWRITE'; -TRANSFORM: 'TRANSFORM'; -REDUCE: 'REDUCE'; -USING: 'USING'; -SERDE: 'SERDE'; -SERDEPROPERTIES: 'SERDEPROPERTIES'; -RECORDREADER: 'RECORDREADER'; -RECORDWRITER: 'RECORDWRITER'; -DELIMITED: 'DELIMITED'; -FIELDS: 'FIELDS'; -TERMINATED: 'TERMINATED'; -COLLECTION: 'COLLECTION'; -ITEMS: 'ITEMS'; -KEYS: 'KEYS'; -ESCAPED: 'ESCAPED'; -LINES: 'LINES'; -SEPARATED: 'SEPARATED'; -FUNCTION: 'FUNCTION'; -EXTENDED: 'EXTENDED'; -REFRESH: 'REFRESH'; -CLEAR: 'CLEAR'; -CACHE: 'CACHE'; -UNCACHE: 'UNCACHE'; -LAZY: 'LAZY'; -FORMATTED: 'FORMATTED'; -TEMPORARY: 'TEMPORARY' | 'TEMP'; -OPTIONS: 'OPTIONS'; -UNSET: 'UNSET'; -TBLPROPERTIES: 'TBLPROPERTIES'; -DBPROPERTIES: 'DBPROPERTIES'; -BUCKETS: 'BUCKETS'; -SKEWED: 'SKEWED'; -STORED: 'STORED'; -DIRECTORIES: 'DIRECTORIES'; -LOCATION: 'LOCATION'; -EXCHANGE: 'EXCHANGE'; -ARCHIVE: 'ARCHIVE'; -UNARCHIVE: 'UNARCHIVE'; -FILEFORMAT: 'FILEFORMAT'; -TOUCH: 'TOUCH'; -COMPACT: 'COMPACT'; -CONCATENATE: 'CONCATENATE'; -CHANGE: 'CHANGE'; -FIRST: 'FIRST'; -AFTER: 'AFTER'; -CASCADE: 'CASCADE'; -RESTRICT: 'RESTRICT'; -CLUSTERED: 'CLUSTERED'; -SORTED: 'SORTED'; -PURGE: 'PURGE'; -INPUTFORMAT: 'INPUTFORMAT'; -OUTPUTFORMAT: 'OUTPUTFORMAT'; -INPUTDRIVER: 'INPUTDRIVER'; -OUTPUTDRIVER: 'OUTPUTDRIVER'; -DATABASE: 'DATABASE' | 'SCHEMA'; -DFS: 'DFS'; -TRUNCATE: 'TRUNCATE'; -METADATA: 'METADATA'; -REPLICATION: 'REPLICATION'; -ANALYZE: 'ANALYZE'; -COMPUTE: 'COMPUTE'; -STATISTICS: 'STATISTICS'; -PARTITIONED: 'PARTITIONED'; -EXTERNAL: 'EXTERNAL'; -DEFINED: 'DEFINED'; -REVOKE: 'REVOKE'; -GRANT: 'GRANT'; -LOCK: 'LOCK'; -UNLOCK: 'UNLOCK'; -MSCK: 'MSCK'; -EXPORT: 'EXPORT'; -IMPORT: 'IMPORT'; -LOAD: 'LOAD'; -ROLE: 'ROLE'; -ROLES: 'ROLES'; -COMPACTIONS: 'COMPACTIONS'; -PRINCIPALS: 'PRINCIPALS'; -TRANSACTIONS: 'TRANSACTIONS'; -INDEXES: 'INDEXES'; -LOCKS: 'LOCKS'; -OPTION: 'OPTION'; - -STRING - : '\'' ( ~('\''|'\\') | ('\\' .) )* '\'' - | '\"' ( ~('\"'|'\\') | ('\\' .) )* '\"' - ; - -BIGINT_LITERAL - : DIGIT+ 'L' - ; - -SMALLINT_LITERAL - : DIGIT+ 'S' - ; - -TINYINT_LITERAL - : DIGIT+ 'Y' - ; - -INTEGER_VALUE - : DIGIT+ - ; - -DECIMAL_VALUE - : DIGIT+ '.' DIGIT* - | '.' DIGIT+ - ; - -SCIENTIFIC_DECIMAL_VALUE - : DIGIT+ ('.' DIGIT*)? EXPONENT - | '.' DIGIT+ EXPONENT - ; - -DOUBLE_LITERAL - : - (INTEGER_VALUE | DECIMAL_VALUE | SCIENTIFIC_DECIMAL_VALUE) 'D' - ; - -IDENTIFIER - : (LETTER | DIGIT | '_')+ - ; - -BACKQUOTED_IDENTIFIER - : '`' ( ~'`' | '``' )* '`' - ; - -fragment EXPONENT - : 'E' [+-]? DIGIT+ - ; - -fragment DIGIT - : [0-9] - ; - -fragment LETTER - : [A-Z] - ; - -SIMPLE_COMMENT - : '--' ~[\r\n]* '\r'? '\n'? -> channel(HIDDEN) - ; - -BRACKETED_COMMENT - : '/*' .*? '*/' -> channel(HIDDEN) - ; - -WS - : [ \r\n\t]+ -> channel(HIDDEN) - ; - -// Catch-all for anything we can't recognize. -// We use this to be able to ignore and recover all the text -// when splitting statements with DelimiterLexer -UNRECOGNIZED - : . - ; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ASTNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ASTNode.scala deleted file mode 100644 index 28f7b10ed6..0000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ASTNode.scala +++ /dev/null @@ -1,99 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.sql.catalyst.parser - -import org.antlr.runtime.{Token, TokenRewriteStream} - -import org.apache.spark.sql.catalyst.trees.{Origin, TreeNode} - -case class ASTNode( - token: Token, - startIndex: Int, - stopIndex: Int, - children: List[ASTNode], - stream: TokenRewriteStream) extends TreeNode[ASTNode] { - /** Cache the number of children. */ - val numChildren: Int = children.size - - /** tuple used in pattern matching. */ - val pattern: Some[(String, List[ASTNode])] = Some((token.getText, children)) - - /** Line in which the ASTNode starts. */ - lazy val line: Int = { - val line = token.getLine - if (line == 0) { - if (children.nonEmpty) children.head.line - else 0 - } else { - line - } - } - - /** Position of the Character at which ASTNode starts. */ - lazy val positionInLine: Int = { - val line = token.getCharPositionInLine - if (line == -1) { - if (children.nonEmpty) children.head.positionInLine - else 0 - } else { - line - } - } - - /** Origin of the ASTNode. */ - override val origin: Origin = Origin(Some(line), Some(positionInLine)) - - /** Source text. */ - lazy val source: String = stream.toOriginalString(startIndex, stopIndex) - - /** Get the source text that remains after this token. */ - lazy val remainder: String = { - stream.fill() - stream.toOriginalString(stopIndex + 1, stream.size() - 1).trim() - } - - def text: String = token.getText - - def tokenType: Int = token.getType - - /** - * Checks if this node is equal to another node. - * - * Right now this function only checks the name, type, text and children of the node - * for equality. - */ - def treeEquals(other: ASTNode): Boolean = { - def check(f: ASTNode => Any): Boolean = { - val l = f(this) - val r = f(other) - (l == null && r == null) || l.equals(r) - } - if (other == null) { - false - } else if (!check(_.token.getType) - || !check(_.token.getText) - || !check(_.numChildren)) { - false - } else { - children.zip(other.children).forall { - case (l, r) => l treeEquals r - } - } - } - - override def simpleString: String = s"$text $line, $startIndex, $stopIndex, $positionInLine " -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AbstractSparkSQLParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AbstractSparkSQLParser.scala deleted file mode 100644 index 7b456a6de3..0000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AbstractSparkSQLParser.scala +++ /dev/null @@ -1,145 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.parser - -import scala.language.implicitConversions -import scala.util.parsing.combinator.lexical.StdLexical -import scala.util.parsing.combinator.syntactical.StandardTokenParsers -import scala.util.parsing.combinator.PackratParsers -import scala.util.parsing.input.CharArrayReader.EofCh - -import org.apache.spark.sql.catalyst.plans.logical._ - -private[sql] abstract class AbstractSparkSQLParser - extends StandardTokenParsers with PackratParsers with ParserInterface { - - def parsePlan(input: String): LogicalPlan = synchronized { - // Initialize the Keywords. - initLexical - phrase(start)(new lexical.Scanner(input)) match { - case Success(plan, _) => plan - case failureOrError => sys.error(failureOrError.toString) - } - } - /* One time initialization of lexical.This avoid reinitialization of lexical in parse method */ - protected lazy val initLexical: Unit = lexical.initialize(reservedWords) - - protected case class Keyword(str: String) { - def normalize: String = lexical.normalizeKeyword(str) - def parser: Parser[String] = normalize - } - - protected implicit def asParser(k: Keyword): Parser[String] = k.parser - - // By default, use Reflection to find the reserved words defined in the sub class. - // NOTICE, Since the Keyword properties defined by sub class, we couldn't call this - // method during the parent class instantiation, because the sub class instance - // isn't created yet. - protected lazy val reservedWords: Seq[String] = - this - .getClass - .getMethods - .filter(_.getReturnType == classOf[Keyword]) - .map(_.invoke(this).asInstanceOf[Keyword].normalize) - - // Set the keywords as empty by default, will change that later. - override val lexical = new SqlLexical - - protected def start: Parser[LogicalPlan] - - // Returns the whole input string - protected lazy val wholeInput: Parser[String] = new Parser[String] { - def apply(in: Input): ParseResult[String] = - Success(in.source.toString, in.drop(in.source.length())) - } - - // Returns the rest of the input string that are not parsed yet - protected lazy val restInput: Parser[String] = new Parser[String] { - def apply(in: Input): ParseResult[String] = - Success( - in.source.subSequence(in.offset, in.source.length()).toString, - in.drop(in.source.length())) - } -} - -class SqlLexical extends StdLexical { - case class DecimalLit(chars: String) extends Token { - override def toString: String = chars - } - - /* This is a work around to support the lazy setting */ - def initialize(keywords: Seq[String]): Unit = { - reserved.clear() - reserved ++= keywords - } - - /* Normal the keyword string */ - def normalizeKeyword(str: String): String = str.toLowerCase - - delimiters += ( - "@", "*", "+", "-", "<", "=", "<>", "!=", "<=", ">=", ">", "/", "(", ")", - ",", ";", "%", "{", "}", ":", "[", "]", ".", "&", "|", "^", "~", "<=>" - ) - - protected override def processIdent(name: String) = { - val token = normalizeKeyword(name) - if (reserved contains token) Keyword(token) else Identifier(name) - } - - override lazy val token: Parser[Token] = - ( rep1(digit) ~ scientificNotation ^^ { case i ~ s => DecimalLit(i.mkString + s) } - | '.' ~> (rep1(digit) ~ scientificNotation) ^^ - { case i ~ s => DecimalLit("0." + i.mkString + s) } - | rep1(digit) ~ ('.' ~> digit.*) ~ scientificNotation ^^ - { case i1 ~ i2 ~ s => DecimalLit(i1.mkString + "." + i2.mkString + s) } - | digit.* ~ identChar ~ (identChar | digit).* ^^ - { case first ~ middle ~ rest => processIdent((first ++ (middle :: rest)).mkString) } - | rep1(digit) ~ ('.' ~> digit.*).? ^^ { - case i ~ None => NumericLit(i.mkString) - case i ~ Some(d) => DecimalLit(i.mkString + "." + d.mkString) - } - | '\'' ~> chrExcept('\'', '\n', EofCh).* <~ '\'' ^^ - { case chars => StringLit(chars mkString "") } - | '"' ~> chrExcept('"', '\n', EofCh).* <~ '"' ^^ - { case chars => StringLit(chars mkString "") } - | '`' ~> chrExcept('`', '\n', EofCh).* <~ '`' ^^ - { case chars => Identifier(chars mkString "") } - | EofCh ^^^ EOF - | '\'' ~> failure("unclosed string literal") - | '"' ~> failure("unclosed string literal") - | delim - | failure("illegal character") - ) - - override def identChar: Parser[Elem] = letter | elem('_') - - private lazy val scientificNotation: Parser[String] = - (elem('e') | elem('E')) ~> (elem('+') | elem('-')).? ~ rep1(digit) ^^ { - case s ~ rest => "e" + s.mkString + rest.mkString - } - - override def whitespace: Parser[Any] = - ( whitespaceChar - | '/' ~ '*' ~ comment - | '/' ~ '/' ~ chrExcept(EofCh, '\n').* - | '#' ~ chrExcept(EofCh, '\n').* - | '-' ~ '-' ~ chrExcept(EofCh, '\n').* - | '/' ~ '*' ~ failure("unclosed comment") - ).* -} - diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala new file mode 100644 index 0000000000..c350f3049f --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -0,0 +1,1460 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.parser + +import java.sql.{Date, Timestamp} + +import scala.collection.JavaConverters._ +import scala.collection.mutable.ArrayBuffer + +import org.antlr.v4.runtime.{ParserRuleContext, Token} +import org.antlr.v4.runtime.tree.{ParseTree, TerminalNode} + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.catalyst.{InternalRow, TableIdentifier} +import org.apache.spark.sql.catalyst.analysis._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.parser.SqlBaseParser._ +import org.apache.spark.sql.catalyst.plans._ +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.CalendarInterval +import org.apache.spark.util.random.RandomSampler + +/** + * The AstBuilder converts an ANTLR4 ParseTree into a catalyst Expression, LogicalPlan or + * TableIdentifier. + */ +class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { + import ParserUtils._ + + protected def typedVisit[T](ctx: ParseTree): T = { + ctx.accept(this).asInstanceOf[T] + } + + override def visitSingleStatement(ctx: SingleStatementContext): LogicalPlan = withOrigin(ctx) { + visit(ctx.statement).asInstanceOf[LogicalPlan] + } + + override def visitSingleExpression(ctx: SingleExpressionContext): Expression = withOrigin(ctx) { + visitNamedExpression(ctx.namedExpression) + } + + override def visitSingleTableIdentifier( + ctx: SingleTableIdentifierContext): TableIdentifier = withOrigin(ctx) { + visitTableIdentifier(ctx.tableIdentifier) + } + + override def visitSingleDataType(ctx: SingleDataTypeContext): DataType = withOrigin(ctx) { + visit(ctx.dataType).asInstanceOf[DataType] + } + + /* ******************************************************************************************** + * Plan parsing + * ******************************************************************************************** */ + protected def plan(tree: ParserRuleContext): LogicalPlan = typedVisit(tree) + + /** + * Make sure we do not try to create a plan for a native command. + */ + override def visitExecuteNativeCommand(ctx: ExecuteNativeCommandContext): LogicalPlan = null + + /** + * Create a plan for a SHOW FUNCTIONS command. + */ + override def visitShowFunctions(ctx: ShowFunctionsContext): LogicalPlan = withOrigin(ctx) { + import ctx._ + if (qualifiedName != null) { + val names = qualifiedName().identifier().asScala.map(_.getText).toList + names match { + case db :: name :: Nil => + ShowFunctions(Some(db), Some(name)) + case name :: Nil => + ShowFunctions(None, Some(name)) + case _ => + throw new ParseException("SHOW FUNCTIONS unsupported name", ctx) + } + } else if (pattern != null) { + ShowFunctions(None, Some(string(pattern))) + } else { + ShowFunctions(None, None) + } + } + + /** + * Create a plan for a DESCRIBE FUNCTION command. + */ + override def visitDescribeFunction(ctx: DescribeFunctionContext): LogicalPlan = withOrigin(ctx) { + val functionName = ctx.qualifiedName().identifier().asScala.map(_.getText).mkString(".") + DescribeFunction(functionName, ctx.EXTENDED != null) + } + + /** + * Create a top-level plan with Common Table Expressions. + */ + override def visitQuery(ctx: QueryContext): LogicalPlan = withOrigin(ctx) { + val query = plan(ctx.queryNoWith) + + // Apply CTEs + query.optional(ctx.ctes) { + val ctes = ctx.ctes.namedQuery.asScala.map { + case nCtx => + val namedQuery = visitNamedQuery(nCtx) + (namedQuery.alias, namedQuery) + } + + // Check for duplicate names. + ctes.groupBy(_._1).filter(_._2.size > 1).foreach { + case (name, _) => + throw new ParseException( + s"Name '$name' is used for multiple common table expressions", ctx) + } + + With(query, ctes.toMap) + } + } + + /** + * Create a named logical plan. + * + * This is only used for Common Table Expressions. + */ + override def visitNamedQuery(ctx: NamedQueryContext): SubqueryAlias = withOrigin(ctx) { + SubqueryAlias(ctx.name.getText, plan(ctx.queryNoWith)) + } + + /** + * Create a logical plan which allows for multiple inserts using one 'from' statement. These + * queries have the following SQL form: + * {{{ + * [WITH cte...]? + * FROM src + * [INSERT INTO tbl1 SELECT *]+ + * }}} + * For example: + * {{{ + * FROM db.tbl1 A + * INSERT INTO dbo.tbl1 SELECT * WHERE A.value = 10 LIMIT 5 + * INSERT INTO dbo.tbl2 SELECT * WHERE A.value = 12 + * }}} + * This (Hive) feature cannot be combined with set-operators. + */ + override def visitMultiInsertQuery(ctx: MultiInsertQueryContext): LogicalPlan = withOrigin(ctx) { + val from = visitFromClause(ctx.fromClause) + + // Build the insert clauses. + val inserts = ctx.multiInsertQueryBody.asScala.map { + body => + assert(body.querySpecification.fromClause == null, + "Multi-Insert queries cannot have a FROM clause in their individual SELECT statements", + body) + + withQuerySpecification(body.querySpecification, from). + // Add organization statements. + optionalMap(body.queryOrganization)(withQueryResultClauses). + // Add insert. + optionalMap(body.insertInto())(withInsertInto) + } + + // If there are multiple INSERTS just UNION them together into one query. + inserts match { + case Seq(query) => query + case queries => Union(queries) + } + } + + /** + * Create a logical plan for a regular (single-insert) query. + */ + override def visitSingleInsertQuery( + ctx: SingleInsertQueryContext): LogicalPlan = withOrigin(ctx) { + plan(ctx.queryTerm). + // Add organization statements. + optionalMap(ctx.queryOrganization)(withQueryResultClauses). + // Add insert. + optionalMap(ctx.insertInto())(withInsertInto) + } + + /** + * Add an INSERT INTO [TABLE]/INSERT OVERWRITE TABLE operation to the logical plan. + */ + private def withInsertInto( + ctx: InsertIntoContext, + query: LogicalPlan): LogicalPlan = withOrigin(ctx) { + val tableIdent = visitTableIdentifier(ctx.tableIdentifier) + val partitionKeys = Option(ctx.partitionSpec).map(visitPartitionSpec).getOrElse(Map.empty) + + InsertIntoTable( + UnresolvedRelation(tableIdent, None), + partitionKeys, + query, + ctx.OVERWRITE != null, + ctx.EXISTS != null) + } + + /** + * Create a partition specification map. + */ + override def visitPartitionSpec( + ctx: PartitionSpecContext): Map[String, Option[String]] = withOrigin(ctx) { + ctx.partitionVal.asScala.map { pVal => + val name = pVal.identifier.getText.toLowerCase + val value = Option(pVal.constant).map(visitStringConstant) + name -> value + }.toMap + } + + /** + * Create a partition specification map without optional values. + */ + protected def visitNonOptionalPartitionSpec( + ctx: PartitionSpecContext): Map[String, String] = withOrigin(ctx) { + visitPartitionSpec(ctx).mapValues(_.orNull).map(identity) + } + + /** + * Convert a constant of any type into a string. This is typically used in DDL commands, and its + * main purpose is to prevent slight differences due to back to back conversions i.e.: + * String -> Literal -> String. + */ + protected def visitStringConstant(ctx: ConstantContext): String = withOrigin(ctx) { + ctx match { + case s: StringLiteralContext => createString(s) + case o => o.getText + } + } + + /** + * Add ORDER BY/SORT BY/CLUSTER BY/DISTRIBUTE BY/LIMIT/WINDOWS clauses to the logical plan. These + * clauses determine the shape (ordering/partitioning/rows) of the query result. + */ + private def withQueryResultClauses( + ctx: QueryOrganizationContext, + query: LogicalPlan): LogicalPlan = withOrigin(ctx) { + import ctx._ + + // Handle ORDER BY, SORT BY, DISTRIBUTE BY, and CLUSTER BY clause. + val withOrder = if ( + !order.isEmpty && sort.isEmpty && distributeBy.isEmpty && clusterBy.isEmpty) { + // ORDER BY ... + Sort(order.asScala.map(visitSortItem), global = true, query) + } else if (order.isEmpty && !sort.isEmpty && distributeBy.isEmpty && clusterBy.isEmpty) { + // SORT BY ... + Sort(sort.asScala.map(visitSortItem), global = false, query) + } else if (order.isEmpty && sort.isEmpty && !distributeBy.isEmpty && clusterBy.isEmpty) { + // DISTRIBUTE BY ... + RepartitionByExpression(expressionList(distributeBy), query) + } else if (order.isEmpty && !sort.isEmpty && !distributeBy.isEmpty && clusterBy.isEmpty) { + // SORT BY ... DISTRIBUTE BY ... + Sort( + sort.asScala.map(visitSortItem), + global = false, + RepartitionByExpression(expressionList(distributeBy), query)) + } else if (order.isEmpty && sort.isEmpty && distributeBy.isEmpty && !clusterBy.isEmpty) { + // CLUSTER BY ... + val expressions = expressionList(clusterBy) + Sort( + expressions.map(SortOrder(_, Ascending)), + global = false, + RepartitionByExpression(expressions, query)) + } else if (order.isEmpty && sort.isEmpty && distributeBy.isEmpty && clusterBy.isEmpty) { + // [EMPTY] + query + } else { + throw new ParseException( + "Combination of ORDER BY/SORT BY/DISTRIBUTE BY/CLUSTER BY is not supported", ctx) + } + + // WINDOWS + val withWindow = withOrder.optionalMap(windows)(withWindows) + + // LIMIT + withWindow.optional(limit) { + Limit(typedVisit(limit), withWindow) + } + } + + /** + * Create a logical plan using a query specification. + */ + override def visitQuerySpecification( + ctx: QuerySpecificationContext): LogicalPlan = withOrigin(ctx) { + val from = OneRowRelation.optional(ctx.fromClause) { + visitFromClause(ctx.fromClause) + } + withQuerySpecification(ctx, from) + } + + /** + * Add a query specification to a logical plan. The query specification is the core of the logical + * plan, this is where sourcing (FROM clause), transforming (SELECT TRANSFORM/MAP/REDUCE), + * projection (SELECT), aggregation (GROUP BY ... HAVING ...) and filtering (WHERE) takes place. + * + * Note that query hints are ignored (both by the parser and the builder). + */ + private def withQuerySpecification( + ctx: QuerySpecificationContext, + relation: LogicalPlan): LogicalPlan = withOrigin(ctx) { + import ctx._ + + // WHERE + def filter(ctx: BooleanExpressionContext, plan: LogicalPlan): LogicalPlan = { + Filter(expression(ctx), plan) + } + + // Expressions. + val expressions = Option(namedExpressionSeq).toSeq + .flatMap(_.namedExpression.asScala) + .map(typedVisit[Expression]) + + // Create either a transform or a regular query. + val specType = Option(kind).map(_.getType).getOrElse(SqlBaseParser.SELECT) + specType match { + case SqlBaseParser.MAP | SqlBaseParser.REDUCE | SqlBaseParser.TRANSFORM => + // Transform + + // Add where. + val withFilter = relation.optionalMap(where)(filter) + + // Create the attributes. + val (attributes, schemaLess) = if (colTypeList != null) { + // Typed return columns. + (createStructType(colTypeList).toAttributes, false) + } else if (identifierSeq != null) { + // Untyped return columns. + val attrs = visitIdentifierSeq(identifierSeq).map { name => + AttributeReference(name, StringType, nullable = true)() + } + (attrs, false) + } else { + (Seq(AttributeReference("key", StringType)(), + AttributeReference("value", StringType)()), true) + } + + // Create the transform. + ScriptTransformation( + expressions, + string(script), + attributes, + withFilter, + withScriptIOSchema(inRowFormat, recordWriter, outRowFormat, recordReader, schemaLess)) + + case SqlBaseParser.SELECT => + // Regular select + + // Add lateral views. + val withLateralView = ctx.lateralView.asScala.foldLeft(relation)(withGenerate) + + // Add where. + val withFilter = withLateralView.optionalMap(where)(filter) + + // Add aggregation or a project. + val namedExpressions = expressions.map { + case e: NamedExpression => e + case e: Expression => UnresolvedAlias(e) + } + val withProject = if (aggregation != null) { + withAggregation(aggregation, namedExpressions, withFilter) + } else if (namedExpressions.nonEmpty) { + Project(namedExpressions, withFilter) + } else { + withFilter + } + + // Having + val withHaving = withProject.optional(having) { + // Note that we added a cast to boolean. If the expression itself is already boolean, + // the optimizer will get rid of the unnecessary cast. + Filter(Cast(expression(having), BooleanType), withProject) + } + + // Distinct + val withDistinct = if (setQuantifier() != null && setQuantifier().DISTINCT() != null) { + Distinct(withHaving) + } else { + withHaving + } + + // Window + withDistinct.optionalMap(windows)(withWindows) + } + } + + /** + * Create a (Hive based) [[ScriptInputOutputSchema]]. + */ + protected def withScriptIOSchema( + inRowFormat: RowFormatContext, + recordWriter: Token, + outRowFormat: RowFormatContext, + recordReader: Token, + schemaLess: Boolean): ScriptInputOutputSchema = null + + /** + * Create a logical plan for a given 'FROM' clause. Note that we support multiple (comma + * separated) relations here, these get converted into a single plan by condition-less inner join. + */ + override def visitFromClause(ctx: FromClauseContext): LogicalPlan = withOrigin(ctx) { + val from = ctx.relation.asScala.map(plan).reduceLeft(Join(_, _, Inner, None)) + ctx.lateralView.asScala.foldLeft(from)(withGenerate) + } + + /** + * Connect two queries by a Set operator. + * + * Supported Set operators are: + * - UNION [DISTINCT] + * - UNION ALL + * - EXCEPT [DISTINCT] + * - INTERSECT [DISTINCT] + */ + override def visitSetOperation(ctx: SetOperationContext): LogicalPlan = withOrigin(ctx) { + val left = plan(ctx.left) + val right = plan(ctx.right) + val all = Option(ctx.setQuantifier()).exists(_.ALL != null) + ctx.operator.getType match { + case SqlBaseParser.UNION if all => + Union(left, right) + case SqlBaseParser.UNION => + Distinct(Union(left, right)) + case SqlBaseParser.INTERSECT if all => + throw new ParseException("INTERSECT ALL is not supported.", ctx) + case SqlBaseParser.INTERSECT => + Intersect(left, right) + case SqlBaseParser.EXCEPT if all => + throw new ParseException("EXCEPT ALL is not supported.", ctx) + case SqlBaseParser.EXCEPT => + Except(left, right) + } + } + + /** + * Add a [[WithWindowDefinition]] operator to a logical plan. + */ + private def withWindows( + ctx: WindowsContext, + query: LogicalPlan): LogicalPlan = withOrigin(ctx) { + // Collect all window specifications defined in the WINDOW clause. + val baseWindowMap = ctx.namedWindow.asScala.map { + wCtx => + (wCtx.identifier.getText, typedVisit[WindowSpec](wCtx.windowSpec)) + }.toMap + + // Handle cases like + // window w1 as (partition by p_mfgr order by p_name + // range between 2 preceding and 2 following), + // w2 as w1 + val windowMapView = baseWindowMap.mapValues { + case WindowSpecReference(name) => + baseWindowMap.get(name) match { + case Some(spec: WindowSpecDefinition) => + spec + case Some(ref) => + throw new ParseException(s"Window reference '$name' is not a window specification", ctx) + case None => + throw new ParseException(s"Cannot resolve window reference '$name'", ctx) + } + case spec: WindowSpecDefinition => spec + } + + // Note that mapValues creates a view instead of materialized map. We force materialization by + // mapping over identity. + WithWindowDefinition(windowMapView.map(identity), query) + } + + /** + * Add an [[Aggregate]] to a logical plan. + */ + private def withAggregation( + ctx: AggregationContext, + selectExpressions: Seq[NamedExpression], + query: LogicalPlan): LogicalPlan = withOrigin(ctx) { + import ctx._ + val groupByExpressions = expressionList(groupingExpressions) + + if (GROUPING != null) { + // GROUP BY .... GROUPING SETS (...) + val expressionMap = groupByExpressions.zipWithIndex.toMap + val numExpressions = expressionMap.size + val mask = (1 << numExpressions) - 1 + val masks = ctx.groupingSet.asScala.map { + _.expression.asScala.foldLeft(mask) { + case (bitmap, eCtx) => + // Find the index of the expression. + val e = typedVisit[Expression](eCtx) + val index = expressionMap.find(_._1.semanticEquals(e)).map(_._2).getOrElse( + throw new ParseException( + s"$e doesn't show up in the GROUP BY list", ctx)) + // 0 means that the column at the given index is a grouping column, 1 means it is not, + // so we unset the bit in bitmap. + bitmap & ~(1 << (numExpressions - 1 - index)) + } + } + GroupingSets(masks, groupByExpressions, query, selectExpressions) + } else { + // GROUP BY .... (WITH CUBE | WITH ROLLUP)? + val mappedGroupByExpressions = if (CUBE != null) { + Seq(Cube(groupByExpressions)) + } else if (ROLLUP != null) { + Seq(Rollup(groupByExpressions)) + } else { + groupByExpressions + } + Aggregate(mappedGroupByExpressions, selectExpressions, query) + } + } + + /** + * Add a [[Generate]] (Lateral View) to a logical plan. + */ + private def withGenerate( + query: LogicalPlan, + ctx: LateralViewContext): LogicalPlan = withOrigin(ctx) { + val expressions = expressionList(ctx.expression) + + // Create the generator. + val generator = ctx.qualifiedName.getText.toLowerCase match { + case "explode" if expressions.size == 1 => + Explode(expressions.head) + case "json_tuple" => + JsonTuple(expressions) + case other => + withGenerator(other, expressions, ctx) + } + + Generate( + generator, + join = true, + outer = ctx.OUTER != null, + Some(ctx.tblName.getText.toLowerCase), + ctx.colName.asScala.map(_.getText).map(UnresolvedAttribute.apply), + query) + } + + /** + * Create a [[Generator]]. Override this method in order to support custom Generators. + */ + protected def withGenerator( + name: String, + expressions: Seq[Expression], + ctx: LateralViewContext): Generator = { + throw new ParseException(s"Generator function '$name' is not supported", ctx) + } + + /** + * Create a joins between two or more logical plans. + */ + override def visitJoinRelation(ctx: JoinRelationContext): LogicalPlan = withOrigin(ctx) { + /** Build a join between two plans. */ + def join(ctx: JoinRelationContext, left: LogicalPlan, right: LogicalPlan): Join = { + val baseJoinType = ctx.joinType match { + case null => Inner + case jt if jt.FULL != null => FullOuter + case jt if jt.SEMI != null => LeftSemi + case jt if jt.LEFT != null => LeftOuter + case jt if jt.RIGHT != null => RightOuter + case _ => Inner + } + + // Resolve the join type and join condition + val (joinType, condition) = Option(ctx.joinCriteria) match { + case Some(c) if c.USING != null => + val columns = c.identifier.asScala.map { column => + UnresolvedAttribute.quoted(column.getText) + } + (UsingJoin(baseJoinType, columns), None) + case Some(c) if c.booleanExpression != null => + (baseJoinType, Option(expression(c.booleanExpression))) + case None if ctx.NATURAL != null => + (NaturalJoin(baseJoinType), None) + case None => + (baseJoinType, None) + } + Join(left, right, joinType, condition) + } + + // Handle all consecutive join clauses. ANTLR produces a right nested tree in which the the + // first join clause is at the top. However fields of previously referenced tables can be used + // in following join clauses. The tree needs to be reversed in order to make this work. + var result = plan(ctx.left) + var current = ctx + while (current != null) { + current.right match { + case right: JoinRelationContext => + result = join(current, result, plan(right.left)) + current = right + case right => + result = join(current, result, plan(right)) + current = null + } + } + result + } + + /** + * Add a [[Sample]] to a logical plan. + * + * This currently supports the following sampling methods: + * - TABLESAMPLE(x ROWS): Sample the table down to the given number of rows. + * - TABLESAMPLE(x PERCENT): Sample the table down to the given percentage. Note that percentages + * are defined as a number between 0 and 100. + * - TABLESAMPLE(BUCKET x OUT OF y): Sample the table down to a 'x' divided by 'y' fraction. + */ + private def withSample(ctx: SampleContext, query: LogicalPlan): LogicalPlan = withOrigin(ctx) { + // Create a sampled plan if we need one. + def sample(fraction: Double): Sample = { + // The range of fraction accepted by Sample is [0, 1]. Because Hive's block sampling + // function takes X PERCENT as the input and the range of X is [0, 100], we need to + // adjust the fraction. + val eps = RandomSampler.roundingEpsilon + assert(fraction >= 0.0 - eps && fraction <= 1.0 + eps, + s"Sampling fraction ($fraction) must be on interval [0, 1]", + ctx) + Sample(0.0, fraction, withReplacement = false, (math.random * 1000).toInt, query)(true) + } + + ctx.sampleType.getType match { + case SqlBaseParser.ROWS => + Limit(expression(ctx.expression), query) + + case SqlBaseParser.PERCENTLIT => + val fraction = ctx.percentage.getText.toDouble + sample(fraction / 100.0d) + + case SqlBaseParser.BUCKET if ctx.ON != null => + throw new ParseException("TABLESAMPLE(BUCKET x OUT OF y ON id) is not supported", ctx) + + case SqlBaseParser.BUCKET => + sample(ctx.numerator.getText.toDouble / ctx.denominator.getText.toDouble) + } + } + + /** + * Create a logical plan for a sub-query. + */ + override def visitSubquery(ctx: SubqueryContext): LogicalPlan = withOrigin(ctx) { + plan(ctx.queryNoWith) + } + + /** + * Create an un-aliased table reference. This is typically used for top-level table references, + * for example: + * {{{ + * INSERT INTO db.tbl2 + * TABLE db.tbl1 + * }}} + */ + override def visitTable(ctx: TableContext): LogicalPlan = withOrigin(ctx) { + UnresolvedRelation(visitTableIdentifier(ctx.tableIdentifier), None) + } + + /** + * Create an aliased table reference. This is typically used in FROM clauses. + */ + override def visitTableName(ctx: TableNameContext): LogicalPlan = withOrigin(ctx) { + val table = UnresolvedRelation( + visitTableIdentifier(ctx.tableIdentifier), + Option(ctx.identifier).map(_.getText)) + table.optionalMap(ctx.sample)(withSample) + } + + /** + * Create an inline table (a virtual table in Hive parlance). + */ + override def visitInlineTable(ctx: InlineTableContext): LogicalPlan = withOrigin(ctx) { + // Get the backing expressions. + val expressions = ctx.expression.asScala.map { eCtx => + val e = expression(eCtx) + assert(e.foldable, "All expressions in an inline table must be constants.", eCtx) + e + } + + // Validate and evaluate the rows. + val (structType, structConstructor) = expressions.head.dataType match { + case st: StructType => + (st, (e: Expression) => e) + case dt => + val st = CreateStruct(Seq(expressions.head)).dataType + (st, (e: Expression) => CreateStruct(Seq(e))) + } + val rows = expressions.map { + case expression => + val safe = Cast(structConstructor(expression), structType) + safe.eval().asInstanceOf[InternalRow] + } + + // Construct attributes. + val baseAttributes = structType.toAttributes.map(_.withNullability(true)) + val attributes = if (ctx.identifierList != null) { + val aliases = visitIdentifierList(ctx.identifierList) + assert(aliases.size == baseAttributes.size, + "Number of aliases must match the number of fields in an inline table.", ctx) + baseAttributes.zip(aliases).map(p => p._1.withName(p._2)) + } else { + baseAttributes + } + + // Create plan and add an alias if a name has been defined. + LocalRelation(attributes, rows).optionalMap(ctx.identifier)(aliasPlan) + } + + /** + * Create an alias (SubqueryAlias) for a join relation. This is practically the same as + * visitAliasedQuery and visitNamedExpression, ANTLR4 however requires us to use 3 different + * hooks. + */ + override def visitAliasedRelation(ctx: AliasedRelationContext): LogicalPlan = withOrigin(ctx) { + plan(ctx.relation).optionalMap(ctx.sample)(withSample).optionalMap(ctx.identifier)(aliasPlan) + } + + /** + * Create an alias (SubqueryAlias) for a sub-query. This is practically the same as + * visitAliasedRelation and visitNamedExpression, ANTLR4 however requires us to use 3 different + * hooks. + */ + override def visitAliasedQuery(ctx: AliasedQueryContext): LogicalPlan = withOrigin(ctx) { + plan(ctx.queryNoWith).optionalMap(ctx.sample)(withSample).optionalMap(ctx.identifier)(aliasPlan) + } + + /** + * Create an alias (SubqueryAlias) for a LogicalPlan. + */ + private def aliasPlan(alias: IdentifierContext, plan: LogicalPlan): LogicalPlan = { + SubqueryAlias(alias.getText, plan) + } + + /** + * Create a Sequence of Strings for a parenthesis enclosed alias list. + */ + override def visitIdentifierList(ctx: IdentifierListContext): Seq[String] = withOrigin(ctx) { + visitIdentifierSeq(ctx.identifierSeq) + } + + /** + * Create a Sequence of Strings for an identifier list. + */ + override def visitIdentifierSeq(ctx: IdentifierSeqContext): Seq[String] = withOrigin(ctx) { + ctx.identifier.asScala.map(_.getText) + } + + /* ******************************************************************************************** + * Table Identifier parsing + * ******************************************************************************************** */ + /** + * Create a [[TableIdentifier]] from a 'tableName' or 'databaseName'.'tableName' pattern. + */ + override def visitTableIdentifier( + ctx: TableIdentifierContext): TableIdentifier = withOrigin(ctx) { + TableIdentifier(ctx.table.getText, Option(ctx.db).map(_.getText)) + } + + /* ******************************************************************************************** + * Expression parsing + * ******************************************************************************************** */ + /** + * Create an expression from the given context. This method just passes the context on to the + * vistor and only takes care of typing (We assume that the visitor returns an Expression here). + */ + protected def expression(ctx: ParserRuleContext): Expression = typedVisit(ctx) + + /** + * Create sequence of expressions from the given sequence of contexts. + */ + private def expressionList(trees: java.util.List[ExpressionContext]): Seq[Expression] = { + trees.asScala.map(expression) + } + + /** + * Invert a boolean expression if it has a valid NOT clause. + */ + private def invertIfNotDefined(expression: Expression, not: TerminalNode): Expression = { + if (not != null) { + Not(expression) + } else { + expression + } + } + + /** + * Create a star (i.e. all) expression; this selects all elements (in the specified object). + * Both un-targeted (global) and targeted aliases are supported. + */ + override def visitStar(ctx: StarContext): Expression = withOrigin(ctx) { + UnresolvedStar(Option(ctx.qualifiedName()).map(_.identifier.asScala.map(_.getText))) + } + + /** + * Create an aliased expression if an alias is specified. Both single and multi-aliases are + * supported. + */ + override def visitNamedExpression(ctx: NamedExpressionContext): Expression = withOrigin(ctx) { + val e = expression(ctx.expression) + if (ctx.identifier != null) { + Alias(e, ctx.identifier.getText)() + } else if (ctx.identifierList != null) { + MultiAlias(e, visitIdentifierList(ctx.identifierList)) + } else { + e + } + } + + /** + * Combine a number of boolean expressions into a balanced expression tree. These expressions are + * either combined by a logical [[And]] or a logical [[Or]]. + * + * A balanced binary tree is created because regular left recursive trees cause considerable + * performance degradations and can cause stack overflows. + */ + override def visitLogicalBinary(ctx: LogicalBinaryContext): Expression = withOrigin(ctx) { + val expressionType = ctx.operator.getType + val expressionCombiner = expressionType match { + case SqlBaseParser.AND => And.apply _ + case SqlBaseParser.OR => Or.apply _ + } + + // Collect all similar left hand contexts. + val contexts = ArrayBuffer(ctx.right) + var current = ctx.left + def collectContexts: Boolean = current match { + case lbc: LogicalBinaryContext if lbc.operator.getType == expressionType => + contexts += lbc.right + current = lbc.left + true + case _ => + contexts += current + false + } + while (collectContexts) { + // No body - all updates take place in the collectContexts. + } + + // Reverse the contexts to have them in the same sequence as in the SQL statement & turn them + // into expressions. + val expressions = contexts.reverse.map(expression) + + // Create a balanced tree. + def reduceToExpressionTree(low: Int, high: Int): Expression = high - low match { + case 0 => + expressions(low) + case 1 => + expressionCombiner(expressions(low), expressions(high)) + case x => + val mid = low + x / 2 + expressionCombiner( + reduceToExpressionTree(low, mid), + reduceToExpressionTree(mid + 1, high)) + } + reduceToExpressionTree(0, expressions.size - 1) + } + + /** + * Invert a boolean expression. + */ + override def visitLogicalNot(ctx: LogicalNotContext): Expression = withOrigin(ctx) { + Not(expression(ctx.booleanExpression())) + } + + /** + * Create a filtering correlated sub-query. This is not supported yet. + */ + override def visitExists(ctx: ExistsContext): Expression = { + throw new ParseException("EXISTS clauses are not supported.", ctx) + } + + /** + * Create a comparison expression. This compares two expressions. The following comparison + * operators are supported: + * - Equal: '=' or '==' + * - Null-safe Equal: '<=>' + * - Not Equal: '<>' or '!=' + * - Less than: '<' + * - Less then or Equal: '<=' + * - Greater than: '>' + * - Greater then or Equal: '>=' + */ + override def visitComparison(ctx: ComparisonContext): Expression = withOrigin(ctx) { + val left = expression(ctx.left) + val right = expression(ctx.right) + val operator = ctx.comparisonOperator().getChild(0).asInstanceOf[TerminalNode] + operator.getSymbol.getType match { + case SqlBaseParser.EQ => + EqualTo(left, right) + case SqlBaseParser.NSEQ => + EqualNullSafe(left, right) + case SqlBaseParser.NEQ | SqlBaseParser.NEQJ => + Not(EqualTo(left, right)) + case SqlBaseParser.LT => + LessThan(left, right) + case SqlBaseParser.LTE => + LessThanOrEqual(left, right) + case SqlBaseParser.GT => + GreaterThan(left, right) + case SqlBaseParser.GTE => + GreaterThanOrEqual(left, right) + } + } + + /** + * Create a BETWEEN expression. This tests if an expression lies with in the bounds set by two + * other expressions. The inverse can also be created. + */ + override def visitBetween(ctx: BetweenContext): Expression = withOrigin(ctx) { + val value = expression(ctx.value) + val between = And( + GreaterThanOrEqual(value, expression(ctx.lower)), + LessThanOrEqual(value, expression(ctx.upper))) + invertIfNotDefined(between, ctx.NOT) + } + + /** + * Create an IN expression. This tests if the value of the left hand side expression is + * contained by the sequence of expressions on the right hand side. + */ + override def visitInList(ctx: InListContext): Expression = withOrigin(ctx) { + val in = In(expression(ctx.value), ctx.expression().asScala.map(expression)) + invertIfNotDefined(in, ctx.NOT) + } + + /** + * Create an IN expression, where the the right hand side is a query. This is unsupported. + */ + override def visitInSubquery(ctx: InSubqueryContext): Expression = { + throw new ParseException("IN with a Sub-query is currently not supported.", ctx) + } + + /** + * Create a (R)LIKE/REGEXP expression. + */ + override def visitLike(ctx: LikeContext): Expression = { + val left = expression(ctx.value) + val right = expression(ctx.pattern) + val like = ctx.like.getType match { + case SqlBaseParser.LIKE => + Like(left, right) + case SqlBaseParser.RLIKE => + RLike(left, right) + } + invertIfNotDefined(like, ctx.NOT) + } + + /** + * Create an IS (NOT) NULL expression. + */ + override def visitNullPredicate(ctx: NullPredicateContext): Expression = withOrigin(ctx) { + val value = expression(ctx.value) + if (ctx.NOT != null) { + IsNotNull(value) + } else { + IsNull(value) + } + } + + /** + * Create a binary arithmetic expression. The following arithmetic operators are supported: + * - Mulitplication: '*' + * - Division: '/' + * - Hive Long Division: 'DIV' + * - Modulo: '%' + * - Addition: '+' + * - Subtraction: '-' + * - Binary AND: '&' + * - Binary XOR + * - Binary OR: '|' + */ + override def visitArithmeticBinary(ctx: ArithmeticBinaryContext): Expression = withOrigin(ctx) { + val left = expression(ctx.left) + val right = expression(ctx.right) + ctx.operator.getType match { + case SqlBaseParser.ASTERISK => + Multiply(left, right) + case SqlBaseParser.SLASH => + Divide(left, right) + case SqlBaseParser.PERCENT => + Remainder(left, right) + case SqlBaseParser.DIV => + Cast(Divide(left, right), LongType) + case SqlBaseParser.PLUS => + Add(left, right) + case SqlBaseParser.MINUS => + Subtract(left, right) + case SqlBaseParser.AMPERSAND => + BitwiseAnd(left, right) + case SqlBaseParser.HAT => + BitwiseXor(left, right) + case SqlBaseParser.PIPE => + BitwiseOr(left, right) + } + } + + /** + * Create a unary arithmetic expression. The following arithmetic operators are supported: + * - Plus: '+' + * - Minus: '-' + * - Bitwise Not: '~' + */ + override def visitArithmeticUnary(ctx: ArithmeticUnaryContext): Expression = withOrigin(ctx) { + val value = expression(ctx.valueExpression) + ctx.operator.getType match { + case SqlBaseParser.PLUS => + value + case SqlBaseParser.MINUS => + UnaryMinus(value) + case SqlBaseParser.TILDE => + BitwiseNot(value) + } + } + + /** + * Create a [[Cast]] expression. + */ + override def visitCast(ctx: CastContext): Expression = withOrigin(ctx) { + Cast(expression(ctx.expression), typedVisit(ctx.dataType)) + } + + /** + * Create a (windowed) Function expression. + */ + override def visitFunctionCall(ctx: FunctionCallContext): Expression = withOrigin(ctx) { + // Create the function call. + val name = ctx.qualifiedName.getText + val isDistinct = Option(ctx.setQuantifier()).exists(_.DISTINCT != null) + val arguments = ctx.expression().asScala.map(expression) match { + case Seq(UnresolvedStar(None)) if name.toLowerCase == "count" && !isDistinct => + // Transform COUNT(*) into COUNT(1). Move this to analysis? + Seq(Literal(1)) + case expressions => + expressions + } + val function = UnresolvedFunction(name, arguments, isDistinct) + + // Check if the function is evaluated in a windowed context. + ctx.windowSpec match { + case spec: WindowRefContext => + UnresolvedWindowExpression(function, visitWindowRef(spec)) + case spec: WindowDefContext => + WindowExpression(function, visitWindowDef(spec)) + case _ => function + } + } + + /** + * Create a reference to a window frame, i.e. [[WindowSpecReference]]. + */ + override def visitWindowRef(ctx: WindowRefContext): WindowSpecReference = withOrigin(ctx) { + WindowSpecReference(ctx.identifier.getText) + } + + /** + * Create a window definition, i.e. [[WindowSpecDefinition]]. + */ + override def visitWindowDef(ctx: WindowDefContext): WindowSpecDefinition = withOrigin(ctx) { + // CLUSTER BY ... | PARTITION BY ... ORDER BY ... + val partition = ctx.partition.asScala.map(expression) + val order = ctx.sortItem.asScala.map(visitSortItem) + + // RANGE/ROWS BETWEEN ... + val frameSpecOption = Option(ctx.windowFrame).map { frame => + val frameType = frame.frameType.getType match { + case SqlBaseParser.RANGE => RangeFrame + case SqlBaseParser.ROWS => RowFrame + } + + SpecifiedWindowFrame( + frameType, + visitFrameBound(frame.start), + Option(frame.end).map(visitFrameBound).getOrElse(CurrentRow)) + } + + WindowSpecDefinition( + partition, + order, + frameSpecOption.getOrElse(UnspecifiedFrame)) + } + + /** + * Create or resolve a [[FrameBoundary]]. Simple math expressions are allowed for Value + * Preceding/Following boundaries. These expressions must be constant (foldable) and return an + * integer value. + */ + override def visitFrameBound(ctx: FrameBoundContext): FrameBoundary = withOrigin(ctx) { + // We currently only allow foldable integers. + def value: Int = { + val e = expression(ctx.expression) + assert(e.resolved && e.foldable && e.dataType == IntegerType, + "Frame bound value must be a constant integer.", + ctx) + e.eval().asInstanceOf[Int] + } + + // Create the FrameBoundary + ctx.boundType.getType match { + case SqlBaseParser.PRECEDING if ctx.UNBOUNDED != null => + UnboundedPreceding + case SqlBaseParser.PRECEDING => + ValuePreceding(value) + case SqlBaseParser.CURRENT => + CurrentRow + case SqlBaseParser.FOLLOWING if ctx.UNBOUNDED != null => + UnboundedFollowing + case SqlBaseParser.FOLLOWING => + ValueFollowing(value) + } + } + + /** + * Create a [[CreateStruct]] expression. + */ + override def visitRowConstructor(ctx: RowConstructorContext): Expression = withOrigin(ctx) { + CreateStruct(ctx.expression.asScala.map(expression)) + } + + /** + * Create a [[ScalarSubquery]] expression. + */ + override def visitSubqueryExpression( + ctx: SubqueryExpressionContext): Expression = withOrigin(ctx) { + ScalarSubquery(plan(ctx.query)) + } + + /** + * Create a value based [[CaseWhen]] expression. This has the following SQL form: + * {{{ + * CASE [expression] + * WHEN [value] THEN [expression] + * ... + * ELSE [expression] + * END + * }}} + */ + override def visitSimpleCase(ctx: SimpleCaseContext): Expression = withOrigin(ctx) { + val e = expression(ctx.valueExpression) + val branches = ctx.whenClause.asScala.map { wCtx => + (EqualTo(e, expression(wCtx.condition)), expression(wCtx.result)) + } + CaseWhen(branches, Option(ctx.elseExpression).map(expression)) + } + + /** + * Create a condition based [[CaseWhen]] expression. This has the following SQL syntax: + * {{{ + * CASE + * WHEN [predicate] THEN [expression] + * ... + * ELSE [expression] + * END + * }}} + * + * @param ctx the parse tree + * */ + override def visitSearchedCase(ctx: SearchedCaseContext): Expression = withOrigin(ctx) { + val branches = ctx.whenClause.asScala.map { wCtx => + (expression(wCtx.condition), expression(wCtx.result)) + } + CaseWhen(branches, Option(ctx.elseExpression).map(expression)) + } + + /** + * Create a dereference expression. The return type depends on the type of the parent, this can + * either be a [[UnresolvedAttribute]] (if the parent is an [[UnresolvedAttribute]]), or an + * [[UnresolvedExtractValue]] if the parent is some expression. + */ + override def visitDereference(ctx: DereferenceContext): Expression = withOrigin(ctx) { + val attr = ctx.fieldName.getText + expression(ctx.base) match { + case UnresolvedAttribute(nameParts) => + UnresolvedAttribute(nameParts :+ attr) + case e => + UnresolvedExtractValue(e, Literal(attr)) + } + } + + /** + * Create an [[UnresolvedAttribute]] expression. + */ + override def visitColumnReference(ctx: ColumnReferenceContext): Expression = withOrigin(ctx) { + UnresolvedAttribute.quoted(ctx.getText) + } + + /** + * Create an [[UnresolvedExtractValue]] expression, this is used for subscript access to an array. + */ + override def visitSubscript(ctx: SubscriptContext): Expression = withOrigin(ctx) { + UnresolvedExtractValue(expression(ctx.value), expression(ctx.index)) + } + + /** + * Create an expression for an expression between parentheses. This is need because the ANTLR + * visitor cannot automatically convert the nested context into an expression. + */ + override def visitParenthesizedExpression( + ctx: ParenthesizedExpressionContext): Expression = withOrigin(ctx) { + expression(ctx.expression) + } + + /** + * Create a [[SortOrder]] expression. + */ + override def visitSortItem(ctx: SortItemContext): SortOrder = withOrigin(ctx) { + if (ctx.DESC != null) { + SortOrder(expression(ctx.expression), Descending) + } else { + SortOrder(expression(ctx.expression), Ascending) + } + } + + /** + * Create a typed Literal expression. A typed literal has the following SQL syntax: + * {{{ + * [TYPE] '[VALUE]' + * }}} + * Currently Date and Timestamp typed literals are supported. + * + * TODO what the added value of this over casting? + */ + override def visitTypeConstructor(ctx: TypeConstructorContext): Literal = withOrigin(ctx) { + val value = string(ctx.STRING) + ctx.identifier.getText.toUpperCase match { + case "DATE" => + Literal(Date.valueOf(value)) + case "TIMESTAMP" => + Literal(Timestamp.valueOf(value)) + case other => + throw new ParseException(s"Literals of type '$other' are currently not supported.", ctx) + } + } + + /** + * Create a NULL literal expression. + */ + override def visitNullLiteral(ctx: NullLiteralContext): Literal = withOrigin(ctx) { + Literal(null) + } + + /** + * Create a Boolean literal expression. + */ + override def visitBooleanLiteral(ctx: BooleanLiteralContext): Literal = withOrigin(ctx) { + if (ctx.getText.toBoolean) { + Literal.TrueLiteral + } else { + Literal.FalseLiteral + } + } + + /** + * Create an integral literal expression. The code selects the most narrow integral type + * possible, either a BigDecimal, a Long or an Integer is returned. + */ + override def visitIntegerLiteral(ctx: IntegerLiteralContext): Literal = withOrigin(ctx) { + BigDecimal(ctx.getText) match { + case v if v.isValidInt => + Literal(v.intValue()) + case v if v.isValidLong => + Literal(v.longValue()) + case v => Literal(v.underlying()) + } + } + + /** + * Create a double literal for a number denoted in scientifc notation. + */ + override def visitScientificDecimalLiteral( + ctx: ScientificDecimalLiteralContext): Literal = withOrigin(ctx) { + Literal(ctx.getText.toDouble) + } + + /** + * Create a decimal literal for a regular decimal number. + */ + override def visitDecimalLiteral(ctx: DecimalLiteralContext): Literal = withOrigin(ctx) { + Literal(BigDecimal(ctx.getText).underlying()) + } + + /** Create a numeric literal expression. */ + private def numericLiteral(ctx: NumberContext)(f: String => Any): Literal = withOrigin(ctx) { + val raw = ctx.getText + try { + Literal(f(raw.substring(0, raw.length - 1))) + } catch { + case e: NumberFormatException => + throw new ParseException(e.getMessage, ctx) + } + } + + /** + * Create a Byte Literal expression. + */ + override def visitTinyIntLiteral(ctx: TinyIntLiteralContext): Literal = numericLiteral(ctx) { + _.toByte + } + + /** + * Create a Short Literal expression. + */ + override def visitSmallIntLiteral(ctx: SmallIntLiteralContext): Literal = numericLiteral(ctx) { + _.toShort + } + + /** + * Create a Long Literal expression. + */ + override def visitBigIntLiteral(ctx: BigIntLiteralContext): Literal = numericLiteral(ctx) { + _.toLong + } + + /** + * Create a Double Literal expression. + */ + override def visitDoubleLiteral(ctx: DoubleLiteralContext): Literal = numericLiteral(ctx) { + _.toDouble + } + + /** + * Create a String literal expression. + */ + override def visitStringLiteral(ctx: StringLiteralContext): Literal = withOrigin(ctx) { + Literal(createString(ctx)) + } + + /** + * Create a String from a string literal context. This supports multiple consecutive string + * literals, these are concatenated, for example this expression "'hello' 'world'" will be + * converted into "helloworld". + * + * Special characters can be escaped by using Hive/C-style escaping. + */ + private def createString(ctx: StringLiteralContext): String = { + ctx.STRING().asScala.map(string).mkString + } + + /** + * Create a [[CalendarInterval]] literal expression. An interval expression can contain multiple + * unit value pairs, for instance: interval 2 months 2 days. + */ + override def visitInterval(ctx: IntervalContext): Literal = withOrigin(ctx) { + val intervals = ctx.intervalField.asScala.map(visitIntervalField) + assert(intervals.nonEmpty, "at least one time unit should be given for interval literal", ctx) + Literal(intervals.reduce(_.add(_))) + } + + /** + * Create a [[CalendarInterval]] for a unit value pair. Two unit configuration types are + * supported: + * - Single unit. + * - From-To unit (only 'YEAR TO MONTH' and 'DAY TO SECOND' are supported). + */ + override def visitIntervalField(ctx: IntervalFieldContext): CalendarInterval = withOrigin(ctx) { + import ctx._ + val s = value.getText + try { + val interval = (unit.getText.toLowerCase, Option(to).map(_.getText.toLowerCase)) match { + case (u, None) if u.endsWith("s") => + // Handle plural forms, e.g: yearS/monthS/weekS/dayS/hourS/minuteS/hourS/... + CalendarInterval.fromSingleUnitString(u.substring(0, u.length - 1), s) + case (u, None) => + CalendarInterval.fromSingleUnitString(u, s) + case ("year", Some("month")) => + CalendarInterval.fromYearMonthString(s) + case ("day", Some("second")) => + CalendarInterval.fromDayTimeString(s) + case (from, Some(t)) => + throw new ParseException(s"Intervals FROM $from TO $t are not supported.", ctx) + } + assert(interval != null, "No interval can be constructed", ctx) + interval + } catch { + // Handle Exceptions thrown by CalendarInterval + case e: IllegalArgumentException => + val pe = new ParseException(e.getMessage, ctx) + pe.setStackTrace(e.getStackTrace) + throw pe + } + } + + /* ******************************************************************************************** + * DataType parsing + * ******************************************************************************************** */ + /** + * Resolve/create a primitive type. + */ + override def visitPrimitiveDataType(ctx: PrimitiveDataTypeContext): DataType = withOrigin(ctx) { + (ctx.identifier.getText.toLowerCase, ctx.INTEGER_VALUE().asScala.toList) match { + case ("boolean", Nil) => BooleanType + case ("tinyint" | "byte", Nil) => ByteType + case ("smallint" | "short", Nil) => ShortType + case ("int" | "integer", Nil) => IntegerType + case ("bigint" | "long", Nil) => LongType + case ("float", Nil) => FloatType + case ("double", Nil) => DoubleType + case ("date", Nil) => DateType + case ("timestamp", Nil) => TimestampType + case ("char" | "varchar" | "string", Nil) => StringType + case ("char" | "varchar", _ :: Nil) => StringType + case ("binary", Nil) => BinaryType + case ("decimal", Nil) => DecimalType.USER_DEFAULT + case ("decimal", precision :: Nil) => DecimalType(precision.getText.toInt, 0) + case ("decimal", precision :: scale :: Nil) => + DecimalType(precision.getText.toInt, scale.getText.toInt) + case (dt, params) => + throw new ParseException( + s"DataType $dt${params.mkString("(", ",", ")")} is not supported.", ctx) + } + } + + /** + * Create a complex DataType. Arrays, Maps and Structures are supported. + */ + override def visitComplexDataType(ctx: ComplexDataTypeContext): DataType = withOrigin(ctx) { + ctx.complex.getType match { + case SqlBaseParser.ARRAY => + ArrayType(typedVisit(ctx.dataType(0))) + case SqlBaseParser.MAP => + MapType(typedVisit(ctx.dataType(0)), typedVisit(ctx.dataType(1))) + case SqlBaseParser.STRUCT => + createStructType(ctx.colTypeList()) + } + } + + /** + * Create a [[StructType]] from a sequence of [[StructField]]s. + */ + protected def createStructType(ctx: ColTypeListContext): StructType = { + StructType(Option(ctx).toSeq.flatMap(visitColTypeList)) + } + + /** + * Create a [[StructType]] from a number of column definitions. + */ + override def visitColTypeList(ctx: ColTypeListContext): Seq[StructField] = withOrigin(ctx) { + ctx.colType().asScala.map(visitColType) + } + + /** + * Create a [[StructField]] from a column definition. + */ + override def visitColType(ctx: ColTypeContext): StructField = withOrigin(ctx) { + import ctx._ + + // Add the comment to the metadata. + val builder = new MetadataBuilder + if (STRING != null) { + builder.putString("comment", string(STRING)) + } + + StructField(identifier.getText, typedVisit(dataType), nullable = true, builder.build()) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/CatalystQl.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/CatalystQl.scala deleted file mode 100644 index c188c5b108..0000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/CatalystQl.scala +++ /dev/null @@ -1,933 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.parser - -import java.sql.Date - -import scala.collection.mutable.ArrayBuffer -import scala.util.matching.Regex - -import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.catalyst.analysis._ -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.aggregate.Count -import org.apache.spark.sql.catalyst.plans._ -import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.CalendarInterval -import org.apache.spark.util.random.RandomSampler - - -/** - * This class translates SQL to Catalyst [[LogicalPlan]]s or [[Expression]]s. - */ -private[sql] class CatalystQl(val conf: ParserConf = SimpleParserConf()) extends ParserInterface { - import ParserUtils._ - - /** - * The safeParse method allows a user to focus on the parsing/AST transformation logic. This - * method will take care of possible errors during the parsing process. - */ - protected def safeParse[T](sql: String, ast: ASTNode)(toResult: ASTNode => T): T = { - try { - toResult(ast) - } catch { - case e: MatchError => throw e - case e: AnalysisException => throw e - case e: Exception => - throw new AnalysisException(e.getMessage) - case e: NotImplementedError => - throw new AnalysisException( - s"""Unsupported language features in query - |== SQL == - |$sql - |== AST == - |${ast.treeString} - |== Error == - |$e - |== Stacktrace == - |${e.getStackTrace.head} - """.stripMargin) - } - } - - /** Creates LogicalPlan for a given SQL string. */ - def parsePlan(sql: String): LogicalPlan = - safeParse(sql, ParseDriver.parsePlan(sql, conf))(nodeToPlan) - - /** Creates Expression for a given SQL string. */ - def parseExpression(sql: String): Expression = - safeParse(sql, ParseDriver.parseExpression(sql, conf))(selExprNodeToExpr(_).get) - - /** Creates TableIdentifier for a given SQL string. */ - def parseTableIdentifier(sql: String): TableIdentifier = - safeParse(sql, ParseDriver.parseTableName(sql, conf))(extractTableIdent) - - /** - * SELECT MAX(value) FROM src GROUP BY k1, k2, k3 GROUPING SETS((k1, k2), (k2)) - * is equivalent to - * SELECT MAX(value) FROM src GROUP BY k1, k2 UNION SELECT MAX(value) FROM src GROUP BY k2 - * Check the following link for details. - * -https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C+Grouping+and+Rollup - * - * The bitmask denotes the grouping expressions validity for a grouping set, - * the bitmask also be called as grouping id (`GROUPING__ID`, the virtual column in Hive) - * e.g. In superset (k1, k2, k3), (bit 2: k1, bit 1: k2, and bit 0: k3), the grouping id of - * GROUPING SETS (k1, k2) and (k2) should be 1 and 5 respectively. - */ - protected def extractGroupingSet(children: Seq[ASTNode]): (Seq[Expression], Seq[Int]) = { - val (keyASTs, setASTs) = children.partition { - case Token("TOK_GROUPING_SETS_EXPRESSION", _) => false // grouping sets - case _ => true // grouping keys - } - - val keys = keyASTs.map(nodeToExpr) - val keyMap = keyASTs.zipWithIndex.toMap - - val mask = (1 << keys.length) - 1 - val bitmasks: Seq[Int] = setASTs.map { - case Token("TOK_GROUPING_SETS_EXPRESSION", columns) => - columns.foldLeft(mask)((bitmap, col) => { - val keyIndex = keyMap.find(_._1.treeEquals(col)).map(_._2).getOrElse( - throw new AnalysisException(s"${col.treeString} doesn't show up in the GROUP BY list")) - // 0 means that the column at the given index is a grouping column, 1 means it is not, - // so we unset the bit in bitmap. - bitmap & ~(1 << (keys.length - 1 - keyIndex)) - }) - case _ => sys.error("Expect GROUPING SETS clause") - } - - (keys, bitmasks) - } - - protected def nodeToPlan(node: ASTNode): LogicalPlan = node match { - case Token("TOK_SHOWFUNCTIONS", args) => - // Skip LIKE. - val pattern = args match { - case like :: nodes if like.text.toUpperCase == "LIKE" => nodes - case nodes => nodes - } - - // Extract Database and Function name - pattern match { - case Nil => - ShowFunctions(None, None) - case Token(name, Nil) :: Nil => - ShowFunctions(None, Some(unquoteString(cleanIdentifier(name)))) - case Token(db, Nil) :: Token(name, Nil) :: Nil => - ShowFunctions(Some(unquoteString(cleanIdentifier(db))), - Some(unquoteString(cleanIdentifier(name)))) - case _ => - noParseRule("SHOW FUNCTIONS", node) - } - - case Token("TOK_DESCFUNCTION", Token(functionName, Nil) :: isExtended) => - DescribeFunction(cleanIdentifier(functionName), isExtended.nonEmpty) - - case Token("TOK_QUERY", queryArgs @ Token("TOK_CTE" | "TOK_FROM" | "TOK_INSERT", _) :: _) => - val (fromClause: Option[ASTNode], insertClauses, cteRelations) = - queryArgs match { - case Token("TOK_CTE", ctes) :: Token("TOK_FROM", from) :: inserts => - val cteRelations = ctes.map { node => - val relation = nodeToRelation(node).asInstanceOf[SubqueryAlias] - relation.alias -> relation - } - (Some(from.head), inserts, Some(cteRelations.toMap)) - case Token("TOK_FROM", from) :: inserts => - (Some(from.head), inserts, None) - case Token("TOK_INSERT", _) :: Nil => - (None, queryArgs, None) - } - - // Return one query for each insert clause. - val queries = insertClauses.map { - case Token("TOK_INSERT", singleInsert) => - val ( - intoClause :: - destClause :: - selectClause :: - selectDistinctClause :: - whereClause :: - groupByClause :: - rollupGroupByClause :: - cubeGroupByClause :: - groupingSetsClause :: - orderByClause :: - havingClause :: - sortByClause :: - clusterByClause :: - distributeByClause :: - limitClause :: - lateralViewClause :: - windowClause :: Nil) = { - getClauses( - Seq( - "TOK_INSERT_INTO", - "TOK_DESTINATION", - "TOK_SELECT", - "TOK_SELECTDI", - "TOK_WHERE", - "TOK_GROUPBY", - "TOK_ROLLUP_GROUPBY", - "TOK_CUBE_GROUPBY", - "TOK_GROUPING_SETS", - "TOK_ORDERBY", - "TOK_HAVING", - "TOK_SORTBY", - "TOK_CLUSTERBY", - "TOK_DISTRIBUTEBY", - "TOK_LIMIT", - "TOK_LATERAL_VIEW", - "WINDOW"), - singleInsert) - } - - val relations = fromClause match { - case Some(f) => nodeToRelation(f) - case None => OneRowRelation - } - - val withLateralView = lateralViewClause.map { lv => - nodeToGenerate(lv.children.head, outer = false, relations) - }.getOrElse(relations) - - val withWhere = whereClause.map { whereNode => - val Seq(whereExpr) = whereNode.children - Filter(nodeToExpr(whereExpr), withLateralView) - }.getOrElse(withLateralView) - - val select = (selectClause orElse selectDistinctClause) - .getOrElse(sys.error("No select clause.")) - - val transformation = nodeToTransformation(select.children.head, withWhere) - - // The projection of the query can either be a normal projection, an aggregation - // (if there is a group by) or a script transformation. - val withProject: LogicalPlan = transformation.getOrElse { - val selectExpressions = - select.children.flatMap(selExprNodeToExpr).map(UnresolvedAlias(_)) - Seq( - groupByClause.map(e => e match { - case Token("TOK_GROUPBY", children) => - // Not a transformation so must be either project or aggregation. - Aggregate(children.map(nodeToExpr), selectExpressions, withWhere) - case _ => sys.error("Expect GROUP BY") - }), - groupingSetsClause.map(e => e match { - case Token("TOK_GROUPING_SETS", children) => - val(groupByExprs, masks) = extractGroupingSet(children) - GroupingSets(masks, groupByExprs, withWhere, selectExpressions) - case _ => sys.error("Expect GROUPING SETS") - }), - rollupGroupByClause.map(e => e match { - case Token("TOK_ROLLUP_GROUPBY", children) => - Aggregate( - Seq(Rollup(children.map(nodeToExpr))), - selectExpressions, - withWhere) - case _ => sys.error("Expect WITH ROLLUP") - }), - cubeGroupByClause.map(e => e match { - case Token("TOK_CUBE_GROUPBY", children) => - Aggregate( - Seq(Cube(children.map(nodeToExpr))), - selectExpressions, - withWhere) - case _ => sys.error("Expect WITH CUBE") - }), - Some(Project(selectExpressions, withWhere))).flatten.head - } - - // Handle HAVING clause. - val withHaving = havingClause.map { h => - val havingExpr = h.children match { case Seq(hexpr) => nodeToExpr(hexpr) } - // Note that we added a cast to boolean. If the expression itself is already boolean, - // the optimizer will get rid of the unnecessary cast. - Filter(Cast(havingExpr, BooleanType), withProject) - }.getOrElse(withProject) - - // Handle SELECT DISTINCT - val withDistinct = - if (selectDistinctClause.isDefined) Distinct(withHaving) else withHaving - - // Handle ORDER BY, SORT BY, DISTRIBUTE BY, and CLUSTER BY clause. - val withSort = - (orderByClause, sortByClause, distributeByClause, clusterByClause) match { - case (Some(totalOrdering), None, None, None) => - Sort(totalOrdering.children.map(nodeToSortOrder), global = true, withDistinct) - case (None, Some(perPartitionOrdering), None, None) => - Sort( - perPartitionOrdering.children.map(nodeToSortOrder), - global = false, withDistinct) - case (None, None, Some(partitionExprs), None) => - RepartitionByExpression( - partitionExprs.children.map(nodeToExpr), withDistinct) - case (None, Some(perPartitionOrdering), Some(partitionExprs), None) => - Sort( - perPartitionOrdering.children.map(nodeToSortOrder), global = false, - RepartitionByExpression( - partitionExprs.children.map(nodeToExpr), - withDistinct)) - case (None, None, None, Some(clusterExprs)) => - Sort( - clusterExprs.children.map(nodeToExpr).map(SortOrder(_, Ascending)), - global = false, - RepartitionByExpression( - clusterExprs.children.map(nodeToExpr), - withDistinct)) - case (None, None, None, None) => withDistinct - case _ => sys.error("Unsupported set of ordering / distribution clauses.") - } - - val withLimit = - limitClause.map(l => nodeToExpr(l.children.head)) - .map(Limit(_, withSort)) - .getOrElse(withSort) - - // Collect all window specifications defined in the WINDOW clause. - val windowDefinitions = windowClause.map(_.children.collect { - case Token("TOK_WINDOWDEF", - Token(windowName, Nil) :: Token("TOK_WINDOWSPEC", spec) :: Nil) => - windowName -> nodesToWindowSpecification(spec) - }.toMap) - // Handle cases like - // window w1 as (partition by p_mfgr order by p_name - // range between 2 preceding and 2 following), - // w2 as w1 - val resolvedCrossReference = windowDefinitions.map { - windowDefMap => windowDefMap.map { - case (windowName, WindowSpecReference(other)) => - (windowName, windowDefMap(other).asInstanceOf[WindowSpecDefinition]) - case o => o.asInstanceOf[(String, WindowSpecDefinition)] - } - } - - val withWindowDefinitions = - resolvedCrossReference.map(WithWindowDefinition(_, withLimit)).getOrElse(withLimit) - - // TOK_INSERT_INTO means to add files to the table. - // TOK_DESTINATION means to overwrite the table. - val resultDestination = - (intoClause orElse destClause).getOrElse(sys.error("No destination found.")) - val overwrite = intoClause.isEmpty - nodeToDest( - resultDestination, - withWindowDefinitions, - overwrite) - } - - // If there are multiple INSERTS just UNION them together into one query. - val query = if (queries.length == 1) queries.head else Union(queries) - - // return With plan if there is CTE - cteRelations.map(With(query, _)).getOrElse(query) - - case Token("TOK_UNIONALL", left :: right :: Nil) => - Union(nodeToPlan(left), nodeToPlan(right)) - case Token("TOK_UNIONDISTINCT", left :: right :: Nil) => - Distinct(Union(nodeToPlan(left), nodeToPlan(right))) - case Token("TOK_EXCEPT", left :: right :: Nil) => - Except(nodeToPlan(left), nodeToPlan(right)) - case Token("TOK_INTERSECT", left :: right :: Nil) => - Intersect(nodeToPlan(left), nodeToPlan(right)) - - case _ => - noParseRule("Plan", node) - } - - val allJoinTokens = "(TOK_.*JOIN)".r - val laterViewToken = "TOK_LATERAL_VIEW(.*)".r - protected def nodeToRelation(node: ASTNode): LogicalPlan = { - node match { - case Token("TOK_SUBQUERY", query :: Token(alias, Nil) :: Nil) => - SubqueryAlias(cleanIdentifier(alias), nodeToPlan(query)) - - case Token(laterViewToken(isOuter), selectClause :: relationClause :: Nil) => - nodeToGenerate( - selectClause, - outer = isOuter.nonEmpty, - nodeToRelation(relationClause)) - - /* All relations, possibly with aliases or sampling clauses. */ - case Token("TOK_TABREF", clauses) => - // If the last clause is not a token then it's the alias of the table. - val (nonAliasClauses, aliasClause) = - if (clauses.last.text.startsWith("TOK")) { - (clauses, None) - } else { - (clauses.dropRight(1), Some(clauses.last)) - } - - val (Some(tableNameParts) :: - splitSampleClause :: - bucketSampleClause :: Nil) = { - getClauses(Seq("TOK_TABNAME", "TOK_TABLESPLITSAMPLE", "TOK_TABLEBUCKETSAMPLE"), - nonAliasClauses) - } - - val tableIdent = extractTableIdent(tableNameParts) - val alias = aliasClause.map { case Token(a, Nil) => cleanIdentifier(a) } - val relation = UnresolvedRelation(tableIdent, alias) - - // Apply sampling if requested. - (bucketSampleClause orElse splitSampleClause).map { - case Token("TOK_TABLESPLITSAMPLE", - Token("TOK_ROWCOUNT", Nil) :: Token(count, Nil) :: Nil) => - Limit(Literal(count.toInt), relation) - case Token("TOK_TABLESPLITSAMPLE", - Token("TOK_PERCENT", Nil) :: Token(fraction, Nil) :: Nil) => - // The range of fraction accepted by Sample is [0, 1]. Because Hive's block sampling - // function takes X PERCENT as the input and the range of X is [0, 100], we need to - // adjust the fraction. - require( - fraction.toDouble >= (0.0 - RandomSampler.roundingEpsilon) - && fraction.toDouble <= (100.0 + RandomSampler.roundingEpsilon), - s"Sampling fraction ($fraction) must be on interval [0, 100]") - Sample(0.0, fraction.toDouble / 100, withReplacement = false, - (math.random * 1000).toInt, - relation)( - isTableSample = true) - case Token("TOK_TABLEBUCKETSAMPLE", - Token(numerator, Nil) :: - Token(denominator, Nil) :: Nil) => - val fraction = numerator.toDouble / denominator.toDouble - Sample(0.0, fraction, withReplacement = false, (math.random * 1000).toInt, relation)( - isTableSample = true) - case a => - noParseRule("Sampling", a) - }.getOrElse(relation) - - case Token(allJoinTokens(joinToken), relation1 :: relation2 :: other) => - if (!(other.size <= 1)) { - sys.error(s"Unsupported join operation: $other") - } - - val (joinType, joinCondition) = getJoinInfo(joinToken, other, node) - - Join(nodeToRelation(relation1), - nodeToRelation(relation2), - joinType, - joinCondition) - case _ => - noParseRule("Relation", node) - } - } - - protected def getJoinInfo( - joinToken: String, - joinConditionToken: Seq[ASTNode], - node: ASTNode): (JoinType, Option[Expression]) = { - val joinType = joinToken match { - case "TOK_JOIN" => Inner - case "TOK_CROSSJOIN" => Inner - case "TOK_RIGHTOUTERJOIN" => RightOuter - case "TOK_LEFTOUTERJOIN" => LeftOuter - case "TOK_FULLOUTERJOIN" => FullOuter - case "TOK_LEFTSEMIJOIN" => LeftSemi - case "TOK_UNIQUEJOIN" => noParseRule("Unique Join", node) - case "TOK_ANTIJOIN" => noParseRule("Anti Join", node) - case "TOK_NATURALJOIN" => NaturalJoin(Inner) - case "TOK_NATURALRIGHTOUTERJOIN" => NaturalJoin(RightOuter) - case "TOK_NATURALLEFTOUTERJOIN" => NaturalJoin(LeftOuter) - case "TOK_NATURALFULLOUTERJOIN" => NaturalJoin(FullOuter) - } - - joinConditionToken match { - case Token("TOK_USING", columnList :: Nil) :: Nil => - val colNames = columnList.children.collect { - case Token(name, Nil) => UnresolvedAttribute(name) - } - (UsingJoin(joinType, colNames), None) - /* Join expression specified using ON clause */ - case _ => (joinType, joinConditionToken.headOption.map(nodeToExpr)) - } - } - - protected def nodeToSortOrder(node: ASTNode): SortOrder = node match { - case Token("TOK_TABSORTCOLNAMEASC", sortExpr :: Nil) => - SortOrder(nodeToExpr(sortExpr), Ascending) - case Token("TOK_TABSORTCOLNAMEDESC", sortExpr :: Nil) => - SortOrder(nodeToExpr(sortExpr), Descending) - case _ => - noParseRule("SortOrder", node) - } - - val destinationToken = "TOK_DESTINATION|TOK_INSERT_INTO".r - protected def nodeToDest( - node: ASTNode, - query: LogicalPlan, - overwrite: Boolean): LogicalPlan = node match { - case Token(destinationToken(), - Token("TOK_DIR", - Token("TOK_TMP_FILE", Nil) :: Nil) :: Nil) => - query - - case Token(destinationToken(), - Token("TOK_TAB", - tableArgs) :: Nil) => - val Some(tableNameParts) :: partitionClause :: Nil = - getClauses(Seq("TOK_TABNAME", "TOK_PARTSPEC"), tableArgs) - - val tableIdent = extractTableIdent(tableNameParts) - - val partitionKeys = partitionClause.map(_.children.map { - // Parse partitions. We also make keys case insensitive. - case Token("TOK_PARTVAL", Token(key, Nil) :: Token(value, Nil) :: Nil) => - cleanIdentifier(key.toLowerCase) -> Some(unquoteString(value)) - case Token("TOK_PARTVAL", Token(key, Nil) :: Nil) => - cleanIdentifier(key.toLowerCase) -> None - }.toMap).getOrElse(Map.empty) - - InsertIntoTable( - UnresolvedRelation(tableIdent, None), partitionKeys, query, overwrite, ifNotExists = false) - - case Token(destinationToken(), - Token("TOK_TAB", - tableArgs) :: - Token("TOK_IFNOTEXISTS", - ifNotExists) :: Nil) => - val Some(tableNameParts) :: partitionClause :: Nil = - getClauses(Seq("TOK_TABNAME", "TOK_PARTSPEC"), tableArgs) - - val tableIdent = extractTableIdent(tableNameParts) - - val partitionKeys = partitionClause.map(_.children.map { - // Parse partitions. We also make keys case insensitive. - case Token("TOK_PARTVAL", Token(key, Nil) :: Token(value, Nil) :: Nil) => - cleanIdentifier(key.toLowerCase) -> Some(unquoteString(value)) - case Token("TOK_PARTVAL", Token(key, Nil) :: Nil) => - cleanIdentifier(key.toLowerCase) -> None - }.toMap).getOrElse(Map.empty) - - InsertIntoTable( - UnresolvedRelation(tableIdent, None), partitionKeys, query, overwrite, ifNotExists = true) - - case _ => - noParseRule("Destination", node) - } - - protected def selExprNodeToExpr(node: ASTNode): Option[Expression] = node match { - case Token("TOK_SELEXPR", e :: Nil) => - Some(nodeToExpr(e)) - - case Token("TOK_SELEXPR", e :: Token(alias, Nil) :: Nil) => - Some(Alias(nodeToExpr(e), cleanIdentifier(alias))()) - - case Token("TOK_SELEXPR", e :: aliasChildren) => - val aliasNames = aliasChildren.collect { - case Token(name, Nil) => cleanIdentifier(name) - } - Some(MultiAlias(nodeToExpr(e), aliasNames)) - - /* Hints are ignored */ - case Token("TOK_HINTLIST", _) => None - - case _ => - noParseRule("Select", node) - } - - /** - * Flattens the left deep tree with the specified pattern into a list. - */ - private def flattenLeftDeepTree(node: ASTNode, pattern: Regex): Seq[ASTNode] = { - val collected = ArrayBuffer[ASTNode]() - var rest = node - while (rest match { - case Token(pattern(), l :: r :: Nil) => - collected += r - rest = l - true - case _ => false - }) { - // do nothing - } - collected += rest - // keep them in the same order as in SQL - collected.reverse - } - - /** - * Creates a balanced tree that has similar number of nodes on left and right. - * - * This help to reduce the depth of the tree to prevent StackOverflow in analyzer/optimizer. - */ - private def balancedTree( - expr: Seq[Expression], - f: (Expression, Expression) => Expression): Expression = expr.length match { - case 1 => expr.head - case 2 => f(expr.head, expr(1)) - case l => f(balancedTree(expr.slice(0, l / 2), f), balancedTree(expr.slice(l / 2, l), f)) - } - - protected def nodeToExpr(node: ASTNode): Expression = node match { - /* Attribute References */ - case Token("TOK_TABLE_OR_COL", Token(name, Nil) :: Nil) => - UnresolvedAttribute.quoted(cleanIdentifier(name)) - case Token(".", qualifier :: Token(attr, Nil) :: Nil) => - nodeToExpr(qualifier) match { - case UnresolvedAttribute(nameParts) => - UnresolvedAttribute(nameParts :+ cleanIdentifier(attr)) - case other => UnresolvedExtractValue(other, Literal(cleanIdentifier(attr))) - } - case Token("TOK_SUBQUERY_EXPR", Token("TOK_SUBQUERY_OP", Nil) :: subquery :: Nil) => - ScalarSubquery(nodeToPlan(subquery)) - - /* Stars (*) */ - case Token("TOK_ALLCOLREF", Nil) => UnresolvedStar(None) - // The format of dbName.tableName.* cannot be parsed by HiveParser. TOK_TABNAME will only - // has a single child which is tableName. - case Token("TOK_ALLCOLREF", Token("TOK_TABNAME", target) :: Nil) if target.nonEmpty => - UnresolvedStar(Some(target.map(x => cleanIdentifier(x.text)))) - - /* Aggregate Functions */ - case Token("TOK_FUNCTIONDI", Token(COUNT(), Nil) :: args) => - Count(args.map(nodeToExpr)).toAggregateExpression(isDistinct = true) - case Token("TOK_FUNCTIONSTAR", Token(COUNT(), Nil) :: Nil) => - Count(Literal(1)).toAggregateExpression() - - /* Casts */ - case Token("TOK_FUNCTION", Token("TOK_STRING", Nil) :: arg :: Nil) => - Cast(nodeToExpr(arg), StringType) - case Token("TOK_FUNCTION", Token("TOK_VARCHAR", _) :: arg :: Nil) => - Cast(nodeToExpr(arg), StringType) - case Token("TOK_FUNCTION", Token("TOK_CHAR", _) :: arg :: Nil) => - Cast(nodeToExpr(arg), StringType) - case Token("TOK_FUNCTION", Token("TOK_INT", Nil) :: arg :: Nil) => - Cast(nodeToExpr(arg), IntegerType) - case Token("TOK_FUNCTION", Token("TOK_BIGINT", Nil) :: arg :: Nil) => - Cast(nodeToExpr(arg), LongType) - case Token("TOK_FUNCTION", Token("TOK_FLOAT", Nil) :: arg :: Nil) => - Cast(nodeToExpr(arg), FloatType) - case Token("TOK_FUNCTION", Token("TOK_DOUBLE", Nil) :: arg :: Nil) => - Cast(nodeToExpr(arg), DoubleType) - case Token("TOK_FUNCTION", Token("TOK_SMALLINT", Nil) :: arg :: Nil) => - Cast(nodeToExpr(arg), ShortType) - case Token("TOK_FUNCTION", Token("TOK_TINYINT", Nil) :: arg :: Nil) => - Cast(nodeToExpr(arg), ByteType) - case Token("TOK_FUNCTION", Token("TOK_BINARY", Nil) :: arg :: Nil) => - Cast(nodeToExpr(arg), BinaryType) - case Token("TOK_FUNCTION", Token("TOK_BOOLEAN", Nil) :: arg :: Nil) => - Cast(nodeToExpr(arg), BooleanType) - case Token("TOK_FUNCTION", Token("TOK_DECIMAL", precision :: scale :: nil) :: arg :: Nil) => - Cast(nodeToExpr(arg), DecimalType(precision.text.toInt, scale.text.toInt)) - case Token("TOK_FUNCTION", Token("TOK_DECIMAL", precision :: Nil) :: arg :: Nil) => - Cast(nodeToExpr(arg), DecimalType(precision.text.toInt, 0)) - case Token("TOK_FUNCTION", Token("TOK_DECIMAL", Nil) :: arg :: Nil) => - Cast(nodeToExpr(arg), DecimalType.USER_DEFAULT) - case Token("TOK_FUNCTION", Token("TOK_TIMESTAMP", Nil) :: arg :: Nil) => - Cast(nodeToExpr(arg), TimestampType) - case Token("TOK_FUNCTION", Token("TOK_DATE", Nil) :: arg :: Nil) => - Cast(nodeToExpr(arg), DateType) - - /* Arithmetic */ - case Token("+", child :: Nil) => nodeToExpr(child) - case Token("-", child :: Nil) => UnaryMinus(nodeToExpr(child)) - case Token("~", child :: Nil) => BitwiseNot(nodeToExpr(child)) - case Token("+", left :: right:: Nil) => Add(nodeToExpr(left), nodeToExpr(right)) - case Token("-", left :: right:: Nil) => Subtract(nodeToExpr(left), nodeToExpr(right)) - case Token("*", left :: right:: Nil) => Multiply(nodeToExpr(left), nodeToExpr(right)) - case Token("/", left :: right:: Nil) => Divide(nodeToExpr(left), nodeToExpr(right)) - case Token(DIV(), left :: right:: Nil) => - Cast(Divide(nodeToExpr(left), nodeToExpr(right)), LongType) - case Token("%", left :: right:: Nil) => Remainder(nodeToExpr(left), nodeToExpr(right)) - case Token("&", left :: right:: Nil) => BitwiseAnd(nodeToExpr(left), nodeToExpr(right)) - case Token("|", left :: right:: Nil) => BitwiseOr(nodeToExpr(left), nodeToExpr(right)) - case Token("^", left :: right:: Nil) => BitwiseXor(nodeToExpr(left), nodeToExpr(right)) - - /* Comparisons */ - case Token("=", left :: right:: Nil) => EqualTo(nodeToExpr(left), nodeToExpr(right)) - case Token("==", left :: right:: Nil) => EqualTo(nodeToExpr(left), nodeToExpr(right)) - case Token("<=>", left :: right:: Nil) => EqualNullSafe(nodeToExpr(left), nodeToExpr(right)) - case Token("!=", left :: right:: Nil) => Not(EqualTo(nodeToExpr(left), nodeToExpr(right))) - case Token("<>", left :: right:: Nil) => Not(EqualTo(nodeToExpr(left), nodeToExpr(right))) - case Token(">", left :: right:: Nil) => GreaterThan(nodeToExpr(left), nodeToExpr(right)) - case Token(">=", left :: right:: Nil) => GreaterThanOrEqual(nodeToExpr(left), nodeToExpr(right)) - case Token("<", left :: right:: Nil) => LessThan(nodeToExpr(left), nodeToExpr(right)) - case Token("<=", left :: right:: Nil) => LessThanOrEqual(nodeToExpr(left), nodeToExpr(right)) - case Token(LIKE(), left :: right:: Nil) => Like(nodeToExpr(left), nodeToExpr(right)) - case Token(RLIKE(), left :: right:: Nil) => RLike(nodeToExpr(left), nodeToExpr(right)) - case Token(REGEXP(), left :: right:: Nil) => RLike(nodeToExpr(left), nodeToExpr(right)) - case Token("TOK_FUNCTION", Token("TOK_ISNOTNULL", Nil) :: child :: Nil) => - IsNotNull(nodeToExpr(child)) - case Token("TOK_FUNCTION", Token("TOK_ISNULL", Nil) :: child :: Nil) => - IsNull(nodeToExpr(child)) - case Token("TOK_FUNCTION", Token(IN(), Nil) :: value :: list) => - In(nodeToExpr(value), list.map(nodeToExpr)) - case Token("TOK_FUNCTION", - Token(BETWEEN(), Nil) :: - kw :: - target :: - minValue :: - maxValue :: Nil) => - - val targetExpression = nodeToExpr(target) - val betweenExpr = - And( - GreaterThanOrEqual(targetExpression, nodeToExpr(minValue)), - LessThanOrEqual(targetExpression, nodeToExpr(maxValue))) - kw match { - case Token("KW_FALSE", Nil) => betweenExpr - case Token("KW_TRUE", Nil) => Not(betweenExpr) - } - - /* Boolean Logic */ - case Token(AND(), left :: right:: Nil) => - balancedTree(flattenLeftDeepTree(node, AND).map(nodeToExpr), And) - case Token(OR(), left :: right:: Nil) => - balancedTree(flattenLeftDeepTree(node, OR).map(nodeToExpr), Or) - case Token(NOT(), child :: Nil) => Not(nodeToExpr(child)) - case Token("!", child :: Nil) => Not(nodeToExpr(child)) - - /* Case statements */ - case Token("TOK_FUNCTION", Token(WHEN(), Nil) :: branches) => - CaseWhen.createFromParser(branches.map(nodeToExpr)) - case Token("TOK_FUNCTION", Token(CASE(), Nil) :: branches) => - val keyExpr = nodeToExpr(branches.head) - CaseKeyWhen(keyExpr, branches.drop(1).map(nodeToExpr)) - - /* Complex datatype manipulation */ - case Token("[", child :: ordinal :: Nil) => - UnresolvedExtractValue(nodeToExpr(child), nodeToExpr(ordinal)) - - /* Window Functions */ - case Token(text, args :+ Token("TOK_WINDOWSPEC", spec)) => - val function = nodeToExpr(node.copy(children = node.children.init)) - nodesToWindowSpecification(spec) match { - case reference: WindowSpecReference => - UnresolvedWindowExpression(function, reference) - case definition: WindowSpecDefinition => - WindowExpression(function, definition) - } - - /* UDFs - Must be last otherwise will preempt built in functions */ - case Token("TOK_FUNCTION", Token(name, Nil) :: args) => - UnresolvedFunction(name, args.map(nodeToExpr), isDistinct = false) - // Aggregate function with DISTINCT keyword. - case Token("TOK_FUNCTIONDI", Token(name, Nil) :: args) => - UnresolvedFunction(name, args.map(nodeToExpr), isDistinct = true) - case Token("TOK_FUNCTIONSTAR", Token(name, Nil) :: args) => - UnresolvedFunction(name, UnresolvedStar(None) :: Nil, isDistinct = false) - - /* Literals */ - case Token("TOK_NULL", Nil) => Literal.create(null, NullType) - case Token(TRUE(), Nil) => Literal.create(true, BooleanType) - case Token(FALSE(), Nil) => Literal.create(false, BooleanType) - case Token("TOK_STRINGLITERALSEQUENCE", strings) => - Literal(strings.map(s => ParseUtils.unescapeSQLString(s.text)).mkString) - - case ast if ast.tokenType == SparkSqlParser.TinyintLiteral => - Literal.create(ast.text.substring(0, ast.text.length() - 1).toByte, ByteType) - - case ast if ast.tokenType == SparkSqlParser.SmallintLiteral => - Literal.create(ast.text.substring(0, ast.text.length() - 1).toShort, ShortType) - - case ast if ast.tokenType == SparkSqlParser.BigintLiteral => - Literal.create(ast.text.substring(0, ast.text.length() - 1).toLong, LongType) - - case ast if ast.tokenType == SparkSqlParser.DoubleLiteral => - Literal(ast.text.toDouble) - - case ast if ast.tokenType == SparkSqlParser.Number => - val text = ast.text - text match { - case INTEGRAL() => - BigDecimal(text) match { - case v if v.isValidInt => - Literal(v.intValue()) - case v if v.isValidLong => - Literal(v.longValue()) - case v => Literal(v.underlying()) - } - case DECIMAL(_*) => - Literal(BigDecimal(text).underlying()) - case _ => - // Convert a scientifically notated decimal into a double. - Literal(text.toDouble) - } - case ast if ast.tokenType == SparkSqlParser.StringLiteral => - Literal(ParseUtils.unescapeSQLString(ast.text)) - - case ast if ast.tokenType == SparkSqlParser.TOK_DATELITERAL => - Literal(Date.valueOf(ast.text.substring(1, ast.text.length - 1))) - - case ast if ast.tokenType == SparkSqlParser.TOK_INTERVAL_YEAR_MONTH_LITERAL => - Literal(CalendarInterval.fromYearMonthString(ast.children.head.text)) - - case ast if ast.tokenType == SparkSqlParser.TOK_INTERVAL_DAY_TIME_LITERAL => - Literal(CalendarInterval.fromDayTimeString(ast.children.head.text)) - - case Token("TOK_INTERVAL", elements) => - var interval = new CalendarInterval(0, 0) - var updated = false - elements.foreach { - // The interval node will always contain children for all possible time units. A child node - // is only useful when it contains exactly one (numeric) child. - case e @ Token(name, Token(value, Nil) :: Nil) => - val unit = name match { - case "TOK_INTERVAL_YEAR_LITERAL" => "year" - case "TOK_INTERVAL_MONTH_LITERAL" => "month" - case "TOK_INTERVAL_WEEK_LITERAL" => "week" - case "TOK_INTERVAL_DAY_LITERAL" => "day" - case "TOK_INTERVAL_HOUR_LITERAL" => "hour" - case "TOK_INTERVAL_MINUTE_LITERAL" => "minute" - case "TOK_INTERVAL_SECOND_LITERAL" => "second" - case "TOK_INTERVAL_MILLISECOND_LITERAL" => "millisecond" - case "TOK_INTERVAL_MICROSECOND_LITERAL" => "microsecond" - case _ => noParseRule(s"Interval($name)", e) - } - interval = interval.add(CalendarInterval.fromSingleUnitString(unit, value)) - updated = true - case _ => - } - if (!updated) { - throw new AnalysisException("at least one time unit should be given for interval literal") - } - Literal(interval) - - case _ => - noParseRule("Expression", node) - } - - /* Case insensitive matches for Window Specification */ - val PRECEDING = "(?i)preceding".r - val FOLLOWING = "(?i)following".r - val CURRENT = "(?i)current".r - protected def nodesToWindowSpecification(nodes: Seq[ASTNode]): WindowSpec = nodes match { - case Token(windowName, Nil) :: Nil => - // Refer to a window spec defined in the window clause. - WindowSpecReference(windowName) - case Nil => - // OVER() - WindowSpecDefinition( - partitionSpec = Nil, - orderSpec = Nil, - frameSpecification = UnspecifiedFrame) - case spec => - val (partitionClause :: rowFrame :: rangeFrame :: Nil) = - getClauses( - Seq( - "TOK_PARTITIONINGSPEC", - "TOK_WINDOWRANGE", - "TOK_WINDOWVALUES"), - spec) - - // Handle Partition By and Order By. - val (partitionSpec, orderSpec) = partitionClause.map { partitionAndOrdering => - val (partitionByClause :: orderByClause :: sortByClause :: clusterByClause :: Nil) = - getClauses( - Seq("TOK_DISTRIBUTEBY", "TOK_ORDERBY", "TOK_SORTBY", "TOK_CLUSTERBY"), - partitionAndOrdering.children) - - (partitionByClause, orderByClause.orElse(sortByClause), clusterByClause) match { - case (Some(partitionByExpr), Some(orderByExpr), None) => - (partitionByExpr.children.map(nodeToExpr), - orderByExpr.children.map(nodeToSortOrder)) - case (Some(partitionByExpr), None, None) => - (partitionByExpr.children.map(nodeToExpr), Nil) - case (None, Some(orderByExpr), None) => - (Nil, orderByExpr.children.map(nodeToSortOrder)) - case (None, None, Some(clusterByExpr)) => - val expressions = clusterByExpr.children.map(nodeToExpr) - (expressions, expressions.map(SortOrder(_, Ascending))) - case _ => - noParseRule("Partition & Ordering", partitionAndOrdering) - } - }.getOrElse { - (Nil, Nil) - } - - // Handle Window Frame - val windowFrame = - if (rowFrame.isEmpty && rangeFrame.isEmpty) { - UnspecifiedFrame - } else { - val frameType = rowFrame.map(_ => RowFrame).getOrElse(RangeFrame) - def nodeToBoundary(node: ASTNode): FrameBoundary = node match { - case Token(PRECEDING(), Token(count, Nil) :: Nil) => - if (count.toLowerCase() == "unbounded") { - UnboundedPreceding - } else { - ValuePreceding(count.toInt) - } - case Token(FOLLOWING(), Token(count, Nil) :: Nil) => - if (count.toLowerCase() == "unbounded") { - UnboundedFollowing - } else { - ValueFollowing(count.toInt) - } - case Token(CURRENT(), Nil) => CurrentRow - case _ => - noParseRule("Window Frame Boundary", node) - } - - rowFrame.orElse(rangeFrame).map { frame => - frame.children match { - case precedingNode :: followingNode :: Nil => - SpecifiedWindowFrame( - frameType, - nodeToBoundary(precedingNode), - nodeToBoundary(followingNode)) - case precedingNode :: Nil => - SpecifiedWindowFrame(frameType, nodeToBoundary(precedingNode), CurrentRow) - case _ => - noParseRule("Window Frame", frame) - } - }.getOrElse(sys.error(s"If you see this, please file a bug report with your query.")) - } - - WindowSpecDefinition(partitionSpec, orderSpec, windowFrame) - } - - protected def nodeToTransformation( - node: ASTNode, - child: LogicalPlan): Option[ScriptTransformation] = None - - val explode = "(?i)explode".r - val jsonTuple = "(?i)json_tuple".r - protected def nodeToGenerate(node: ASTNode, outer: Boolean, child: LogicalPlan): Generate = { - val Token("TOK_SELECT", Token("TOK_SELEXPR", clauses) :: Nil) = node - - val alias = cleanIdentifier(getClause("TOK_TABALIAS", clauses).children.head.text) - - val generator = clauses.head match { - case Token("TOK_FUNCTION", Token(explode(), Nil) :: childNode :: Nil) => - Explode(nodeToExpr(childNode)) - case Token("TOK_FUNCTION", Token(jsonTuple(), Nil) :: children) => - JsonTuple(children.map(nodeToExpr)) - case other => - nodeToGenerator(other) - } - - val attributes = clauses.collect { - case Token(a, Nil) => UnresolvedAttribute(cleanIdentifier(a.toLowerCase)) - } - - Generate( - generator, - join = true, - outer = outer, - Some(cleanIdentifier(alias.toLowerCase)), - attributes, - child) - } - - protected def nodeToGenerator(node: ASTNode): Generator = noParseRule("Generator", node) - -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/DataTypeParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/DataTypeParser.scala index 21deb82107..0b570c9e42 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/DataTypeParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/DataTypeParser.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.parser import scala.language.implicitConversions import scala.util.matching.Regex import scala.util.parsing.combinator.syntactical.StandardTokenParsers +import scala.util.parsing.input.CharArrayReader._ import org.apache.spark.sql.types._ @@ -117,3 +118,69 @@ private[sql] object DataTypeParser { /** The exception thrown from the [[DataTypeParser]]. */ private[sql] class DataTypeException(message: String) extends Exception(message) + +class SqlLexical extends scala.util.parsing.combinator.lexical.StdLexical { + case class DecimalLit(chars: String) extends Token { + override def toString: String = chars + } + + /* This is a work around to support the lazy setting */ + def initialize(keywords: Seq[String]): Unit = { + reserved.clear() + reserved ++= keywords + } + + /* Normal the keyword string */ + def normalizeKeyword(str: String): String = str.toLowerCase + + delimiters += ( + "@", "*", "+", "-", "<", "=", "<>", "!=", "<=", ">=", ">", "/", "(", ")", + ",", ";", "%", "{", "}", ":", "[", "]", ".", "&", "|", "^", "~", "<=>" + ) + + protected override def processIdent(name: String) = { + val token = normalizeKeyword(name) + if (reserved contains token) Keyword(token) else Identifier(name) + } + + override lazy val token: Parser[Token] = + ( rep1(digit) ~ scientificNotation ^^ { case i ~ s => DecimalLit(i.mkString + s) } + | '.' ~> (rep1(digit) ~ scientificNotation) ^^ + { case i ~ s => DecimalLit("0." + i.mkString + s) } + | rep1(digit) ~ ('.' ~> digit.*) ~ scientificNotation ^^ + { case i1 ~ i2 ~ s => DecimalLit(i1.mkString + "." + i2.mkString + s) } + | digit.* ~ identChar ~ (identChar | digit).* ^^ + { case first ~ middle ~ rest => processIdent((first ++ (middle :: rest)).mkString) } + | rep1(digit) ~ ('.' ~> digit.*).? ^^ { + case i ~ None => NumericLit(i.mkString) + case i ~ Some(d) => DecimalLit(i.mkString + "." + d.mkString) + } + | '\'' ~> chrExcept('\'', '\n', EofCh).* <~ '\'' ^^ + { case chars => StringLit(chars mkString "") } + | '"' ~> chrExcept('"', '\n', EofCh).* <~ '"' ^^ + { case chars => StringLit(chars mkString "") } + | '`' ~> chrExcept('`', '\n', EofCh).* <~ '`' ^^ + { case chars => Identifier(chars mkString "") } + | EofCh ^^^ EOF + | '\'' ~> failure("unclosed string literal") + | '"' ~> failure("unclosed string literal") + | delim + | failure("illegal character") + ) + + override def identChar: Parser[Elem] = letter | elem('_') + + private lazy val scientificNotation: Parser[String] = + (elem('e') | elem('E')) ~> (elem('+') | elem('-')).? ~ rep1(digit) ^^ { + case s ~ rest => "e" + s.mkString + rest.mkString + } + + override def whitespace: Parser[Any] = + ( whitespaceChar + | '/' ~ '*' ~ comment + | '/' ~ '/' ~ chrExcept(EofCh, '\n').* + | '#' ~ chrExcept(EofCh, '\n').* + | '-' ~ '-' ~ chrExcept(EofCh, '\n').* + | '/' ~ '*' ~ failure("unclosed comment") + ).* +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala index 51cfc50130..d0132529f1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala @@ -16,90 +16,105 @@ */ package org.apache.spark.sql.catalyst.parser -import scala.annotation.tailrec - -import org.antlr.runtime._ -import org.antlr.runtime.tree.CommonTree +import org.antlr.v4.runtime._ +import org.antlr.v4.runtime.atn.PredictionMode +import org.antlr.v4.runtime.misc.ParseCancellationException import org.apache.spark.internal.Logging import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.trees.Origin +import org.apache.spark.sql.types.DataType /** - * The ParseDriver takes a SQL command and turns this into an AST. - * - * This is based on Hive's org.apache.hadoop.hive.ql.parse.ParseDriver + * Base SQL parsing infrastructure. */ -object ParseDriver extends Logging { - /** Create an LogicalPlan ASTNode from a SQL command. */ - def parsePlan(command: String, conf: ParserConf): ASTNode = parse(command, conf) { parser => - parser.statement().getTree - } +abstract class AbstractSqlParser extends ParserInterface with Logging { - /** Create an Expression ASTNode from a SQL command. */ - def parseExpression(command: String, conf: ParserConf): ASTNode = parse(command, conf) { parser => - parser.singleNamedExpression().getTree + /** Creates/Resolves DataType for a given SQL string. */ + def parseDataType(sqlText: String): DataType = parse(sqlText) { parser => + // TODO add this to the parser interface. + astBuilder.visitSingleDataType(parser.singleDataType()) } - /** Create an TableIdentifier ASTNode from a SQL command. */ - def parseTableName(command: String, conf: ParserConf): ASTNode = parse(command, conf) { parser => - parser.singleTableName().getTree + /** Creates Expression for a given SQL string. */ + override def parseExpression(sqlText: String): Expression = parse(sqlText) { parser => + astBuilder.visitSingleExpression(parser.singleExpression()) } - private def parse( - command: String, - conf: ParserConf)( - toTree: SparkSqlParser => CommonTree): ASTNode = { - logInfo(s"Parsing command: $command") + /** Creates TableIdentifier for a given SQL string. */ + override def parseTableIdentifier(sqlText: String): TableIdentifier = parse(sqlText) { parser => + astBuilder.visitSingleTableIdentifier(parser.singleTableIdentifier()) + } - // Setup error collection. - val reporter = new ParseErrorReporter() + /** Creates LogicalPlan for a given SQL string. */ + override def parsePlan(sqlText: String): LogicalPlan = parse(sqlText) { parser => + astBuilder.visitSingleStatement(parser.singleStatement()) match { + case plan: LogicalPlan => plan + case _ => nativeCommand(sqlText) + } + } - // Create lexer. - val lexer = new SparkSqlLexer(new ANTLRNoCaseStringStream(command)) - val tokens = new TokenRewriteStream(lexer) - lexer.configure(conf, reporter) + /** Get the builder (visitor) which converts a ParseTree into a AST. */ + protected def astBuilder: AstBuilder - // Create the parser. - val parser = new SparkSqlParser(tokens) - parser.configure(conf, reporter) + /** Create a native command, or fail when this is not supported. */ + protected def nativeCommand(sqlText: String): LogicalPlan = { + val position = Origin(None, None) + throw new ParseException(Option(sqlText), "Unsupported SQL statement", position, position) + } - try { - val result = toTree(parser) + protected def parse[T](command: String)(toResult: SqlBaseParser => T): T = { + logInfo(s"Parsing command: $command") - // Check errors. - reporter.checkForErrors() + val lexer = new SqlBaseLexer(new ANTLRNoCaseStringStream(command)) + lexer.removeErrorListeners() + lexer.addErrorListener(ParseErrorListener) - // Return the AST node from the result. - logInfo(s"Parse completed.") + val tokenStream = new CommonTokenStream(lexer) + val parser = new SqlBaseParser(tokenStream) + parser.addParseListener(PostProcessor) + parser.removeErrorListeners() + parser.addErrorListener(ParseErrorListener) - // Find the non null token tree in the result. - @tailrec - def nonNullToken(tree: CommonTree): CommonTree = { - if (tree.token != null || tree.getChildCount == 0) tree - else nonNullToken(tree.getChild(0).asInstanceOf[CommonTree]) + try { + try { + // first, try parsing with potentially faster SLL mode + parser.getInterpreter.setPredictionMode(PredictionMode.SLL) + toResult(parser) } - val tree = nonNullToken(result) - - // Make sure all boundaries are set. - tree.setUnknownTokenBoundaries() - - // Construct the immutable AST. - def createASTNode(tree: CommonTree): ASTNode = { - val children = (0 until tree.getChildCount).map { i => - createASTNode(tree.getChild(i).asInstanceOf[CommonTree]) - }.toList - ASTNode(tree.token, tree.getTokenStartIndex, tree.getTokenStopIndex, children, tokens) + catch { + case e: ParseCancellationException => + // if we fail, parse with LL mode + tokenStream.reset() // rewind input stream + parser.reset() + + // Try Again. + parser.getInterpreter.setPredictionMode(PredictionMode.LL) + toResult(parser) } - createASTNode(tree) } catch { - case e: RecognitionException => - logInfo(s"Parse failed.") - reporter.throwError(e) + case e: ParseException if e.command.isDefined => + throw e + case e: ParseException => + throw e.withCommand(command) + case e: AnalysisException => + val position = Origin(e.line, e.startPosition) + throw new ParseException(Option(command), e.message, position, position) } } } +/** + * Concrete SQL parser for Catalyst-only SQL statements. + */ +object CatalystSqlParser extends AbstractSqlParser { + val astBuilder = new AstBuilder +} + /** * This string stream provides the lexer with upper case characters only. This greatly simplifies * lexing the stream, while we can maintain the original command. @@ -120,58 +135,104 @@ object ParseDriver extends Logging { * have the ANTLRNoCaseStringStream implementation. */ -private[parser] class ANTLRNoCaseStringStream(input: String) extends ANTLRStringStream(input) { +private[parser] class ANTLRNoCaseStringStream(input: String) extends ANTLRInputStream(input) { override def LA(i: Int): Int = { val la = super.LA(i) - if (la == 0 || la == CharStream.EOF) la + if (la == 0 || la == IntStream.EOF) la else Character.toUpperCase(la) } } /** - * Utility used by the Parser and the Lexer for error collection and reporting. + * The ParseErrorListener converts parse errors into AnalysisExceptions. */ -private[parser] class ParseErrorReporter { - val errors = scala.collection.mutable.Buffer.empty[ParseError] - - def report(br: BaseRecognizer, re: RecognitionException, tokenNames: Array[String]): Unit = { - errors += ParseError(br, re, tokenNames) +case object ParseErrorListener extends BaseErrorListener { + override def syntaxError( + recognizer: Recognizer[_, _], + offendingSymbol: scala.Any, + line: Int, + charPositionInLine: Int, + msg: String, + e: RecognitionException): Unit = { + val position = Origin(Some(line), Some(charPositionInLine)) + throw new ParseException(None, msg, position, position) } +} - def checkForErrors(): Unit = { - if (errors.nonEmpty) { - val first = errors.head - val e = first.re - throwError(e.line, e.charPositionInLine, first.buildMessage().toString, errors.tail) - } +/** + * A [[ParseException]] is an [[AnalysisException]] that is thrown during the parse process. It + * contains fields and an extended error message that make reporting and diagnosing errors easier. + */ +class ParseException( + val command: Option[String], + message: String, + val start: Origin, + val stop: Origin) extends AnalysisException(message, start.line, start.startPosition) { + + def this(message: String, ctx: ParserRuleContext) = { + this(Option(ParserUtils.command(ctx)), + message, + ParserUtils.position(ctx.getStart), + ParserUtils.position(ctx.getStop)) } - def throwError(e: RecognitionException): Nothing = { - throwError(e.line, e.charPositionInLine, e.toString, errors) + override def getMessage: String = { + val builder = new StringBuilder + builder ++= "\n" ++= message + start match { + case Origin(Some(l), Some(p)) => + builder ++= s"(line $l, pos $p)\n" + command.foreach { cmd => + val (above, below) = cmd.split("\n").splitAt(l) + builder ++= "\n== SQL ==\n" + above.foreach(builder ++= _ += '\n') + builder ++= (0 until p).map(_ => "-").mkString("") ++= "^^^\n" + below.foreach(builder ++= _ += '\n') + } + case _ => + command.foreach { cmd => + builder ++= "\n== SQL ==\n" ++= cmd + } + } + builder.toString } - private def throwError( - line: Int, - startPosition: Int, - msg: String, - errors: Seq[ParseError]): Nothing = { - val b = new StringBuilder - b.append(msg).append("\n") - errors.foreach(error => error.buildMessage(b).append("\n")) - throw new AnalysisException(b.toString, Option(line), Option(startPosition)) + def withCommand(cmd: String): ParseException = { + new ParseException(Option(cmd), message, start, stop) } } /** - * Error collected during the parsing process. - * - * This is based on Hive's org.apache.hadoop.hive.ql.parse.ParseError + * The post-processor validates & cleans-up the parse tree during the parse process. */ -private[parser] case class ParseError( - br: BaseRecognizer, - re: RecognitionException, - tokenNames: Array[String]) { - def buildMessage(s: StringBuilder = new StringBuilder): StringBuilder = { - s.append(br.getErrorHeader(re)).append(" ").append(br.getErrorMessage(re, tokenNames)) +case object PostProcessor extends SqlBaseBaseListener { + + /** Remove the back ticks from an Identifier. */ + override def exitQuotedIdentifier(ctx: SqlBaseParser.QuotedIdentifierContext): Unit = { + replaceTokenByIdentifier(ctx, 1) { token => + // Remove the double back ticks in the string. + token.setText(token.getText.replace("``", "`")) + token + } + } + + /** Treat non-reserved keywords as Identifiers. */ + override def exitNonReserved(ctx: SqlBaseParser.NonReservedContext): Unit = { + replaceTokenByIdentifier(ctx, 0)(identity) + } + + private def replaceTokenByIdentifier( + ctx: ParserRuleContext, + stripMargins: Int)( + f: CommonToken => CommonToken = identity): Unit = { + val parent = ctx.getParent + parent.removeLastChild() + val token = ctx.getChild(0).getPayload.asInstanceOf[Token] + parent.addChild(f(new CommonToken( + new org.antlr.v4.runtime.misc.Pair(token.getTokenSource, token.getInputStream), + SqlBaseParser.IDENTIFIER, + token.getChannel, + token.getStartIndex + stripMargins, + token.getStopIndex - stripMargins))) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserConf.scala deleted file mode 100644 index ce449b1143..0000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserConf.scala +++ /dev/null @@ -1,26 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.sql.catalyst.parser - -trait ParserConf { - def supportQuotedId: Boolean - def supportSQL11ReservedKeywords: Boolean -} - -case class SimpleParserConf( - supportQuotedId: Boolean = true, - supportSQL11ReservedKeywords: Boolean = false) extends ParserConf diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserUtils.scala index 0c2e481954..90b76dc314 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserUtils.scala @@ -14,166 +14,105 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - package org.apache.spark.sql.catalyst.parser -import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.catalyst.trees.CurrentOrigin -import org.apache.spark.sql.types._ +import org.antlr.v4.runtime.{CharStream, ParserRuleContext, Token} +import org.antlr.v4.runtime.misc.Interval +import org.antlr.v4.runtime.tree.TerminalNode +import org.apache.spark.sql.catalyst.parser.ParseUtils.unescapeSQLString +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, Origin} /** - * A collection of utility methods and patterns for parsing query texts. + * A collection of utility methods for use during the parsing process. */ -// TODO: merge with ParseUtils object ParserUtils { - - object Token { - // Match on (text, children) - def unapply(node: ASTNode): Some[(String, List[ASTNode])] = { - CurrentOrigin.setPosition(node.line, node.positionInLine) - node.pattern - } + /** Get the command which created the token. */ + def command(ctx: ParserRuleContext): String = { + command(ctx.getStart.getInputStream) } - private val escapedIdentifier = "`(.+)`".r - private val doubleQuotedString = "\"([^\"]+)\"".r - private val singleQuotedString = "'([^']+)'".r - - // Token patterns - val COUNT = "(?i)COUNT".r - val SUM = "(?i)SUM".r - val AND = "(?i)AND".r - val OR = "(?i)OR".r - val NOT = "(?i)NOT".r - val TRUE = "(?i)TRUE".r - val FALSE = "(?i)FALSE".r - val LIKE = "(?i)LIKE".r - val RLIKE = "(?i)RLIKE".r - val REGEXP = "(?i)REGEXP".r - val IN = "(?i)IN".r - val DIV = "(?i)DIV".r - val BETWEEN = "(?i)BETWEEN".r - val WHEN = "(?i)WHEN".r - val CASE = "(?i)CASE".r - val INTEGRAL = "[+-]?\\d+".r - val DECIMAL = "[+-]?((\\d+(\\.\\d*)?)|(\\.\\d+))".r - - /** - * Strip quotes, if any, from the string. - */ - def unquoteString(str: String): String = { - str match { - case singleQuotedString(s) => s - case doubleQuotedString(s) => s - case other => other - } + /** Get the command which created the token. */ + def command(stream: CharStream): String = { + stream.getText(Interval.of(0, stream.size())) } - /** - * Strip backticks, if any, from the string. - */ - def cleanIdentifier(ident: String): String = { - ident match { - case escapedIdentifier(i) => i - case plainIdent => plainIdent - } + /** Get the code that creates the given node. */ + def source(ctx: ParserRuleContext): String = { + val stream = ctx.getStart.getInputStream + stream.getText(Interval.of(ctx.getStart.getStartIndex, ctx.getStop.getStopIndex)) } - def getClauses( - clauseNames: Seq[String], - nodeList: Seq[ASTNode]): Seq[Option[ASTNode]] = { - var remainingNodes = nodeList - val clauses = clauseNames.map { clauseName => - val (matches, nonMatches) = remainingNodes.partition(_.text.toUpperCase == clauseName) - remainingNodes = nonMatches ++ (if (matches.nonEmpty) matches.tail else Nil) - matches.headOption - } + /** Get all the text which comes after the given rule. */ + def remainder(ctx: ParserRuleContext): String = remainder(ctx.getStop) - if (remainingNodes.nonEmpty) { - sys.error( - s"""Unhandled clauses: ${remainingNodes.map(_.treeString).mkString("\n")}. - |You are likely trying to use an unsupported Hive feature."""".stripMargin) - } - clauses + /** Get all the text which comes after the given token. */ + def remainder(token: Token): String = { + val stream = token.getInputStream + val interval = Interval.of(token.getStopIndex + 1, stream.size()) + stream.getText(interval) } - def getClause(clauseName: String, nodeList: Seq[ASTNode]): ASTNode = { - getClauseOption(clauseName, nodeList).getOrElse(sys.error( - s"Expected clause $clauseName missing from ${nodeList.map(_.treeString).mkString("\n")}")) - } + /** Convert a string token into a string. */ + def string(token: Token): String = unescapeSQLString(token.getText) - def getClauseOption(clauseName: String, nodeList: Seq[ASTNode]): Option[ASTNode] = { - nodeList.filter { case ast: ASTNode => ast.text == clauseName } match { - case Seq(oneMatch) => Some(oneMatch) - case Seq() => None - case _ => sys.error(s"Found multiple instances of clause $clauseName") - } - } + /** Convert a string node into a string. */ + def string(node: TerminalNode): String = unescapeSQLString(node.getText) - def extractTableIdent(tableNameParts: ASTNode): TableIdentifier = { - tableNameParts.children.map { - case Token(part, Nil) => cleanIdentifier(part) - } match { - case Seq(tableOnly) => TableIdentifier(tableOnly) - case Seq(databaseName, table) => TableIdentifier(table, Some(databaseName)) - case other => sys.error("Hive only supports tables names like 'tableName' " + - s"or 'databaseName.tableName', found '$other'") - } + /** Get the origin (line and position) of the token. */ + def position(token: Token): Origin = { + Origin(Option(token.getLine), Option(token.getCharPositionInLine)) } - def nodeToDataType(node: ASTNode): DataType = node match { - case Token("TOK_DECIMAL", precision :: scale :: Nil) => - DecimalType(precision.text.toInt, scale.text.toInt) - case Token("TOK_DECIMAL", precision :: Nil) => - DecimalType(precision.text.toInt, 0) - case Token("TOK_DECIMAL", Nil) => DecimalType.USER_DEFAULT - case Token("TOK_BIGINT", Nil) => LongType - case Token("TOK_INT", Nil) => IntegerType - case Token("TOK_TINYINT", Nil) => ByteType - case Token("TOK_SMALLINT", Nil) => ShortType - case Token("TOK_BOOLEAN", Nil) => BooleanType - case Token("TOK_STRING", Nil) => StringType - case Token("TOK_VARCHAR", Token(_, Nil) :: Nil) => StringType - case Token("TOK_CHAR", Token(_, Nil) :: Nil) => StringType - case Token("TOK_FLOAT", Nil) => FloatType - case Token("TOK_DOUBLE", Nil) => DoubleType - case Token("TOK_DATE", Nil) => DateType - case Token("TOK_TIMESTAMP", Nil) => TimestampType - case Token("TOK_BINARY", Nil) => BinaryType - case Token("TOK_LIST", elementType :: Nil) => ArrayType(nodeToDataType(elementType)) - case Token("TOK_STRUCT", Token("TOK_TABCOLLIST", fields) :: Nil) => - StructType(fields.map(nodeToStructField)) - case Token("TOK_MAP", keyType :: valueType :: Nil) => - MapType(nodeToDataType(keyType), nodeToDataType(valueType)) - case _ => - noParseRule("DataType", node) - } - - def nodeToStructField(node: ASTNode): StructField = node match { - case Token("TOK_TABCOL", Token(fieldName, Nil) :: dataType :: Nil) => - StructField(cleanIdentifier(fieldName), nodeToDataType(dataType), nullable = true) - case Token("TOK_TABCOL", Token(fieldName, Nil) :: dataType :: comment :: Nil) => - val meta = new MetadataBuilder().putString("comment", unquoteString(comment.text)).build() - StructField(cleanIdentifier(fieldName), nodeToDataType(dataType), nullable = true, meta) - case _ => - noParseRule("StructField", node) + /** Assert if a condition holds. If it doesn't throw a parse exception. */ + def assert(f: => Boolean, message: String, ctx: ParserRuleContext): Unit = { + if (!f) { + throw new ParseException(message, ctx) + } } /** - * Throw an exception because we cannot parse the given node for some unexpected reason. + * Register the origin of the context. Any TreeNode created in the closure will be assigned the + * registered origin. This method restores the previously set origin after completion of the + * closure. */ - def parseFailed(msg: String, node: ASTNode): Nothing = { - throw new AnalysisException(s"$msg: '${node.source}") + def withOrigin[T](ctx: ParserRuleContext)(f: => T): T = { + val current = CurrentOrigin.get + CurrentOrigin.set(position(ctx.getStart)) + try { + f + } finally { + CurrentOrigin.set(current) + } } - /** - * Throw an exception because there are no rules to parse the node. - */ - def noParseRule(msg: String, node: ASTNode): Nothing = { - throw new NotImplementedError( - s"[$msg]: No parse rules for ASTNode type: ${node.tokenType}, tree:\n${node.treeString}") - } + /** Some syntactic sugar which makes it easier to work with optional clauses for LogicalPlans. */ + implicit class EnhancedLogicalPlan(val plan: LogicalPlan) extends AnyVal { + /** + * Create a plan using the block of code when the given context exists. Otherwise return the + * original plan. + */ + def optional(ctx: AnyRef)(f: => LogicalPlan): LogicalPlan = { + if (ctx != null) { + f + } else { + plan + } + } + /** + * Map a [[LogicalPlan]] to another [[LogicalPlan]] if the passed context exists using the + * passed function. The original plan is returned when the context does not exist. + */ + def optionalMap[C <: ParserRuleContext]( + ctx: C)( + f: (C, LogicalPlan) => LogicalPlan): LogicalPlan = { + if (ctx != null) { + f(ctx, plan) + } else { + plan + } + } + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ng/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ng/AstBuilder.scala deleted file mode 100644 index 5a64c414fb..0000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ng/AstBuilder.scala +++ /dev/null @@ -1,1452 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.sql.catalyst.parser.ng - -import java.sql.{Date, Timestamp} - -import scala.collection.JavaConverters._ -import scala.collection.mutable.ArrayBuffer - -import org.antlr.v4.runtime.{ParserRuleContext, Token} -import org.antlr.v4.runtime.tree.{ParseTree, TerminalNode} - -import org.apache.spark.internal.Logging -import org.apache.spark.sql.catalyst.{InternalRow, TableIdentifier} -import org.apache.spark.sql.catalyst.analysis._ -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.parser.ng.SqlBaseParser._ -import org.apache.spark.sql.catalyst.plans._ -import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.CalendarInterval -import org.apache.spark.util.random.RandomSampler - -/** - * The AstBuilder converts an ANTLR4 ParseTree into a catalyst Expression, LogicalPlan or - * TableIdentifier. - */ -class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { - import ParserUtils._ - - protected def typedVisit[T](ctx: ParseTree): T = { - ctx.accept(this).asInstanceOf[T] - } - - override def visitSingleStatement(ctx: SingleStatementContext): LogicalPlan = withOrigin(ctx) { - visit(ctx.statement).asInstanceOf[LogicalPlan] - } - - override def visitSingleExpression(ctx: SingleExpressionContext): Expression = withOrigin(ctx) { - visitNamedExpression(ctx.namedExpression) - } - - override def visitSingleTableIdentifier( - ctx: SingleTableIdentifierContext): TableIdentifier = withOrigin(ctx) { - visitTableIdentifier(ctx.tableIdentifier) - } - - override def visitSingleDataType(ctx: SingleDataTypeContext): DataType = withOrigin(ctx) { - visit(ctx.dataType).asInstanceOf[DataType] - } - - /* ******************************************************************************************** - * Plan parsing - * ******************************************************************************************** */ - protected def plan(tree: ParserRuleContext): LogicalPlan = typedVisit(tree) - - /** - * Make sure we do not try to create a plan for a native command. - */ - override def visitExecuteNativeCommand(ctx: ExecuteNativeCommandContext): LogicalPlan = null - - /** - * Create a plan for a SHOW FUNCTIONS command. - */ - override def visitShowFunctions(ctx: ShowFunctionsContext): LogicalPlan = withOrigin(ctx) { - import ctx._ - if (qualifiedName != null) { - val names = qualifiedName().identifier().asScala.map(_.getText).toList - names match { - case db :: name :: Nil => - ShowFunctions(Some(db), Some(name)) - case name :: Nil => - ShowFunctions(None, Some(name)) - case _ => - throw new ParseException("SHOW FUNCTIONS unsupported name", ctx) - } - } else if (pattern != null) { - ShowFunctions(None, Some(string(pattern))) - } else { - ShowFunctions(None, None) - } - } - - /** - * Create a plan for a DESCRIBE FUNCTION command. - */ - override def visitDescribeFunction(ctx: DescribeFunctionContext): LogicalPlan = withOrigin(ctx) { - val functionName = ctx.qualifiedName().identifier().asScala.map(_.getText).mkString(".") - DescribeFunction(functionName, ctx.EXTENDED != null) - } - - /** - * Create a top-level plan with Common Table Expressions. - */ - override def visitQuery(ctx: QueryContext): LogicalPlan = withOrigin(ctx) { - val query = plan(ctx.queryNoWith) - - // Apply CTEs - query.optional(ctx.ctes) { - val ctes = ctx.ctes.namedQuery.asScala.map { - case nCtx => - val namedQuery = visitNamedQuery(nCtx) - (namedQuery.alias, namedQuery) - } - - // Check for duplicate names. - ctes.groupBy(_._1).filter(_._2.size > 1).foreach { - case (name, _) => - throw new ParseException( - s"Name '$name' is used for multiple common table expressions", ctx) - } - - With(query, ctes.toMap) - } - } - - /** - * Create a named logical plan. - * - * This is only used for Common Table Expressions. - */ - override def visitNamedQuery(ctx: NamedQueryContext): SubqueryAlias = withOrigin(ctx) { - SubqueryAlias(ctx.name.getText, plan(ctx.queryNoWith)) - } - - /** - * Create a logical plan which allows for multiple inserts using one 'from' statement. These - * queries have the following SQL form: - * {{{ - * [WITH cte...]? - * FROM src - * [INSERT INTO tbl1 SELECT *]+ - * }}} - * For example: - * {{{ - * FROM db.tbl1 A - * INSERT INTO dbo.tbl1 SELECT * WHERE A.value = 10 LIMIT 5 - * INSERT INTO dbo.tbl2 SELECT * WHERE A.value = 12 - * }}} - * This (Hive) feature cannot be combined with set-operators. - */ - override def visitMultiInsertQuery(ctx: MultiInsertQueryContext): LogicalPlan = withOrigin(ctx) { - val from = visitFromClause(ctx.fromClause) - - // Build the insert clauses. - val inserts = ctx.multiInsertQueryBody.asScala.map { - body => - assert(body.querySpecification.fromClause == null, - "Multi-Insert queries cannot have a FROM clause in their individual SELECT statements", - body) - - withQuerySpecification(body.querySpecification, from). - // Add organization statements. - optionalMap(body.queryOrganization)(withQueryResultClauses). - // Add insert. - optionalMap(body.insertInto())(withInsertInto) - } - - // If there are multiple INSERTS just UNION them together into one query. - inserts match { - case Seq(query) => query - case queries => Union(queries) - } - } - - /** - * Create a logical plan for a regular (single-insert) query. - */ - override def visitSingleInsertQuery( - ctx: SingleInsertQueryContext): LogicalPlan = withOrigin(ctx) { - plan(ctx.queryTerm). - // Add organization statements. - optionalMap(ctx.queryOrganization)(withQueryResultClauses). - // Add insert. - optionalMap(ctx.insertInto())(withInsertInto) - } - - /** - * Add an INSERT INTO [TABLE]/INSERT OVERWRITE TABLE operation to the logical plan. - */ - private def withInsertInto( - ctx: InsertIntoContext, - query: LogicalPlan): LogicalPlan = withOrigin(ctx) { - val tableIdent = visitTableIdentifier(ctx.tableIdentifier) - val partitionKeys = Option(ctx.partitionSpec).map(visitPartitionSpec).getOrElse(Map.empty) - - InsertIntoTable( - UnresolvedRelation(tableIdent, None), - partitionKeys, - query, - ctx.OVERWRITE != null, - ctx.EXISTS != null) - } - - /** - * Create a partition specification map. - */ - override def visitPartitionSpec( - ctx: PartitionSpecContext): Map[String, Option[String]] = withOrigin(ctx) { - ctx.partitionVal.asScala.map { pVal => - val name = pVal.identifier.getText.toLowerCase - val value = Option(pVal.constant).map(visitStringConstant) - name -> value - }.toMap - } - - /** - * Create a partition specification map without optional values. - */ - protected def visitNonOptionalPartitionSpec( - ctx: PartitionSpecContext): Map[String, String] = withOrigin(ctx) { - visitPartitionSpec(ctx).mapValues(_.orNull).map(identity) - } - - /** - * Convert a constant of any type into a string. This is typically used in DDL commands, and its - * main purpose is to prevent slight differences due to back to back conversions i.e.: - * String -> Literal -> String. - */ - protected def visitStringConstant(ctx: ConstantContext): String = withOrigin(ctx) { - ctx match { - case s: StringLiteralContext => createString(s) - case o => o.getText - } - } - - /** - * Add ORDER BY/SORT BY/CLUSTER BY/DISTRIBUTE BY/LIMIT/WINDOWS clauses to the logical plan. These - * clauses determine the shape (ordering/partitioning/rows) of the query result. - */ - private def withQueryResultClauses( - ctx: QueryOrganizationContext, - query: LogicalPlan): LogicalPlan = withOrigin(ctx) { - import ctx._ - - // Handle ORDER BY, SORT BY, DISTRIBUTE BY, and CLUSTER BY clause. - val withOrder = if ( - !order.isEmpty && sort.isEmpty && distributeBy.isEmpty && clusterBy.isEmpty) { - // ORDER BY ... - Sort(order.asScala.map(visitSortItem), global = true, query) - } else if (order.isEmpty && !sort.isEmpty && distributeBy.isEmpty && clusterBy.isEmpty) { - // SORT BY ... - Sort(sort.asScala.map(visitSortItem), global = false, query) - } else if (order.isEmpty && sort.isEmpty && !distributeBy.isEmpty && clusterBy.isEmpty) { - // DISTRIBUTE BY ... - RepartitionByExpression(expressionList(distributeBy), query) - } else if (order.isEmpty && !sort.isEmpty && !distributeBy.isEmpty && clusterBy.isEmpty) { - // SORT BY ... DISTRIBUTE BY ... - Sort( - sort.asScala.map(visitSortItem), - global = false, - RepartitionByExpression(expressionList(distributeBy), query)) - } else if (order.isEmpty && sort.isEmpty && distributeBy.isEmpty && !clusterBy.isEmpty) { - // CLUSTER BY ... - val expressions = expressionList(clusterBy) - Sort( - expressions.map(SortOrder(_, Ascending)), - global = false, - RepartitionByExpression(expressions, query)) - } else if (order.isEmpty && sort.isEmpty && distributeBy.isEmpty && clusterBy.isEmpty) { - // [EMPTY] - query - } else { - throw new ParseException( - "Combination of ORDER BY/SORT BY/DISTRIBUTE BY/CLUSTER BY is not supported", ctx) - } - - // WINDOWS - val withWindow = withOrder.optionalMap(windows)(withWindows) - - // LIMIT - withWindow.optional(limit) { - Limit(typedVisit(limit), withWindow) - } - } - - /** - * Create a logical plan using a query specification. - */ - override def visitQuerySpecification( - ctx: QuerySpecificationContext): LogicalPlan = withOrigin(ctx) { - val from = OneRowRelation.optional(ctx.fromClause) { - visitFromClause(ctx.fromClause) - } - withQuerySpecification(ctx, from) - } - - /** - * Add a query specification to a logical plan. The query specification is the core of the logical - * plan, this is where sourcing (FROM clause), transforming (SELECT TRANSFORM/MAP/REDUCE), - * projection (SELECT), aggregation (GROUP BY ... HAVING ...) and filtering (WHERE) takes place. - * - * Note that query hints are ignored (both by the parser and the builder). - */ - private def withQuerySpecification( - ctx: QuerySpecificationContext, - relation: LogicalPlan): LogicalPlan = withOrigin(ctx) { - import ctx._ - - // WHERE - def filter(ctx: BooleanExpressionContext, plan: LogicalPlan): LogicalPlan = { - Filter(expression(ctx), plan) - } - - // Expressions. - val expressions = Option(namedExpressionSeq).toSeq - .flatMap(_.namedExpression.asScala) - .map(typedVisit[Expression]) - - // Create either a transform or a regular query. - val specType = Option(kind).map(_.getType).getOrElse(SqlBaseParser.SELECT) - specType match { - case SqlBaseParser.MAP | SqlBaseParser.REDUCE | SqlBaseParser.TRANSFORM => - // Transform - - // Add where. - val withFilter = relation.optionalMap(where)(filter) - - // Create the attributes. - val (attributes, schemaLess) = if (colTypeList != null) { - // Typed return columns. - (createStructType(colTypeList).toAttributes, false) - } else if (identifierSeq != null) { - // Untyped return columns. - val attrs = visitIdentifierSeq(identifierSeq).map { name => - AttributeReference(name, StringType, nullable = true)() - } - (attrs, false) - } else { - (Seq(AttributeReference("key", StringType)(), - AttributeReference("value", StringType)()), true) - } - - // Create the transform. - ScriptTransformation( - expressions, - string(script), - attributes, - withFilter, - withScriptIOSchema(inRowFormat, recordWriter, outRowFormat, recordReader, schemaLess)) - - case SqlBaseParser.SELECT => - // Regular select - - // Add lateral views. - val withLateralView = ctx.lateralView.asScala.foldLeft(relation)(withGenerate) - - // Add where. - val withFilter = withLateralView.optionalMap(where)(filter) - - // Add aggregation or a project. - val namedExpressions = expressions.map { - case e: NamedExpression => e - case e: Expression => UnresolvedAlias(e) - } - val withProject = if (aggregation != null) { - withAggregation(aggregation, namedExpressions, withFilter) - } else if (namedExpressions.nonEmpty) { - Project(namedExpressions, withFilter) - } else { - withFilter - } - - // Having - val withHaving = withProject.optional(having) { - // Note that we added a cast to boolean. If the expression itself is already boolean, - // the optimizer will get rid of the unnecessary cast. - Filter(Cast(expression(having), BooleanType), withProject) - } - - // Distinct - val withDistinct = if (setQuantifier() != null && setQuantifier().DISTINCT() != null) { - Distinct(withHaving) - } else { - withHaving - } - - // Window - withDistinct.optionalMap(windows)(withWindows) - } - } - - /** - * Create a (Hive based) [[ScriptInputOutputSchema]]. - */ - protected def withScriptIOSchema( - inRowFormat: RowFormatContext, - recordWriter: Token, - outRowFormat: RowFormatContext, - recordReader: Token, - schemaLess: Boolean): ScriptInputOutputSchema = null - - /** - * Create a logical plan for a given 'FROM' clause. Note that we support multiple (comma - * separated) relations here, these get converted into a single plan by condition-less inner join. - */ - override def visitFromClause(ctx: FromClauseContext): LogicalPlan = withOrigin(ctx) { - val from = ctx.relation.asScala.map(plan).reduceLeft(Join(_, _, Inner, None)) - ctx.lateralView.asScala.foldLeft(from)(withGenerate) - } - - /** - * Connect two queries by a Set operator. - * - * Supported Set operators are: - * - UNION [DISTINCT] - * - UNION ALL - * - EXCEPT [DISTINCT] - * - INTERSECT [DISTINCT] - */ - override def visitSetOperation(ctx: SetOperationContext): LogicalPlan = withOrigin(ctx) { - val left = plan(ctx.left) - val right = plan(ctx.right) - val all = Option(ctx.setQuantifier()).exists(_.ALL != null) - ctx.operator.getType match { - case SqlBaseParser.UNION if all => - Union(left, right) - case SqlBaseParser.UNION => - Distinct(Union(left, right)) - case SqlBaseParser.INTERSECT if all => - throw new ParseException("INTERSECT ALL is not supported.", ctx) - case SqlBaseParser.INTERSECT => - Intersect(left, right) - case SqlBaseParser.EXCEPT if all => - throw new ParseException("EXCEPT ALL is not supported.", ctx) - case SqlBaseParser.EXCEPT => - Except(left, right) - } - } - - /** - * Add a [[WithWindowDefinition]] operator to a logical plan. - */ - private def withWindows( - ctx: WindowsContext, - query: LogicalPlan): LogicalPlan = withOrigin(ctx) { - // Collect all window specifications defined in the WINDOW clause. - val baseWindowMap = ctx.namedWindow.asScala.map { - wCtx => - (wCtx.identifier.getText, typedVisit[WindowSpec](wCtx.windowSpec)) - }.toMap - - // Handle cases like - // window w1 as (partition by p_mfgr order by p_name - // range between 2 preceding and 2 following), - // w2 as w1 - val windowMapView = baseWindowMap.mapValues { - case WindowSpecReference(name) => - baseWindowMap.get(name) match { - case Some(spec: WindowSpecDefinition) => - spec - case Some(ref) => - throw new ParseException(s"Window reference '$name' is not a window specification", ctx) - case None => - throw new ParseException(s"Cannot resolve window reference '$name'", ctx) - } - case spec: WindowSpecDefinition => spec - } - - // Note that mapValues creates a view instead of materialized map. We force materialization by - // mapping over identity. - WithWindowDefinition(windowMapView.map(identity), query) - } - - /** - * Add an [[Aggregate]] to a logical plan. - */ - private def withAggregation( - ctx: AggregationContext, - selectExpressions: Seq[NamedExpression], - query: LogicalPlan): LogicalPlan = withOrigin(ctx) { - import ctx._ - val groupByExpressions = expressionList(groupingExpressions) - - if (GROUPING != null) { - // GROUP BY .... GROUPING SETS (...) - val expressionMap = groupByExpressions.zipWithIndex.toMap - val numExpressions = expressionMap.size - val mask = (1 << numExpressions) - 1 - val masks = ctx.groupingSet.asScala.map { - _.expression.asScala.foldLeft(mask) { - case (bitmap, eCtx) => - // Find the index of the expression. - val e = typedVisit[Expression](eCtx) - val index = expressionMap.find(_._1.semanticEquals(e)).map(_._2).getOrElse( - throw new ParseException( - s"$e doesn't show up in the GROUP BY list", ctx)) - // 0 means that the column at the given index is a grouping column, 1 means it is not, - // so we unset the bit in bitmap. - bitmap & ~(1 << (numExpressions - 1 - index)) - } - } - GroupingSets(masks, groupByExpressions, query, selectExpressions) - } else { - // GROUP BY .... (WITH CUBE | WITH ROLLUP)? - val mappedGroupByExpressions = if (CUBE != null) { - Seq(Cube(groupByExpressions)) - } else if (ROLLUP != null) { - Seq(Rollup(groupByExpressions)) - } else { - groupByExpressions - } - Aggregate(mappedGroupByExpressions, selectExpressions, query) - } - } - - /** - * Add a [[Generate]] (Lateral View) to a logical plan. - */ - private def withGenerate( - query: LogicalPlan, - ctx: LateralViewContext): LogicalPlan = withOrigin(ctx) { - val expressions = expressionList(ctx.expression) - - // Create the generator. - val generator = ctx.qualifiedName.getText.toLowerCase match { - case "explode" if expressions.size == 1 => - Explode(expressions.head) - case "json_tuple" => - JsonTuple(expressions) - case other => - withGenerator(other, expressions, ctx) - } - - Generate( - generator, - join = true, - outer = ctx.OUTER != null, - Some(ctx.tblName.getText.toLowerCase), - ctx.colName.asScala.map(_.getText).map(UnresolvedAttribute.apply), - query) - } - - /** - * Create a [[Generator]]. Override this method in order to support custom Generators. - */ - protected def withGenerator( - name: String, - expressions: Seq[Expression], - ctx: LateralViewContext): Generator = { - throw new ParseException(s"Generator function '$name' is not supported", ctx) - } - - /** - * Create a joins between two or more logical plans. - */ - override def visitJoinRelation(ctx: JoinRelationContext): LogicalPlan = withOrigin(ctx) { - /** Build a join between two plans. */ - def join(ctx: JoinRelationContext, left: LogicalPlan, right: LogicalPlan): Join = { - val baseJoinType = ctx.joinType match { - case null => Inner - case jt if jt.FULL != null => FullOuter - case jt if jt.SEMI != null => LeftSemi - case jt if jt.LEFT != null => LeftOuter - case jt if jt.RIGHT != null => RightOuter - case _ => Inner - } - - // Resolve the join type and join condition - val (joinType, condition) = Option(ctx.joinCriteria) match { - case Some(c) if c.USING != null => - val columns = c.identifier.asScala.map { column => - UnresolvedAttribute.quoted(column.getText) - } - (UsingJoin(baseJoinType, columns), None) - case Some(c) if c.booleanExpression != null => - (baseJoinType, Option(expression(c.booleanExpression))) - case None if ctx.NATURAL != null => - (NaturalJoin(baseJoinType), None) - case None => - (baseJoinType, None) - } - Join(left, right, joinType, condition) - } - - // Handle all consecutive join clauses. ANTLR produces a right nested tree in which the the - // first join clause is at the top. However fields of previously referenced tables can be used - // in following join clauses. The tree needs to be reversed in order to make this work. - var result = plan(ctx.left) - var current = ctx - while (current != null) { - current.right match { - case right: JoinRelationContext => - result = join(current, result, plan(right.left)) - current = right - case right => - result = join(current, result, plan(right)) - current = null - } - } - result - } - - /** - * Add a [[Sample]] to a logical plan. - * - * This currently supports the following sampling methods: - * - TABLESAMPLE(x ROWS): Sample the table down to the given number of rows. - * - TABLESAMPLE(x PERCENT): Sample the table down to the given percentage. Note that percentages - * are defined as a number between 0 and 100. - * - TABLESAMPLE(BUCKET x OUT OF y): Sample the table down to a 'x' divided by 'y' fraction. - */ - private def withSample(ctx: SampleContext, query: LogicalPlan): LogicalPlan = withOrigin(ctx) { - // Create a sampled plan if we need one. - def sample(fraction: Double): Sample = { - // The range of fraction accepted by Sample is [0, 1]. Because Hive's block sampling - // function takes X PERCENT as the input and the range of X is [0, 100], we need to - // adjust the fraction. - val eps = RandomSampler.roundingEpsilon - assert(fraction >= 0.0 - eps && fraction <= 1.0 + eps, - s"Sampling fraction ($fraction) must be on interval [0, 1]", - ctx) - Sample(0.0, fraction, withReplacement = false, (math.random * 1000).toInt, query)(true) - } - - ctx.sampleType.getType match { - case SqlBaseParser.ROWS => - Limit(expression(ctx.expression), query) - - case SqlBaseParser.PERCENTLIT => - val fraction = ctx.percentage.getText.toDouble - sample(fraction / 100.0d) - - case SqlBaseParser.BUCKET if ctx.ON != null => - throw new ParseException("TABLESAMPLE(BUCKET x OUT OF y ON id) is not supported", ctx) - - case SqlBaseParser.BUCKET => - sample(ctx.numerator.getText.toDouble / ctx.denominator.getText.toDouble) - } - } - - /** - * Create a logical plan for a sub-query. - */ - override def visitSubquery(ctx: SubqueryContext): LogicalPlan = withOrigin(ctx) { - plan(ctx.queryNoWith) - } - - /** - * Create an un-aliased table reference. This is typically used for top-level table references, - * for example: - * {{{ - * INSERT INTO db.tbl2 - * TABLE db.tbl1 - * }}} - */ - override def visitTable(ctx: TableContext): LogicalPlan = withOrigin(ctx) { - UnresolvedRelation(visitTableIdentifier(ctx.tableIdentifier), None) - } - - /** - * Create an aliased table reference. This is typically used in FROM clauses. - */ - override def visitTableName(ctx: TableNameContext): LogicalPlan = withOrigin(ctx) { - val table = UnresolvedRelation( - visitTableIdentifier(ctx.tableIdentifier), - Option(ctx.identifier).map(_.getText)) - table.optionalMap(ctx.sample)(withSample) - } - - /** - * Create an inline table (a virtual table in Hive parlance). - */ - override def visitInlineTable(ctx: InlineTableContext): LogicalPlan = withOrigin(ctx) { - // Get the backing expressions. - val expressions = ctx.expression.asScala.map { eCtx => - val e = expression(eCtx) - assert(e.foldable, "All expressions in an inline table must be constants.", eCtx) - e - } - - // Validate and evaluate the rows. - val (structType, structConstructor) = expressions.head.dataType match { - case st: StructType => - (st, (e: Expression) => e) - case dt => - val st = CreateStruct(Seq(expressions.head)).dataType - (st, (e: Expression) => CreateStruct(Seq(e))) - } - val rows = expressions.map { - case expression => - val safe = Cast(structConstructor(expression), structType) - safe.eval().asInstanceOf[InternalRow] - } - - // Construct attributes. - val baseAttributes = structType.toAttributes.map(_.withNullability(true)) - val attributes = if (ctx.identifierList != null) { - val aliases = visitIdentifierList(ctx.identifierList) - assert(aliases.size == baseAttributes.size, - "Number of aliases must match the number of fields in an inline table.", ctx) - baseAttributes.zip(aliases).map(p => p._1.withName(p._2)) - } else { - baseAttributes - } - - // Create plan and add an alias if a name has been defined. - LocalRelation(attributes, rows).optionalMap(ctx.identifier)(aliasPlan) - } - - /** - * Create an alias (SubqueryAlias) for a join relation. This is practically the same as - * visitAliasedQuery and visitNamedExpression, ANTLR4 however requires us to use 3 different - * hooks. - */ - override def visitAliasedRelation(ctx: AliasedRelationContext): LogicalPlan = withOrigin(ctx) { - plan(ctx.relation).optionalMap(ctx.sample)(withSample).optionalMap(ctx.identifier)(aliasPlan) - } - - /** - * Create an alias (SubqueryAlias) for a sub-query. This is practically the same as - * visitAliasedRelation and visitNamedExpression, ANTLR4 however requires us to use 3 different - * hooks. - */ - override def visitAliasedQuery(ctx: AliasedQueryContext): LogicalPlan = withOrigin(ctx) { - plan(ctx.queryNoWith).optionalMap(ctx.sample)(withSample).optionalMap(ctx.identifier)(aliasPlan) - } - - /** - * Create an alias (SubqueryAlias) for a LogicalPlan. - */ - private def aliasPlan(alias: IdentifierContext, plan: LogicalPlan): LogicalPlan = { - SubqueryAlias(alias.getText, plan) - } - - /** - * Create a Sequence of Strings for a parenthesis enclosed alias list. - */ - override def visitIdentifierList(ctx: IdentifierListContext): Seq[String] = withOrigin(ctx) { - visitIdentifierSeq(ctx.identifierSeq) - } - - /** - * Create a Sequence of Strings for an identifier list. - */ - override def visitIdentifierSeq(ctx: IdentifierSeqContext): Seq[String] = withOrigin(ctx) { - ctx.identifier.asScala.map(_.getText) - } - - /* ******************************************************************************************** - * Table Identifier parsing - * ******************************************************************************************** */ - /** - * Create a [[TableIdentifier]] from a 'tableName' or 'databaseName'.'tableName' pattern. - */ - override def visitTableIdentifier( - ctx: TableIdentifierContext): TableIdentifier = withOrigin(ctx) { - TableIdentifier(ctx.table.getText, Option(ctx.db).map(_.getText)) - } - - /* ******************************************************************************************** - * Expression parsing - * ******************************************************************************************** */ - /** - * Create an expression from the given context. This method just passes the context on to the - * vistor and only takes care of typing (We assume that the visitor returns an Expression here). - */ - protected def expression(ctx: ParserRuleContext): Expression = typedVisit(ctx) - - /** - * Create sequence of expressions from the given sequence of contexts. - */ - private def expressionList(trees: java.util.List[ExpressionContext]): Seq[Expression] = { - trees.asScala.map(expression) - } - - /** - * Invert a boolean expression if it has a valid NOT clause. - */ - private def invertIfNotDefined(expression: Expression, not: TerminalNode): Expression = { - if (not != null) { - Not(expression) - } else { - expression - } - } - - /** - * Create a star (i.e. all) expression; this selects all elements (in the specified object). - * Both un-targeted (global) and targeted aliases are supported. - */ - override def visitStar(ctx: StarContext): Expression = withOrigin(ctx) { - UnresolvedStar(Option(ctx.qualifiedName()).map(_.identifier.asScala.map(_.getText))) - } - - /** - * Create an aliased expression if an alias is specified. Both single and multi-aliases are - * supported. - */ - override def visitNamedExpression(ctx: NamedExpressionContext): Expression = withOrigin(ctx) { - val e = expression(ctx.expression) - if (ctx.identifier != null) { - Alias(e, ctx.identifier.getText)() - } else if (ctx.identifierList != null) { - MultiAlias(e, visitIdentifierList(ctx.identifierList)) - } else { - e - } - } - - /** - * Combine a number of boolean expressions into a balanced expression tree. These expressions are - * either combined by a logical [[And]] or a logical [[Or]]. - * - * A balanced binary tree is created because regular left recursive trees cause considerable - * performance degradations and can cause stack overflows. - */ - override def visitLogicalBinary(ctx: LogicalBinaryContext): Expression = withOrigin(ctx) { - val expressionType = ctx.operator.getType - val expressionCombiner = expressionType match { - case SqlBaseParser.AND => And.apply _ - case SqlBaseParser.OR => Or.apply _ - } - - // Collect all similar left hand contexts. - val contexts = ArrayBuffer(ctx.right) - var current = ctx.left - def collectContexts: Boolean = current match { - case lbc: LogicalBinaryContext if lbc.operator.getType == expressionType => - contexts += lbc.right - current = lbc.left - true - case _ => - contexts += current - false - } - while (collectContexts) { - // No body - all updates take place in the collectContexts. - } - - // Reverse the contexts to have them in the same sequence as in the SQL statement & turn them - // into expressions. - val expressions = contexts.reverse.map(expression) - - // Create a balanced tree. - def reduceToExpressionTree(low: Int, high: Int): Expression = high - low match { - case 0 => - expressions(low) - case 1 => - expressionCombiner(expressions(low), expressions(high)) - case x => - val mid = low + x / 2 - expressionCombiner( - reduceToExpressionTree(low, mid), - reduceToExpressionTree(mid + 1, high)) - } - reduceToExpressionTree(0, expressions.size - 1) - } - - /** - * Invert a boolean expression. - */ - override def visitLogicalNot(ctx: LogicalNotContext): Expression = withOrigin(ctx) { - Not(expression(ctx.booleanExpression())) - } - - /** - * Create a filtering correlated sub-query. This is not supported yet. - */ - override def visitExists(ctx: ExistsContext): Expression = { - throw new ParseException("EXISTS clauses are not supported.", ctx) - } - - /** - * Create a comparison expression. This compares two expressions. The following comparison - * operators are supported: - * - Equal: '=' or '==' - * - Null-safe Equal: '<=>' - * - Not Equal: '<>' or '!=' - * - Less than: '<' - * - Less then or Equal: '<=' - * - Greater than: '>' - * - Greater then or Equal: '>=' - */ - override def visitComparison(ctx: ComparisonContext): Expression = withOrigin(ctx) { - val left = expression(ctx.left) - val right = expression(ctx.right) - val operator = ctx.comparisonOperator().getChild(0).asInstanceOf[TerminalNode] - operator.getSymbol.getType match { - case SqlBaseParser.EQ => - EqualTo(left, right) - case SqlBaseParser.NSEQ => - EqualNullSafe(left, right) - case SqlBaseParser.NEQ | SqlBaseParser.NEQJ => - Not(EqualTo(left, right)) - case SqlBaseParser.LT => - LessThan(left, right) - case SqlBaseParser.LTE => - LessThanOrEqual(left, right) - case SqlBaseParser.GT => - GreaterThan(left, right) - case SqlBaseParser.GTE => - GreaterThanOrEqual(left, right) - } - } - - /** - * Create a BETWEEN expression. This tests if an expression lies with in the bounds set by two - * other expressions. The inverse can also be created. - */ - override def visitBetween(ctx: BetweenContext): Expression = withOrigin(ctx) { - val value = expression(ctx.value) - val between = And( - GreaterThanOrEqual(value, expression(ctx.lower)), - LessThanOrEqual(value, expression(ctx.upper))) - invertIfNotDefined(between, ctx.NOT) - } - - /** - * Create an IN expression. This tests if the value of the left hand side expression is - * contained by the sequence of expressions on the right hand side. - */ - override def visitInList(ctx: InListContext): Expression = withOrigin(ctx) { - val in = In(expression(ctx.value), ctx.expression().asScala.map(expression)) - invertIfNotDefined(in, ctx.NOT) - } - - /** - * Create an IN expression, where the the right hand side is a query. This is unsupported. - */ - override def visitInSubquery(ctx: InSubqueryContext): Expression = { - throw new ParseException("IN with a Sub-query is currently not supported.", ctx) - } - - /** - * Create a (R)LIKE/REGEXP expression. - */ - override def visitLike(ctx: LikeContext): Expression = { - val left = expression(ctx.value) - val right = expression(ctx.pattern) - val like = ctx.like.getType match { - case SqlBaseParser.LIKE => - Like(left, right) - case SqlBaseParser.RLIKE => - RLike(left, right) - } - invertIfNotDefined(like, ctx.NOT) - } - - /** - * Create an IS (NOT) NULL expression. - */ - override def visitNullPredicate(ctx: NullPredicateContext): Expression = withOrigin(ctx) { - val value = expression(ctx.value) - if (ctx.NOT != null) { - IsNotNull(value) - } else { - IsNull(value) - } - } - - /** - * Create a binary arithmetic expression. The following arithmetic operators are supported: - * - Mulitplication: '*' - * - Division: '/' - * - Hive Long Division: 'DIV' - * - Modulo: '%' - * - Addition: '+' - * - Subtraction: '-' - * - Binary AND: '&' - * - Binary XOR - * - Binary OR: '|' - */ - override def visitArithmeticBinary(ctx: ArithmeticBinaryContext): Expression = withOrigin(ctx) { - val left = expression(ctx.left) - val right = expression(ctx.right) - ctx.operator.getType match { - case SqlBaseParser.ASTERISK => - Multiply(left, right) - case SqlBaseParser.SLASH => - Divide(left, right) - case SqlBaseParser.PERCENT => - Remainder(left, right) - case SqlBaseParser.DIV => - Cast(Divide(left, right), LongType) - case SqlBaseParser.PLUS => - Add(left, right) - case SqlBaseParser.MINUS => - Subtract(left, right) - case SqlBaseParser.AMPERSAND => - BitwiseAnd(left, right) - case SqlBaseParser.HAT => - BitwiseXor(left, right) - case SqlBaseParser.PIPE => - BitwiseOr(left, right) - } - } - - /** - * Create a unary arithmetic expression. The following arithmetic operators are supported: - * - Plus: '+' - * - Minus: '-' - * - Bitwise Not: '~' - */ - override def visitArithmeticUnary(ctx: ArithmeticUnaryContext): Expression = withOrigin(ctx) { - val value = expression(ctx.valueExpression) - ctx.operator.getType match { - case SqlBaseParser.PLUS => - value - case SqlBaseParser.MINUS => - UnaryMinus(value) - case SqlBaseParser.TILDE => - BitwiseNot(value) - } - } - - /** - * Create a [[Cast]] expression. - */ - override def visitCast(ctx: CastContext): Expression = withOrigin(ctx) { - Cast(expression(ctx.expression), typedVisit(ctx.dataType)) - } - - /** - * Create a (windowed) Function expression. - */ - override def visitFunctionCall(ctx: FunctionCallContext): Expression = withOrigin(ctx) { - // Create the function call. - val name = ctx.qualifiedName.getText - val isDistinct = Option(ctx.setQuantifier()).exists(_.DISTINCT != null) - val arguments = ctx.expression().asScala.map(expression) match { - case Seq(UnresolvedStar(None)) if name.toLowerCase == "count" && !isDistinct => - // Transform COUNT(*) into COUNT(1). Move this to analysis? - Seq(Literal(1)) - case expressions => - expressions - } - val function = UnresolvedFunction(name, arguments, isDistinct) - - // Check if the function is evaluated in a windowed context. - ctx.windowSpec match { - case spec: WindowRefContext => - UnresolvedWindowExpression(function, visitWindowRef(spec)) - case spec: WindowDefContext => - WindowExpression(function, visitWindowDef(spec)) - case _ => function - } - } - - /** - * Create a reference to a window frame, i.e. [[WindowSpecReference]]. - */ - override def visitWindowRef(ctx: WindowRefContext): WindowSpecReference = withOrigin(ctx) { - WindowSpecReference(ctx.identifier.getText) - } - - /** - * Create a window definition, i.e. [[WindowSpecDefinition]]. - */ - override def visitWindowDef(ctx: WindowDefContext): WindowSpecDefinition = withOrigin(ctx) { - // CLUSTER BY ... | PARTITION BY ... ORDER BY ... - val partition = ctx.partition.asScala.map(expression) - val order = ctx.sortItem.asScala.map(visitSortItem) - - // RANGE/ROWS BETWEEN ... - val frameSpecOption = Option(ctx.windowFrame).map { frame => - val frameType = frame.frameType.getType match { - case SqlBaseParser.RANGE => RangeFrame - case SqlBaseParser.ROWS => RowFrame - } - - SpecifiedWindowFrame( - frameType, - visitFrameBound(frame.start), - Option(frame.end).map(visitFrameBound).getOrElse(CurrentRow)) - } - - WindowSpecDefinition( - partition, - order, - frameSpecOption.getOrElse(UnspecifiedFrame)) - } - - /** - * Create or resolve a [[FrameBoundary]]. Simple math expressions are allowed for Value - * Preceding/Following boundaries. These expressions must be constant (foldable) and return an - * integer value. - */ - override def visitFrameBound(ctx: FrameBoundContext): FrameBoundary = withOrigin(ctx) { - // We currently only allow foldable integers. - def value: Int = { - val e = expression(ctx.expression) - assert(e.resolved && e.foldable && e.dataType == IntegerType, - "Frame bound value must be a constant integer.", - ctx) - e.eval().asInstanceOf[Int] - } - - // Create the FrameBoundary - ctx.boundType.getType match { - case SqlBaseParser.PRECEDING if ctx.UNBOUNDED != null => - UnboundedPreceding - case SqlBaseParser.PRECEDING => - ValuePreceding(value) - case SqlBaseParser.CURRENT => - CurrentRow - case SqlBaseParser.FOLLOWING if ctx.UNBOUNDED != null => - UnboundedFollowing - case SqlBaseParser.FOLLOWING => - ValueFollowing(value) - } - } - - /** - * Create a [[CreateStruct]] expression. - */ - override def visitRowConstructor(ctx: RowConstructorContext): Expression = withOrigin(ctx) { - CreateStruct(ctx.expression.asScala.map(expression)) - } - - /** - * Create a [[ScalarSubquery]] expression. - */ - override def visitSubqueryExpression( - ctx: SubqueryExpressionContext): Expression = withOrigin(ctx) { - ScalarSubquery(plan(ctx.query)) - } - - /** - * Create a value based [[CaseWhen]] expression. This has the following SQL form: - * {{{ - * CASE [expression] - * WHEN [value] THEN [expression] - * ... - * ELSE [expression] - * END - * }}} - */ - override def visitSimpleCase(ctx: SimpleCaseContext): Expression = withOrigin(ctx) { - val e = expression(ctx.valueExpression) - val branches = ctx.whenClause.asScala.map { wCtx => - (EqualTo(e, expression(wCtx.condition)), expression(wCtx.result)) - } - CaseWhen(branches, Option(ctx.elseExpression).map(expression)) - } - - /** - * Create a condition based [[CaseWhen]] expression. This has the following SQL syntax: - * {{{ - * CASE - * WHEN [predicate] THEN [expression] - * ... - * ELSE [expression] - * END - * }}} - * - * @param ctx the parse tree - * */ - override def visitSearchedCase(ctx: SearchedCaseContext): Expression = withOrigin(ctx) { - val branches = ctx.whenClause.asScala.map { wCtx => - (expression(wCtx.condition), expression(wCtx.result)) - } - CaseWhen(branches, Option(ctx.elseExpression).map(expression)) - } - - /** - * Create a dereference expression. The return type depends on the type of the parent, this can - * either be a [[UnresolvedAttribute]] (if the parent is an [[UnresolvedAttribute]]), or an - * [[UnresolvedExtractValue]] if the parent is some expression. - */ - override def visitDereference(ctx: DereferenceContext): Expression = withOrigin(ctx) { - val attr = ctx.fieldName.getText - expression(ctx.base) match { - case UnresolvedAttribute(nameParts) => - UnresolvedAttribute(nameParts :+ attr) - case e => - UnresolvedExtractValue(e, Literal(attr)) - } - } - - /** - * Create an [[UnresolvedAttribute]] expression. - */ - override def visitColumnReference(ctx: ColumnReferenceContext): Expression = withOrigin(ctx) { - UnresolvedAttribute.quoted(ctx.getText) - } - - /** - * Create an [[UnresolvedExtractValue]] expression, this is used for subscript access to an array. - */ - override def visitSubscript(ctx: SubscriptContext): Expression = withOrigin(ctx) { - UnresolvedExtractValue(expression(ctx.value), expression(ctx.index)) - } - - /** - * Create an expression for an expression between parentheses. This is need because the ANTLR - * visitor cannot automatically convert the nested context into an expression. - */ - override def visitParenthesizedExpression( - ctx: ParenthesizedExpressionContext): Expression = withOrigin(ctx) { - expression(ctx.expression) - } - - /** - * Create a [[SortOrder]] expression. - */ - override def visitSortItem(ctx: SortItemContext): SortOrder = withOrigin(ctx) { - if (ctx.DESC != null) { - SortOrder(expression(ctx.expression), Descending) - } else { - SortOrder(expression(ctx.expression), Ascending) - } - } - - /** - * Create a typed Literal expression. A typed literal has the following SQL syntax: - * {{{ - * [TYPE] '[VALUE]' - * }}} - * Currently Date and Timestamp typed literals are supported. - * - * TODO what the added value of this over casting? - */ - override def visitTypeConstructor(ctx: TypeConstructorContext): Literal = withOrigin(ctx) { - val value = string(ctx.STRING) - ctx.identifier.getText.toUpperCase match { - case "DATE" => - Literal(Date.valueOf(value)) - case "TIMESTAMP" => - Literal(Timestamp.valueOf(value)) - case other => - throw new ParseException(s"Literals of type '$other' are currently not supported.", ctx) - } - } - - /** - * Create a NULL literal expression. - */ - override def visitNullLiteral(ctx: NullLiteralContext): Literal = withOrigin(ctx) { - Literal(null) - } - - /** - * Create a Boolean literal expression. - */ - override def visitBooleanLiteral(ctx: BooleanLiteralContext): Literal = withOrigin(ctx) { - if (ctx.getText.toBoolean) { - Literal.TrueLiteral - } else { - Literal.FalseLiteral - } - } - - /** - * Create an integral literal expression. The code selects the most narrow integral type - * possible, either a BigDecimal, a Long or an Integer is returned. - */ - override def visitIntegerLiteral(ctx: IntegerLiteralContext): Literal = withOrigin(ctx) { - BigDecimal(ctx.getText) match { - case v if v.isValidInt => - Literal(v.intValue()) - case v if v.isValidLong => - Literal(v.longValue()) - case v => Literal(v.underlying()) - } - } - - /** - * Create a double literal for a number denoted in scientifc notation. - */ - override def visitScientificDecimalLiteral( - ctx: ScientificDecimalLiteralContext): Literal = withOrigin(ctx) { - Literal(ctx.getText.toDouble) - } - - /** - * Create a decimal literal for a regular decimal number. - */ - override def visitDecimalLiteral(ctx: DecimalLiteralContext): Literal = withOrigin(ctx) { - Literal(BigDecimal(ctx.getText).underlying()) - } - - /** Create a numeric literal expression. */ - private def numericLiteral(ctx: NumberContext)(f: String => Any): Literal = withOrigin(ctx) { - val raw = ctx.getText - try { - Literal(f(raw.substring(0, raw.length - 1))) - } catch { - case e: NumberFormatException => - throw new ParseException(e.getMessage, ctx) - } - } - - /** - * Create a Byte Literal expression. - */ - override def visitTinyIntLiteral(ctx: TinyIntLiteralContext): Literal = numericLiteral(ctx) { - _.toByte - } - - /** - * Create a Short Literal expression. - */ - override def visitSmallIntLiteral(ctx: SmallIntLiteralContext): Literal = numericLiteral(ctx) { - _.toShort - } - - /** - * Create a Long Literal expression. - */ - override def visitBigIntLiteral(ctx: BigIntLiteralContext): Literal = numericLiteral(ctx) { - _.toLong - } - - /** - * Create a Double Literal expression. - */ - override def visitDoubleLiteral(ctx: DoubleLiteralContext): Literal = numericLiteral(ctx) { - _.toDouble - } - - /** - * Create a String literal expression. - */ - override def visitStringLiteral(ctx: StringLiteralContext): Literal = withOrigin(ctx) { - Literal(createString(ctx)) - } - - /** - * Create a String from a string literal context. This supports multiple consecutive string - * literals, these are concatenated, for example this expression "'hello' 'world'" will be - * converted into "helloworld". - * - * Special characters can be escaped by using Hive/C-style escaping. - */ - private def createString(ctx: StringLiteralContext): String = { - ctx.STRING().asScala.map(string).mkString - } - - /** - * Create a [[CalendarInterval]] literal expression. An interval expression can contain multiple - * unit value pairs, for instance: interval 2 months 2 days. - */ - override def visitInterval(ctx: IntervalContext): Literal = withOrigin(ctx) { - val intervals = ctx.intervalField.asScala.map(visitIntervalField) - assert(intervals.nonEmpty, "at least one time unit should be given for interval literal", ctx) - Literal(intervals.reduce(_.add(_))) - } - - /** - * Create a [[CalendarInterval]] for a unit value pair. Two unit configuration types are - * supported: - * - Single unit. - * - From-To unit (only 'YEAR TO MONTH' and 'DAY TO SECOND' are supported). - */ - override def visitIntervalField(ctx: IntervalFieldContext): CalendarInterval = withOrigin(ctx) { - import ctx._ - val s = value.getText - val interval = (unit.getText.toLowerCase, Option(to).map(_.getText.toLowerCase)) match { - case (u, None) if u.endsWith("s") => - // Handle plural forms, e.g: yearS/monthS/weekS/dayS/hourS/minuteS/hourS/... - CalendarInterval.fromSingleUnitString(u.substring(0, u.length - 1), s) - case (u, None) => - CalendarInterval.fromSingleUnitString(u, s) - case ("year", Some("month")) => - CalendarInterval.fromYearMonthString(s) - case ("day", Some("second")) => - CalendarInterval.fromDayTimeString(s) - case (from, Some(t)) => - throw new ParseException(s"Intervals FROM $from TO $t are not supported.", ctx) - } - assert(interval != null, "No interval can be constructed", ctx) - interval - } - - /* ******************************************************************************************** - * DataType parsing - * ******************************************************************************************** */ - /** - * Resolve/create a primitive type. - */ - override def visitPrimitiveDataType(ctx: PrimitiveDataTypeContext): DataType = withOrigin(ctx) { - (ctx.identifier.getText.toLowerCase, ctx.INTEGER_VALUE().asScala.toList) match { - case ("boolean", Nil) => BooleanType - case ("tinyint" | "byte", Nil) => ByteType - case ("smallint" | "short", Nil) => ShortType - case ("int" | "integer", Nil) => IntegerType - case ("bigint" | "long", Nil) => LongType - case ("float", Nil) => FloatType - case ("double", Nil) => DoubleType - case ("date", Nil) => DateType - case ("timestamp", Nil) => TimestampType - case ("char" | "varchar" | "string", Nil) => StringType - case ("char" | "varchar", _ :: Nil) => StringType - case ("binary", Nil) => BinaryType - case ("decimal", Nil) => DecimalType.USER_DEFAULT - case ("decimal", precision :: Nil) => DecimalType(precision.getText.toInt, 0) - case ("decimal", precision :: scale :: Nil) => - DecimalType(precision.getText.toInt, scale.getText.toInt) - case (dt, params) => - throw new ParseException( - s"DataType $dt${params.mkString("(", ",", ")")} is not supported.", ctx) - } - } - - /** - * Create a complex DataType. Arrays, Maps and Structures are supported. - */ - override def visitComplexDataType(ctx: ComplexDataTypeContext): DataType = withOrigin(ctx) { - ctx.complex.getType match { - case SqlBaseParser.ARRAY => - ArrayType(typedVisit(ctx.dataType(0))) - case SqlBaseParser.MAP => - MapType(typedVisit(ctx.dataType(0)), typedVisit(ctx.dataType(1))) - case SqlBaseParser.STRUCT => - createStructType(ctx.colTypeList()) - } - } - - /** - * Create a [[StructType]] from a sequence of [[StructField]]s. - */ - protected def createStructType(ctx: ColTypeListContext): StructType = { - StructType(Option(ctx).toSeq.flatMap(visitColTypeList)) - } - - /** - * Create a [[StructType]] from a number of column definitions. - */ - override def visitColTypeList(ctx: ColTypeListContext): Seq[StructField] = withOrigin(ctx) { - ctx.colType().asScala.map(visitColType) - } - - /** - * Create a [[StructField]] from a column definition. - */ - override def visitColType(ctx: ColTypeContext): StructField = withOrigin(ctx) { - import ctx._ - - // Add the comment to the metadata. - val builder = new MetadataBuilder - if (STRING != null) { - builder.putString("comment", string(STRING)) - } - - StructField(identifier.getText, typedVisit(dataType), nullable = true, builder.build()) - } -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ng/ParseDriver.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ng/ParseDriver.scala deleted file mode 100644 index c9a286374c..0000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ng/ParseDriver.scala +++ /dev/null @@ -1,240 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.sql.catalyst.parser.ng - -import org.antlr.v4.runtime._ -import org.antlr.v4.runtime.atn.PredictionMode -import org.antlr.v4.runtime.misc.ParseCancellationException - -import org.apache.spark.internal.Logging -import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.catalyst.expressions.Expression -import org.apache.spark.sql.catalyst.parser.ParserInterface -import org.apache.spark.sql.catalyst.parser.ng.SqlBaseParser._ -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.catalyst.trees.Origin -import org.apache.spark.sql.types.DataType - -/** - * Base SQL parsing infrastructure. - */ -abstract class AbstractSqlParser extends ParserInterface with Logging { - - /** Creates/Resolves DataType for a given SQL string. */ - def parseDataType(sqlText: String): DataType = parse(sqlText) { parser => - // TODO add this to the parser interface. - astBuilder.visitSingleDataType(parser.singleDataType()) - } - - /** Creates Expression for a given SQL string. */ - override def parseExpression(sqlText: String): Expression = parse(sqlText) { parser => - astBuilder.visitSingleExpression(parser.singleExpression()) - } - - /** Creates TableIdentifier for a given SQL string. */ - override def parseTableIdentifier(sqlText: String): TableIdentifier = parse(sqlText) { parser => - astBuilder.visitSingleTableIdentifier(parser.singleTableIdentifier()) - } - - /** Creates LogicalPlan for a given SQL string. */ - override def parsePlan(sqlText: String): LogicalPlan = parse(sqlText) { parser => - astBuilder.visitSingleStatement(parser.singleStatement()) match { - case plan: LogicalPlan => plan - case _ => nativeCommand(sqlText) - } - } - - /** Get the builder (visitor) which converts a ParseTree into a AST. */ - protected def astBuilder: AstBuilder - - /** Create a native command, or fail when this is not supported. */ - protected def nativeCommand(sqlText: String): LogicalPlan = { - val position = Origin(None, None) - throw new ParseException(Option(sqlText), "Unsupported SQL statement", position, position) - } - - protected def parse[T](command: String)(toResult: SqlBaseParser => T): T = { - logInfo(s"Parsing command: $command") - - val lexer = new SqlBaseLexer(new ANTLRNoCaseStringStream(command)) - lexer.removeErrorListeners() - lexer.addErrorListener(ParseErrorListener) - - val tokenStream = new CommonTokenStream(lexer) - val parser = new SqlBaseParser(tokenStream) - parser.addParseListener(PostProcessor) - parser.removeErrorListeners() - parser.addErrorListener(ParseErrorListener) - - try { - try { - // first, try parsing with potentially faster SLL mode - parser.getInterpreter.setPredictionMode(PredictionMode.SLL) - toResult(parser) - } - catch { - case e: ParseCancellationException => - // if we fail, parse with LL mode - tokenStream.reset() // rewind input stream - parser.reset() - - // Try Again. - parser.getInterpreter.setPredictionMode(PredictionMode.LL) - toResult(parser) - } - } - catch { - case e: ParseException if e.command.isDefined => - throw e - case e: ParseException => - throw e.withCommand(command) - case e: AnalysisException => - val position = Origin(e.line, e.startPosition) - throw new ParseException(Option(command), e.message, position, position) - } - } -} - -/** - * Concrete SQL parser for Catalyst-only SQL statements. - */ -object CatalystSqlParser extends AbstractSqlParser { - val astBuilder = new AstBuilder -} - -/** - * This string stream provides the lexer with upper case characters only. This greatly simplifies - * lexing the stream, while we can maintain the original command. - * - * This is based on Hive's org.apache.hadoop.hive.ql.parse.ParseDriver.ANTLRNoCaseStringStream - * - * The comment below (taken from the original class) describes the rationale for doing this: - * - * This class provides and implementation for a case insensitive token checker for the lexical - * analysis part of antlr. By converting the token stream into upper case at the time when lexical - * rules are checked, this class ensures that the lexical rules need to just match the token with - * upper case letters as opposed to combination of upper case and lower case characters. This is - * purely used for matching lexical rules. The actual token text is stored in the same way as the - * user input without actually converting it into an upper case. The token values are generated by - * the consume() function of the super class ANTLRStringStream. The LA() function is the lookahead - * function and is purely used for matching lexical rules. This also means that the grammar will - * only accept capitalized tokens in case it is run from other tools like antlrworks which do not - * have the ANTLRNoCaseStringStream implementation. - */ - -private[parser] class ANTLRNoCaseStringStream(input: String) extends ANTLRInputStream(input) { - override def LA(i: Int): Int = { - val la = super.LA(i) - if (la == 0 || la == IntStream.EOF) la - else Character.toUpperCase(la) - } -} - -/** - * The ParseErrorListener converts parse errors into AnalysisExceptions. - */ -case object ParseErrorListener extends BaseErrorListener { - override def syntaxError( - recognizer: Recognizer[_, _], - offendingSymbol: scala.Any, - line: Int, - charPositionInLine: Int, - msg: String, - e: RecognitionException): Unit = { - val position = Origin(Some(line), Some(charPositionInLine)) - throw new ParseException(None, msg, position, position) - } -} - -/** - * A [[ParseException]] is an [[AnalysisException]] that is thrown during the parse process. It - * contains fields and an extended error message that make reporting and diagnosing errors easier. - */ -class ParseException( - val command: Option[String], - message: String, - val start: Origin, - val stop: Origin) extends AnalysisException(message, start.line, start.startPosition) { - - def this(message: String, ctx: ParserRuleContext) = { - this(Option(ParserUtils.command(ctx)), - message, - ParserUtils.position(ctx.getStart), - ParserUtils.position(ctx.getStop)) - } - - override def getMessage: String = { - val builder = new StringBuilder - builder ++= "\n" ++= message - start match { - case Origin(Some(l), Some(p)) => - builder ++= s"(line $l, pos $p)\n" - command.foreach { cmd => - val (above, below) = cmd.split("\n").splitAt(l) - builder ++= "\n== SQL ==\n" - above.foreach(builder ++= _ += '\n') - builder ++= (0 until p).map(_ => "-").mkString("") ++= "^^^\n" - below.foreach(builder ++= _ += '\n') - } - case _ => - command.foreach { cmd => - builder ++= "\n== SQL ==\n" ++= cmd - } - } - builder.toString - } - - def withCommand(cmd: String): ParseException = { - new ParseException(Option(cmd), message, start, stop) - } -} - -/** - * The post-processor validates & cleans-up the parse tree during the parse process. - */ -case object PostProcessor extends SqlBaseBaseListener { - - /** Remove the back ticks from an Identifier. */ - override def exitQuotedIdentifier(ctx: QuotedIdentifierContext): Unit = { - replaceTokenByIdentifier(ctx, 1) { token => - // Remove the double back ticks in the string. - token.setText(token.getText.replace("``", "`")) - token - } - } - - /** Treat non-reserved keywords as Identifiers. */ - override def exitNonReserved(ctx: NonReservedContext): Unit = { - replaceTokenByIdentifier(ctx, 0)(identity) - } - - private def replaceTokenByIdentifier( - ctx: ParserRuleContext, - stripMargins: Int)( - f: CommonToken => CommonToken = identity): Unit = { - val parent = ctx.getParent - parent.removeLastChild() - val token = ctx.getChild(0).getPayload.asInstanceOf[Token] - parent.addChild(f(new CommonToken( - new org.antlr.v4.runtime.misc.Pair(token.getTokenSource, token.getInputStream), - SqlBaseParser.IDENTIFIER, - token.getChannel, - token.getStartIndex + stripMargins, - token.getStopIndex - stripMargins))) - } -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ng/ParserUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ng/ParserUtils.scala deleted file mode 100644 index 1fbfa763b4..0000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ng/ParserUtils.scala +++ /dev/null @@ -1,118 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.sql.catalyst.parser.ng - -import org.antlr.v4.runtime.{CharStream, ParserRuleContext, Token} -import org.antlr.v4.runtime.misc.Interval -import org.antlr.v4.runtime.tree.TerminalNode - -import org.apache.spark.sql.catalyst.parser.ParseUtils.unescapeSQLString -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, Origin} - -/** - * A collection of utility methods for use during the parsing process. - */ -object ParserUtils { - /** Get the command which created the token. */ - def command(ctx: ParserRuleContext): String = { - command(ctx.getStart.getInputStream) - } - - /** Get the command which created the token. */ - def command(stream: CharStream): String = { - stream.getText(Interval.of(0, stream.size())) - } - - /** Get the code that creates the given node. */ - def source(ctx: ParserRuleContext): String = { - val stream = ctx.getStart.getInputStream - stream.getText(Interval.of(ctx.getStart.getStartIndex, ctx.getStop.getStopIndex)) - } - - /** Get all the text which comes after the given rule. */ - def remainder(ctx: ParserRuleContext): String = remainder(ctx.getStop) - - /** Get all the text which comes after the given token. */ - def remainder(token: Token): String = { - val stream = token.getInputStream - val interval = Interval.of(token.getStopIndex + 1, stream.size()) - stream.getText(interval) - } - - /** Convert a string token into a string. */ - def string(token: Token): String = unescapeSQLString(token.getText) - - /** Convert a string node into a string. */ - def string(node: TerminalNode): String = unescapeSQLString(node.getText) - - /** Get the origin (line and position) of the token. */ - def position(token: Token): Origin = { - Origin(Option(token.getLine), Option(token.getCharPositionInLine)) - } - - /** Assert if a condition holds. If it doesn't throw a parse exception. */ - def assert(f: => Boolean, message: String, ctx: ParserRuleContext): Unit = { - if (!f) { - throw new ParseException(message, ctx) - } - } - - /** - * Register the origin of the context. Any TreeNode created in the closure will be assigned the - * registered origin. This method restores the previously set origin after completion of the - * closure. - */ - def withOrigin[T](ctx: ParserRuleContext)(f: => T): T = { - val current = CurrentOrigin.get - CurrentOrigin.set(position(ctx.getStart)) - try { - f - } finally { - CurrentOrigin.set(current) - } - } - - /** Some syntactic sugar which makes it easier to work with optional clauses for LogicalPlans. */ - implicit class EnhancedLogicalPlan(val plan: LogicalPlan) extends AnyVal { - /** - * Create a plan using the block of code when the given context exists. Otherwise return the - * original plan. - */ - def optional(ctx: AnyRef)(f: => LogicalPlan): LogicalPlan = { - if (ctx != null) { - f - } else { - plan - } - } - - /** - * Map a [[LogicalPlan]] to another [[LogicalPlan]] if the passed context exists using the - * passed function. The original plan is returned when the context does not exist. - */ - def optionalMap[C <: ParserRuleContext]( - ctx: C)( - f: (C, LogicalPlan) => LogicalPlan): LogicalPlan = { - if (ctx != null) { - f(ctx, plan) - } else { - plan - } - } - } -} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ASTNodeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ASTNodeSuite.scala deleted file mode 100644 index 8b05f9e33d..0000000000 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ASTNodeSuite.scala +++ /dev/null @@ -1,38 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.sql.catalyst.parser - -import org.apache.spark.SparkFunSuite - -class ASTNodeSuite extends SparkFunSuite { - test("SPARK-13157 - remainder must return all input chars") { - val inputs = Seq( - ("add jar", "file:///tmp/ab/TestUDTF.jar"), - ("add jar", "file:///tmp/a@b/TestUDTF.jar"), - ("add jar", "c:\\windows32\\TestUDTF.jar"), - ("add jar", "some \nbad\t\tfile\r\n.\njar"), - ("ADD JAR", "@*#&@(!#@$^*!@^@#(*!@#"), - ("SET", "foo=bar"), - ("SET", "foo*)(@#^*@&!#^=bar") - ) - inputs.foreach { - case (command, arguments) => - val node = ParseDriver.parsePlan(s"$command $arguments", null) - assert(node.remainder === arguments) - } - } -} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/CatalystQlSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/CatalystQlSuite.scala deleted file mode 100644 index 223485e292..0000000000 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/CatalystQlSuite.scala +++ /dev/null @@ -1,223 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.parser - -import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.catalyst.analysis._ -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.parser.ng.CatalystSqlParser -import org.apache.spark.sql.catalyst.plans.PlanTest -import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.unsafe.types.CalendarInterval - -class CatalystQlSuite extends PlanTest { - val parser = new CatalystQl() - import org.apache.spark.sql.catalyst.dsl.expressions._ - import org.apache.spark.sql.catalyst.dsl.plans._ - - val star = UnresolvedAlias(UnresolvedStar(None)) - - test("test case insensitive") { - val result = OneRowRelation.select(1) - assert(result === parser.parsePlan("seLect 1")) - assert(result === parser.parsePlan("select 1")) - assert(result === parser.parsePlan("SELECT 1")) - } - - test("test NOT operator with comparison operations") { - val parsed = parser.parsePlan("SELECT NOT TRUE > TRUE") - val expected = OneRowRelation.select(Not(GreaterThan(true, true))) - comparePlans(parsed, expected) - } - - test("test Union Distinct operator") { - val parsed1 = parser.parsePlan( - "SELECT * FROM t0 UNION SELECT * FROM t1") - val parsed2 = parser.parsePlan( - "SELECT * FROM t0 UNION DISTINCT SELECT * FROM t1") - val expected = Distinct(Union(table("t0").select(star), table("t1").select(star))) - .as("u_1").select(star) - comparePlans(parsed1, expected) - comparePlans(parsed2, expected) - } - - test("test Union All operator") { - val parsed = parser.parsePlan("SELECT * FROM t0 UNION ALL SELECT * FROM t1") - val expected = Union(table("t0").select(star), table("t1").select(star)).as("u_1").select(star) - comparePlans(parsed, expected) - } - - test("support hive interval literal") { - def checkInterval(sql: String, result: CalendarInterval): Unit = { - val parsed = parser.parsePlan(sql) - val expected = OneRowRelation.select(Literal(result)) - comparePlans(parsed, expected) - } - - def checkYearMonth(lit: String): Unit = { - checkInterval( - s"SELECT INTERVAL '$lit' YEAR TO MONTH", - CalendarInterval.fromYearMonthString(lit)) - } - - def checkDayTime(lit: String): Unit = { - checkInterval( - s"SELECT INTERVAL '$lit' DAY TO SECOND", - CalendarInterval.fromDayTimeString(lit)) - } - - def checkSingleUnit(lit: String, unit: String): Unit = { - checkInterval( - s"SELECT INTERVAL '$lit' $unit", - CalendarInterval.fromSingleUnitString(unit, lit)) - } - - checkYearMonth("123-10") - checkYearMonth("496-0") - checkYearMonth("-2-3") - checkYearMonth("-123-0") - - checkDayTime("99 11:22:33.123456789") - checkDayTime("-99 11:22:33.123456789") - checkDayTime("10 9:8:7.123456789") - checkDayTime("1 0:0:0") - checkDayTime("-1 0:0:0") - checkDayTime("1 0:0:1") - - for (unit <- Seq("year", "month", "day", "hour", "minute", "second")) { - checkSingleUnit("7", unit) - checkSingleUnit("-7", unit) - checkSingleUnit("0", unit) - } - - checkSingleUnit("13.123456789", "second") - checkSingleUnit("-13.123456789", "second") - } - - test("support scientific notation") { - def assertRight(input: String, output: Double): Unit = { - val parsed = parser.parsePlan("SELECT " + input) - val expected = OneRowRelation.select(Literal(output)) - comparePlans(parsed, expected) - } - - assertRight("9.0e1", 90) - assertRight(".9e+2", 90) - assertRight("0.9e+2", 90) - assertRight("900e-1", 90) - assertRight("900.0E-1", 90) - assertRight("9.e+1", 90) - - intercept[AnalysisException](parser.parsePlan("SELECT .e3")) - } - - test("parse expressions") { - compareExpressions( - parser.parseExpression("prinln('hello', 'world')"), - UnresolvedFunction( - "prinln", Literal("hello") :: Literal("world") :: Nil, false)) - - compareExpressions( - parser.parseExpression("1 + r.r As q"), - Alias(Add(Literal(1), UnresolvedAttribute("r.r")), "q")()) - - compareExpressions( - parser.parseExpression("1 - f('o', o(bar))"), - Subtract(Literal(1), - UnresolvedFunction("f", - Literal("o") :: - UnresolvedFunction("o", UnresolvedAttribute("bar") :: Nil, false) :: - Nil, false))) - - intercept[AnalysisException](parser.parseExpression("1 - f('o', o(bar)) hello * world")) - } - - test("table identifier") { - assert(TableIdentifier("q") === parser.parseTableIdentifier("q")) - assert(TableIdentifier("q", Some("d")) === parser.parseTableIdentifier("d.q")) - intercept[AnalysisException](parser.parseTableIdentifier("")) - intercept[AnalysisException](parser.parseTableIdentifier("d.q.g")) - } - - test("parse union/except/intersect") { - parser.parsePlan("select * from t1 union all select * from t2") - parser.parsePlan("select * from t1 union distinct select * from t2") - parser.parsePlan("select * from t1 union select * from t2") - parser.parsePlan("select * from t1 except select * from t2") - parser.parsePlan("select * from t1 intersect select * from t2") - parser.parsePlan("(select * from t1) union all (select * from t2)") - parser.parsePlan("(select * from t1) union distinct (select * from t2)") - parser.parsePlan("(select * from t1) union (select * from t2)") - parser.parsePlan("select * from ((select * from t1) union (select * from t2)) t") - } - - test("window function: better support of parentheses") { - parser.parsePlan("select sum(product + 1) over (partition by ((1) + (product / 2)) " + - "order by 2) from windowData") - parser.parsePlan("select sum(product + 1) over (partition by (1 + (product / 2)) " + - "order by 2) from windowData") - parser.parsePlan("select sum(product + 1) over (partition by ((product / 2) + 1) " + - "order by 2) from windowData") - - parser.parsePlan("select sum(product + 1) over (partition by ((product) + (1)) order by 2) " + - "from windowData") - parser.parsePlan("select sum(product + 1) over (partition by ((product) + 1) order by 2) " + - "from windowData") - parser.parsePlan("select sum(product + 1) over (partition by (product + (1)) order by 2) " + - "from windowData") - } - - test("very long AND/OR expression") { - val equals = (1 to 1000).map(x => s"$x == $x") - val expr = parser.parseExpression(equals.mkString(" AND ")) - assert(expr.isInstanceOf[And]) - assert(expr.collect( { case EqualTo(_, _) => true } ).size == 1000) - - val expr2 = parser.parseExpression(equals.mkString(" OR ")) - assert(expr2.isInstanceOf[Or]) - assert(expr2.collect( { case EqualTo(_, _) => true } ).size == 1000) - } - - test("subquery") { - parser.parsePlan("select (select max(b) from s) ss from t") - parser.parsePlan("select * from t where a = (select b from s)") - parser.parsePlan("select * from t group by g having a > (select b from s)") - } - - test("using clause in JOIN") { - // Tests parsing of using clause for different join types. - parser.parsePlan("select * from t1 join t2 using (c1)") - parser.parsePlan("select * from t1 join t2 using (c1, c2)") - parser.parsePlan("select * from t1 left join t2 using (c1, c2)") - parser.parsePlan("select * from t1 right join t2 using (c1, c2)") - parser.parsePlan("select * from t1 full outer join t2 using (c1, c2)") - parser.parsePlan("select * from t1 join t2 using (c1) join t3 using (c2)") - // Tests errors - // (1) Empty using clause - // (2) Qualified columns in using - // (3) Both on and using clause - var error = intercept[AnalysisException](parser.parsePlan("select * from t1 join t2 using ()")) - assert(error.message.contains("cannot recognize input near ')'")) - error = intercept[AnalysisException](parser.parsePlan("select * from t1 join t2 using (t1.c1)")) - assert(error.message.contains("mismatched input '.'")) - error = intercept[AnalysisException](parser.parsePlan("select * from t1" + - " join t2 using (c1) on t1.c1 = t2.c1")) - assert(error.message.contains("missing EOF at 'on' near ')'")) - } -} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DataTypeParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DataTypeParserSuite.scala index d9bd33c50a..07b89cb61f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DataTypeParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DataTypeParserSuite.scala @@ -18,7 +18,6 @@ package org.apache.spark.sql.catalyst.parser import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.parser.ng.{CatalystSqlParser, ParseException} import org.apache.spark.sql.types._ abstract class AbstractDataTypeParserSuite extends SparkFunSuite { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ErrorParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ErrorParserSuite.scala new file mode 100644 index 0000000000..db96bfb652 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ErrorParserSuite.scala @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.parser + +import org.apache.spark.SparkFunSuite + +/** + * Test various parser errors. + */ +class ErrorParserSuite extends SparkFunSuite { + def intercept(sql: String, line: Int, startPosition: Int, messages: String*): Unit = { + val e = intercept[ParseException](CatalystSqlParser.parsePlan(sql)) + + // Check position. + assert(e.line.isDefined) + assert(e.line.get === line) + assert(e.startPosition.isDefined) + assert(e.startPosition.get === startPosition) + + // Check messages. + val error = e.getMessage + messages.foreach { message => + assert(error.contains(message)) + } + } + + test("no viable input") { + intercept("select from tbl", 1, 7, "no viable alternative at input", "-------^^^") + intercept("select\nfrom tbl", 2, 0, "no viable alternative at input", "^^^") + intercept("select ((r + 1) ", 1, 16, "no viable alternative at input", "----------------^^^") + } + + test("extraneous input") { + intercept("select 1 1", 1, 9, "extraneous input '1' expecting", "---------^^^") + intercept("select *\nfrom r as q t", 2, 12, "extraneous input", "------------^^^") + } + + test("mismatched input") { + intercept("select * from r order by q from t", 1, 27, + "mismatched input", + "---------------------------^^^") + intercept("select *\nfrom r\norder by q\nfrom t", 4, 0, "mismatched input", "^^^") + } + + test("semantic errors") { + intercept("select *\nfrom r\norder by q\ncluster by q", 3, 0, + "Combination of ORDER BY/SORT BY/DISTRIBUTE BY/CLUSTER BY is not supported", + "^^^") + intercept("select * from r where a in (select * from t)", 1, 24, + "IN with a Sub-query is currently not supported", + "------------------------^^^") + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala new file mode 100644 index 0000000000..a80d29ce5d --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala @@ -0,0 +1,497 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.parser + +import java.sql.{Date, Timestamp} + +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, _} +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.CalendarInterval + +/** + * Test basic expression parsing. If a type of expression is supported it should be tested here. + * + * Please note that some of the expressions test don't have to be sound expressions, only their + * structure needs to be valid. Unsound expressions should be caught by the Analyzer or + * CheckAnalysis classes. + */ +class ExpressionParserSuite extends PlanTest { + import CatalystSqlParser._ + import org.apache.spark.sql.catalyst.dsl.expressions._ + import org.apache.spark.sql.catalyst.dsl.plans._ + + def assertEqual(sqlCommand: String, e: Expression): Unit = { + compareExpressions(parseExpression(sqlCommand), e) + } + + def intercept(sqlCommand: String, messages: String*): Unit = { + val e = intercept[ParseException](parseExpression(sqlCommand)) + messages.foreach { message => + assert(e.message.contains(message)) + } + } + + test("star expressions") { + // Global Star + assertEqual("*", UnresolvedStar(None)) + + // Targeted Star + assertEqual("a.b.*", UnresolvedStar(Option(Seq("a", "b")))) + } + + // NamedExpression (Alias/Multialias) + test("named expressions") { + // No Alias + val r0 = 'a + assertEqual("a", r0) + + // Single Alias. + val r1 = 'a as "b" + assertEqual("a as b", r1) + assertEqual("a b", r1) + + // Multi-Alias + assertEqual("a as (b, c)", MultiAlias('a, Seq("b", "c"))) + assertEqual("a() (b, c)", MultiAlias('a.function(), Seq("b", "c"))) + + // Numeric literals without a space between the literal qualifier and the alias, should not be + // interpreted as such. An unresolved reference should be returned instead. + // TODO add the JIRA-ticket number. + assertEqual("1SL", Symbol("1SL")) + + // Aliased star is allowed. + assertEqual("a.* b", UnresolvedStar(Option(Seq("a"))) as 'b) + } + + test("binary logical expressions") { + // And + assertEqual("a and b", 'a && 'b) + + // Or + assertEqual("a or b", 'a || 'b) + + // Combination And/Or check precedence + assertEqual("a and b or c and d", ('a && 'b) || ('c && 'd)) + assertEqual("a or b or c and d", 'a || 'b || ('c && 'd)) + + // Multiple AND/OR get converted into a balanced tree + assertEqual("a or b or c or d or e or f", (('a || 'b) || 'c) || (('d || 'e) || 'f)) + assertEqual("a and b and c and d and e and f", (('a && 'b) && 'c) && (('d && 'e) && 'f)) + } + + test("long binary logical expressions") { + def testVeryBinaryExpression(op: String, clazz: Class[_]): Unit = { + val sql = (1 to 1000).map(x => s"$x == $x").mkString(op) + val e = parseExpression(sql) + assert(e.collect { case _: EqualTo => true }.size === 1000) + assert(e.collect { case x if clazz.isInstance(x) => true }.size === 999) + } + testVeryBinaryExpression(" AND ", classOf[And]) + testVeryBinaryExpression(" OR ", classOf[Or]) + } + + test("not expressions") { + assertEqual("not a", !'a) + assertEqual("!a", !'a) + assertEqual("not true > true", Not(GreaterThan(true, true))) + } + + test("exists expression") { + intercept("exists (select 1 from b where b.x = a.x)", "EXISTS clauses are not supported") + } + + test("comparison expressions") { + assertEqual("a = b", 'a === 'b) + assertEqual("a == b", 'a === 'b) + assertEqual("a <=> b", 'a <=> 'b) + assertEqual("a <> b", 'a =!= 'b) + assertEqual("a != b", 'a =!= 'b) + assertEqual("a < b", 'a < 'b) + assertEqual("a <= b", 'a <= 'b) + assertEqual("a > b", 'a > 'b) + assertEqual("a >= b", 'a >= 'b) + } + + test("between expressions") { + assertEqual("a between b and c", 'a >= 'b && 'a <= 'c) + assertEqual("a not between b and c", !('a >= 'b && 'a <= 'c)) + } + + test("in expressions") { + assertEqual("a in (b, c, d)", 'a in ('b, 'c, 'd)) + assertEqual("a not in (b, c, d)", !('a in ('b, 'c, 'd))) + } + + test("in sub-query") { + intercept("a in (select b from c)", "IN with a Sub-query is currently not supported") + } + + test("like expressions") { + assertEqual("a like 'pattern%'", 'a like "pattern%") + assertEqual("a not like 'pattern%'", !('a like "pattern%")) + assertEqual("a rlike 'pattern%'", 'a rlike "pattern%") + assertEqual("a not rlike 'pattern%'", !('a rlike "pattern%")) + assertEqual("a regexp 'pattern%'", 'a rlike "pattern%") + assertEqual("a not regexp 'pattern%'", !('a rlike "pattern%")) + } + + test("is null expressions") { + assertEqual("a is null", 'a.isNull) + assertEqual("a is not null", 'a.isNotNull) + assertEqual("a = b is null", ('a === 'b).isNull) + assertEqual("a = b is not null", ('a === 'b).isNotNull) + } + + test("binary arithmetic expressions") { + // Simple operations + assertEqual("a * b", 'a * 'b) + assertEqual("a / b", 'a / 'b) + assertEqual("a DIV b", ('a / 'b).cast(LongType)) + assertEqual("a % b", 'a % 'b) + assertEqual("a + b", 'a + 'b) + assertEqual("a - b", 'a - 'b) + assertEqual("a & b", 'a & 'b) + assertEqual("a ^ b", 'a ^ 'b) + assertEqual("a | b", 'a | 'b) + + // Check precedences + assertEqual( + "a * t | b ^ c & d - e + f % g DIV h / i * k", + 'a * 't | ('b ^ ('c & ('d - 'e + (('f % 'g / 'h).cast(LongType) / 'i * 'k))))) + } + + test("unary arithmetic expressions") { + assertEqual("+a", 'a) + assertEqual("-a", -'a) + assertEqual("~a", ~'a) + assertEqual("-+~~a", -(~(~'a))) + } + + test("cast expressions") { + // Note that DataType parsing is tested elsewhere. + assertEqual("cast(a as int)", 'a.cast(IntegerType)) + assertEqual("cast(a as timestamp)", 'a.cast(TimestampType)) + assertEqual("cast(a as array)", 'a.cast(ArrayType(IntegerType))) + assertEqual("cast(cast(a as int) as long)", 'a.cast(IntegerType).cast(LongType)) + } + + test("function expressions") { + assertEqual("foo()", 'foo.function()) + assertEqual("foo.bar()", Symbol("foo.bar").function()) + assertEqual("foo(*)", 'foo.function(star())) + assertEqual("count(*)", 'count.function(1)) + assertEqual("foo(a, b)", 'foo.function('a, 'b)) + assertEqual("foo(all a, b)", 'foo.function('a, 'b)) + assertEqual("foo(distinct a, b)", 'foo.distinctFunction('a, 'b)) + assertEqual("grouping(distinct a, b)", 'grouping.distinctFunction('a, 'b)) + assertEqual("`select`(all a, b)", 'select.function('a, 'b)) + } + + test("window function expressions") { + val func = 'foo.function(star()) + def windowed( + partitioning: Seq[Expression] = Seq.empty, + ordering: Seq[SortOrder] = Seq.empty, + frame: WindowFrame = UnspecifiedFrame): Expression = { + WindowExpression(func, WindowSpecDefinition(partitioning, ordering, frame)) + } + + // Basic window testing. + assertEqual("foo(*) over w1", UnresolvedWindowExpression(func, WindowSpecReference("w1"))) + assertEqual("foo(*) over ()", windowed()) + assertEqual("foo(*) over (partition by a, b)", windowed(Seq('a, 'b))) + assertEqual("foo(*) over (distribute by a, b)", windowed(Seq('a, 'b))) + assertEqual("foo(*) over (cluster by a, b)", windowed(Seq('a, 'b))) + assertEqual("foo(*) over (order by a desc, b asc)", windowed(Seq.empty, Seq('a.desc, 'b.asc ))) + assertEqual("foo(*) over (sort by a desc, b asc)", windowed(Seq.empty, Seq('a.desc, 'b.asc ))) + assertEqual("foo(*) over (partition by a, b order by c)", windowed(Seq('a, 'b), Seq('c.asc))) + assertEqual("foo(*) over (distribute by a, b sort by c)", windowed(Seq('a, 'b), Seq('c.asc))) + + // Test use of expressions in window functions. + assertEqual( + "sum(product + 1) over (partition by ((product) + (1)) order by 2)", + WindowExpression('sum.function('product + 1), + WindowSpecDefinition(Seq('product + 1), Seq(Literal(2).asc), UnspecifiedFrame))) + assertEqual( + "sum(product + 1) over (partition by ((product / 2) + 1) order by 2)", + WindowExpression('sum.function('product + 1), + WindowSpecDefinition(Seq('product / 2 + 1), Seq(Literal(2).asc), UnspecifiedFrame))) + + // Range/Row + val frameTypes = Seq(("rows", RowFrame), ("range", RangeFrame)) + val boundaries = Seq( + ("10 preceding", ValuePreceding(10), CurrentRow), + ("3 + 1 following", ValueFollowing(4), CurrentRow), // Will fail during analysis + ("unbounded preceding", UnboundedPreceding, CurrentRow), + ("unbounded following", UnboundedFollowing, CurrentRow), // Will fail during analysis + ("between unbounded preceding and current row", UnboundedPreceding, CurrentRow), + ("between unbounded preceding and unbounded following", + UnboundedPreceding, UnboundedFollowing), + ("between 10 preceding and current row", ValuePreceding(10), CurrentRow), + ("between current row and 5 following", CurrentRow, ValueFollowing(5)), + ("between 10 preceding and 5 following", ValuePreceding(10), ValueFollowing(5)) + ) + frameTypes.foreach { + case (frameTypeSql, frameType) => + boundaries.foreach { + case (boundarySql, begin, end) => + val query = s"foo(*) over (partition by a order by b $frameTypeSql $boundarySql)" + val expr = windowed(Seq('a), Seq('b.asc), SpecifiedWindowFrame(frameType, begin, end)) + assertEqual(query, expr) + } + } + + // We cannot use non integer constants. + intercept("foo(*) over (partition by a order by b rows 10.0 preceding)", + "Frame bound value must be a constant integer.") + + // We cannot use an arbitrary expression. + intercept("foo(*) over (partition by a order by b rows exp(b) preceding)", + "Frame bound value must be a constant integer.") + } + + test("row constructor") { + // Note that '(a)' will be interpreted as a nested expression. + assertEqual("(a, b)", CreateStruct(Seq('a, 'b))) + assertEqual("(a, b, c)", CreateStruct(Seq('a, 'b, 'c))) + } + + test("scalar sub-query") { + assertEqual( + "(select max(val) from tbl) > current", + ScalarSubquery(table("tbl").select('max.function('val))) > 'current) + assertEqual( + "a = (select b from s)", + 'a === ScalarSubquery(table("s").select('b))) + } + + test("case when") { + assertEqual("case a when 1 then b when 2 then c else d end", + CaseKeyWhen('a, Seq(1, 'b, 2, 'c, 'd))) + assertEqual("case when a = 1 then b when a = 2 then c else d end", + CaseWhen(Seq(('a === 1, 'b.expr), ('a === 2, 'c.expr)), 'd)) + } + + test("dereference") { + assertEqual("a.b", UnresolvedAttribute("a.b")) + assertEqual("`select`.b", UnresolvedAttribute("select.b")) + assertEqual("(a + b).b", ('a + 'b).getField("b")) // This will fail analysis. + assertEqual("struct(a, b).b", 'struct.function('a, 'b).getField("b")) + } + + test("reference") { + // Regular + assertEqual("a", 'a) + + // Starting with a digit. + assertEqual("1a", Symbol("1a")) + + // Quoted using a keyword. + assertEqual("`select`", 'select) + + // Unquoted using an unreserved keyword. + assertEqual("columns", 'columns) + } + + test("subscript") { + assertEqual("a[b]", 'a.getItem('b)) + assertEqual("a[1 + 1]", 'a.getItem(Literal(1) + 1)) + assertEqual("`c`.a[b]", UnresolvedAttribute("c.a").getItem('b)) + } + + test("parenthesis") { + assertEqual("(a)", 'a) + assertEqual("r * (a + b)", 'r * ('a + 'b)) + } + + test("type constructors") { + // Dates. + assertEqual("dAte '2016-03-11'", Literal(Date.valueOf("2016-03-11"))) + intercept[IllegalArgumentException] { + parseExpression("DAtE 'mar 11 2016'") + } + + // Timestamps. + assertEqual("tImEstAmp '2016-03-11 20:54:00.000'", + Literal(Timestamp.valueOf("2016-03-11 20:54:00.000"))) + intercept[IllegalArgumentException] { + parseExpression("timestamP '2016-33-11 20:54:00.000'") + } + + // Unsupported datatype. + intercept("GEO '(10,-6)'", "Literals of type 'GEO' are currently not supported.") + } + + test("literals") { + // NULL + assertEqual("null", Literal(null)) + + // Boolean + assertEqual("trUe", Literal(true)) + assertEqual("False", Literal(false)) + + // Integral should have the narrowest possible type + assertEqual("787324", Literal(787324)) + assertEqual("7873247234798249234", Literal(7873247234798249234L)) + assertEqual("78732472347982492793712334", + Literal(BigDecimal("78732472347982492793712334").underlying())) + + // Decimal + assertEqual("7873247234798249279371.2334", + Literal(BigDecimal("7873247234798249279371.2334").underlying())) + + // Scientific Decimal + assertEqual("9.0e1", 90d) + assertEqual(".9e+2", 90d) + assertEqual("0.9e+2", 90d) + assertEqual("900e-1", 90d) + assertEqual("900.0E-1", 90d) + assertEqual("9.e+1", 90d) + intercept(".e3") + + // Tiny Int Literal + assertEqual("10Y", Literal(10.toByte)) + intercept("-1000Y") + + // Small Int Literal + assertEqual("10S", Literal(10.toShort)) + intercept("40000S") + + // Long Int Literal + assertEqual("10L", Literal(10L)) + intercept("78732472347982492793712334L") + + // Double Literal + assertEqual("10.0D", Literal(10.0D)) + // TODO we need to figure out if we should throw an exception here! + assertEqual("1E309", Literal(Double.PositiveInfinity)) + } + + test("strings") { + // Single Strings. + assertEqual("\"hello\"", "hello") + assertEqual("'hello'", "hello") + + // Multi-Strings. + assertEqual("\"hello\" 'world'", "helloworld") + assertEqual("'hello' \" \" 'world'", "hello world") + + // 'LIKE' string literals. Notice that an escaped '%' is the same as an escaped '\' and a + // regular '%'; to get the correct result you need to add another escaped '\'. + // TODO figure out if we shouldn't change the ParseUtils.unescapeSQLString method? + assertEqual("'pattern%'", "pattern%") + assertEqual("'no-pattern\\%'", "no-pattern\\%") + assertEqual("'pattern\\\\%'", "pattern\\%") + assertEqual("'pattern\\\\\\%'", "pattern\\\\%") + + // Escaped characters. + // See: http://dev.mysql.com/doc/refman/5.7/en/string-literals.html + assertEqual("'\\0'", "\u0000") // ASCII NUL (X'00') + assertEqual("'\\''", "\'") // Single quote + assertEqual("'\\\"'", "\"") // Double quote + assertEqual("'\\b'", "\b") // Backspace + assertEqual("'\\n'", "\n") // Newline + assertEqual("'\\r'", "\r") // Carriage return + assertEqual("'\\t'", "\t") // Tab character + assertEqual("'\\Z'", "\u001A") // ASCII 26 - CTRL + Z (EOF on windows) + + // Octals + assertEqual("'\\110\\145\\154\\154\\157\\041'", "Hello!") + + // Unicode + assertEqual("'\\u0087\\u0111\\u0114\\u0108\\u0100\\u0032\\u0058\\u0041'", "World :)") + } + + test("intervals") { + def intervalLiteral(u: String, s: String): Literal = { + Literal(CalendarInterval.fromSingleUnitString(u, s)) + } + + // Empty interval statement + intercept("interval", "at least one time unit should be given for interval literal") + + // Single Intervals. + val units = Seq( + "year", + "month", + "week", + "day", + "hour", + "minute", + "second", + "millisecond", + "microsecond") + val forms = Seq("", "s") + val values = Seq("0", "10", "-7", "21") + units.foreach { unit => + forms.foreach { form => + values.foreach { value => + val expected = intervalLiteral(unit, value) + assertEqual(s"interval $value $unit$form", expected) + assertEqual(s"interval '$value' $unit$form", expected) + } + } + } + + // Hive nanosecond notation. + assertEqual("interval 13.123456789 seconds", intervalLiteral("second", "13.123456789")) + assertEqual("interval -13.123456789 second", intervalLiteral("second", "-13.123456789")) + + // Non Existing unit + intercept("interval 10 nanoseconds", "No interval can be constructed") + + // Year-Month intervals. + val yearMonthValues = Seq("123-10", "496-0", "-2-3", "-123-0") + yearMonthValues.foreach { value => + val result = Literal(CalendarInterval.fromYearMonthString(value)) + assertEqual(s"interval '$value' year to month", result) + } + + // Day-Time intervals. + val datTimeValues = Seq( + "99 11:22:33.123456789", + "-99 11:22:33.123456789", + "10 9:8:7.123456789", + "1 0:0:0", + "-1 0:0:0", + "1 0:0:1") + datTimeValues.foreach { value => + val result = Literal(CalendarInterval.fromDayTimeString(value)) + assertEqual(s"interval '$value' day to second", result) + } + + // Unknown FROM TO intervals + intercept("interval 10 month to second", "Intervals FROM month TO second are not supported.") + + // Composed intervals. + assertEqual( + "interval 3 months 22 seconds 1 millisecond", + Literal(new CalendarInterval(3, 22001000L))) + assertEqual( + "interval 3 years '-1-10' year to month 3 weeks '1 0:0:2' day to second", + Literal(new CalendarInterval(14, + 22 * CalendarInterval.MICROS_PER_DAY + 2 * CalendarInterval.MICROS_PER_SECOND))) + } + + test("composed expressions") { + assertEqual("1 + r.r As q", (Literal(1) + UnresolvedAttribute("r.r")).as("q")) + assertEqual("1 - f('o', o(bar))", Literal(1) - 'f.function("o", 'o.function('bar))) + intercept("1 - f('o', o(bar)) hello * world", "mismatched input '*'") + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala new file mode 100644 index 0000000000..23f05ce846 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala @@ -0,0 +1,429 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.parser + +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans._ +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.types.{BooleanType, IntegerType} + +class PlanParserSuite extends PlanTest { + import CatalystSqlParser._ + import org.apache.spark.sql.catalyst.dsl.expressions._ + import org.apache.spark.sql.catalyst.dsl.plans._ + + def assertEqual(sqlCommand: String, plan: LogicalPlan): Unit = { + comparePlans(parsePlan(sqlCommand), plan) + } + + def intercept(sqlCommand: String, messages: String*): Unit = { + val e = intercept[ParseException](parsePlan(sqlCommand)) + messages.foreach { message => + assert(e.message.contains(message)) + } + } + + test("case insensitive") { + val plan = table("a").select(star()) + assertEqual("sELEct * FroM a", plan) + assertEqual("select * fRoM a", plan) + assertEqual("SELECT * FROM a", plan) + } + + test("show functions") { + assertEqual("show functions", ShowFunctions(None, None)) + assertEqual("show functions foo", ShowFunctions(None, Some("foo"))) + assertEqual("show functions foo.bar", ShowFunctions(Some("foo"), Some("bar"))) + assertEqual("show functions 'foo\\\\.*'", ShowFunctions(None, Some("foo\\.*"))) + intercept("show functions foo.bar.baz", "SHOW FUNCTIONS unsupported name") + } + + test("describe function") { + assertEqual("describe function bar", DescribeFunction("bar", isExtended = false)) + assertEqual("describe function extended bar", DescribeFunction("bar", isExtended = true)) + assertEqual("describe function foo.bar", DescribeFunction("foo.bar", isExtended = false)) + assertEqual("describe function extended f.bar", DescribeFunction("f.bar", isExtended = true)) + } + + test("set operations") { + val a = table("a").select(star()) + val b = table("b").select(star()) + + assertEqual("select * from a union select * from b", Distinct(a.union(b))) + assertEqual("select * from a union distinct select * from b", Distinct(a.union(b))) + assertEqual("select * from a union all select * from b", a.union(b)) + assertEqual("select * from a except select * from b", a.except(b)) + intercept("select * from a except all select * from b", "EXCEPT ALL is not supported.") + assertEqual("select * from a except distinct select * from b", a.except(b)) + assertEqual("select * from a intersect select * from b", a.intersect(b)) + intercept("select * from a intersect all select * from b", "INTERSECT ALL is not supported.") + assertEqual("select * from a intersect distinct select * from b", a.intersect(b)) + } + + test("common table expressions") { + def cte(plan: LogicalPlan, namedPlans: (String, LogicalPlan)*): With = { + val ctes = namedPlans.map { + case (name, cte) => + name -> SubqueryAlias(name, cte) + }.toMap + With(plan, ctes) + } + assertEqual( + "with cte1 as (select * from a) select * from cte1", + cte(table("cte1").select(star()), "cte1" -> table("a").select(star()))) + assertEqual( + "with cte1 (select 1) select * from cte1", + cte(table("cte1").select(star()), "cte1" -> OneRowRelation.select(1))) + assertEqual( + "with cte1 (select 1), cte2 as (select * from cte1) select * from cte2", + cte(table("cte2").select(star()), + "cte1" -> OneRowRelation.select(1), + "cte2" -> table("cte1").select(star()))) + intercept( + "with cte1 (select 1), cte1 as (select 1 from cte1) select * from cte1", + "Name 'cte1' is used for multiple common table expressions") + } + + test("simple select query") { + assertEqual("select 1", OneRowRelation.select(1)) + assertEqual("select a, b", OneRowRelation.select('a, 'b)) + assertEqual("select a, b from db.c", table("db", "c").select('a, 'b)) + assertEqual("select a, b from db.c where x < 1", table("db", "c").where('x < 1).select('a, 'b)) + assertEqual( + "select a, b from db.c having x < 1", + table("db", "c").select('a, 'b).where(('x < 1).cast(BooleanType))) + assertEqual("select distinct a, b from db.c", Distinct(table("db", "c").select('a, 'b))) + assertEqual("select all a, b from db.c", table("db", "c").select('a, 'b)) + } + + test("reverse select query") { + assertEqual("from a", table("a")) + assertEqual("from a select b, c", table("a").select('b, 'c)) + assertEqual( + "from db.a select b, c where d < 1", table("db", "a").where('d < 1).select('b, 'c)) + assertEqual("from a select distinct b, c", Distinct(table("a").select('b, 'c))) + assertEqual( + "from (from a union all from b) c select *", + table("a").union(table("b")).as("c").select(star())) + } + + test("transform query spec") { + val p = ScriptTransformation(Seq('a, 'b), "func", Seq.empty, table("e"), null) + assertEqual("select transform(a, b) using 'func' from e where f < 10", + p.copy(child = p.child.where('f < 10), output = Seq('key.string, 'value.string))) + assertEqual("map a, b using 'func' as c, d from e", + p.copy(output = Seq('c.string, 'd.string))) + assertEqual("reduce a, b using 'func' as (c: int, d decimal(10, 0)) from e", + p.copy(output = Seq('c.int, 'd.decimal(10, 0)))) + } + + test("multi select query") { + assertEqual( + "from a select * select * where s < 10", + table("a").select(star()).union(table("a").where('s < 10).select(star()))) + intercept( + "from a select * select * from x where a.s < 10", + "Multi-Insert queries cannot have a FROM clause in their individual SELECT statements") + assertEqual( + "from a insert into tbl1 select * insert into tbl2 select * where s < 10", + table("a").select(star()).insertInto("tbl1").union( + table("a").where('s < 10).select(star()).insertInto("tbl2"))) + } + + test("query organization") { + // Test all valid combinations of order by/sort by/distribute by/cluster by/limit/windows + val baseSql = "select * from t" + val basePlan = table("t").select(star()) + + val ws = Map("w1" -> WindowSpecDefinition(Seq.empty, Seq.empty, UnspecifiedFrame)) + val limitWindowClauses = Seq( + ("", (p: LogicalPlan) => p), + (" limit 10", (p: LogicalPlan) => p.limit(10)), + (" window w1 as ()", (p: LogicalPlan) => WithWindowDefinition(ws, p)), + (" window w1 as () limit 10", (p: LogicalPlan) => WithWindowDefinition(ws, p).limit(10)) + ) + + val orderSortDistrClusterClauses = Seq( + ("", basePlan), + (" order by a, b desc", basePlan.orderBy('a.asc, 'b.desc)), + (" sort by a, b desc", basePlan.sortBy('a.asc, 'b.desc)), + (" distribute by a, b", basePlan.distribute('a, 'b)), + (" distribute by a sort by b", basePlan.distribute('a).sortBy('b.asc)), + (" cluster by a, b", basePlan.distribute('a, 'b).sortBy('a.asc, 'b.asc)) + ) + + orderSortDistrClusterClauses.foreach { + case (s1, p1) => + limitWindowClauses.foreach { + case (s2, pf2) => + assertEqual(baseSql + s1 + s2, pf2(p1)) + } + } + + val msg = "Combination of ORDER BY/SORT BY/DISTRIBUTE BY/CLUSTER BY is not supported" + intercept(s"$baseSql order by a sort by a", msg) + intercept(s"$baseSql cluster by a distribute by a", msg) + intercept(s"$baseSql order by a cluster by a", msg) + intercept(s"$baseSql order by a distribute by a", msg) + } + + test("insert into") { + val sql = "select * from t" + val plan = table("t").select(star()) + def insert( + partition: Map[String, Option[String]], + overwrite: Boolean = false, + ifNotExists: Boolean = false): LogicalPlan = + InsertIntoTable(table("s"), partition, plan, overwrite, ifNotExists) + + // Single inserts + assertEqual(s"insert overwrite table s $sql", + insert(Map.empty, overwrite = true)) + assertEqual(s"insert overwrite table s if not exists $sql", + insert(Map.empty, overwrite = true, ifNotExists = true)) + assertEqual(s"insert into s $sql", + insert(Map.empty)) + assertEqual(s"insert into table s partition (c = 'd', e = 1) $sql", + insert(Map("c" -> Option("d"), "e" -> Option("1")))) + assertEqual(s"insert overwrite table s partition (c = 'd', x) if not exists $sql", + insert(Map("c" -> Option("d"), "x" -> None), overwrite = true, ifNotExists = true)) + + // Multi insert + val plan2 = table("t").where('x > 5).select(star()) + assertEqual("from t insert into s select * limit 1 insert into u select * where x > 5", + InsertIntoTable( + table("s"), Map.empty, plan.limit(1), overwrite = false, ifNotExists = false).union( + InsertIntoTable( + table("u"), Map.empty, plan2, overwrite = false, ifNotExists = false))) + } + + test("aggregation") { + val sql = "select a, b, sum(c) as c from d group by a, b" + + // Normal + assertEqual(sql, table("d").groupBy('a, 'b)('a, 'b, 'sum.function('c).as("c"))) + + // Cube + assertEqual(s"$sql with cube", + table("d").groupBy(Cube(Seq('a, 'b)))('a, 'b, 'sum.function('c).as("c"))) + + // Rollup + assertEqual(s"$sql with rollup", + table("d").groupBy(Rollup(Seq('a, 'b)))('a, 'b, 'sum.function('c).as("c"))) + + // Grouping Sets + assertEqual(s"$sql grouping sets((a, b), (a), ())", + GroupingSets(Seq(0, 1, 3), Seq('a, 'b), table("d"), Seq('a, 'b, 'sum.function('c).as("c")))) + intercept(s"$sql grouping sets((a, b), (c), ())", + "c doesn't show up in the GROUP BY list") + } + + test("limit") { + val sql = "select * from t" + val plan = table("t").select(star()) + assertEqual(s"$sql limit 10", plan.limit(10)) + assertEqual(s"$sql limit cast(9 / 4 as int)", plan.limit(Cast(Literal(9) / 4, IntegerType))) + } + + test("window spec") { + // Note that WindowSpecs are testing in the ExpressionParserSuite + val sql = "select * from t" + val plan = table("t").select(star()) + val spec = WindowSpecDefinition(Seq('a, 'b), Seq('c.asc), + SpecifiedWindowFrame(RowFrame, ValuePreceding(1), ValueFollowing(1))) + + // Test window resolution. + val ws1 = Map("w1" -> spec, "w2" -> spec, "w3" -> spec) + assertEqual( + s"""$sql + |window w1 as (partition by a, b order by c rows between 1 preceding and 1 following), + | w2 as w1, + | w3 as w1""".stripMargin, + WithWindowDefinition(ws1, plan)) + + // Fail with no reference. + intercept(s"$sql window w2 as w1", "Cannot resolve window reference 'w1'") + + // Fail when resolved reference is not a window spec. + intercept( + s"""$sql + |window w1 as (partition by a, b order by c rows between 1 preceding and 1 following), + | w2 as w1, + | w3 as w2""".stripMargin, + "Window reference 'w2' is not a window specification" + ) + } + + test("lateral view") { + // Single lateral view + assertEqual( + "select * from t lateral view explode(x) expl as x", + table("t") + .generate(Explode('x), join = true, outer = false, Some("expl"), Seq("x")) + .select(star())) + + // Multiple lateral views + assertEqual( + """select * + |from t + |lateral view explode(x) expl + |lateral view outer json_tuple(x, y) jtup q, z""".stripMargin, + table("t") + .generate(Explode('x), join = true, outer = false, Some("expl"), Seq.empty) + .generate(JsonTuple(Seq('x, 'y)), join = true, outer = true, Some("jtup"), Seq("q", "z")) + .select(star())) + + // Multi-Insert lateral views. + val from = table("t1").generate(Explode('x), join = true, outer = false, Some("expl"), Seq("x")) + assertEqual( + """from t1 + |lateral view explode(x) expl as x + |insert into t2 + |select * + |lateral view json_tuple(x, y) jtup q, z + |insert into t3 + |select * + |where s < 10 + """.stripMargin, + Union(from + .generate(JsonTuple(Seq('x, 'y)), join = true, outer = false, Some("jtup"), Seq("q", "z")) + .select(star()) + .insertInto("t2"), + from.where('s < 10).select(star()).insertInto("t3"))) + + // Unsupported generator. + intercept( + "select * from t lateral view posexplode(x) posexpl as x, y", + "Generator function 'posexplode' is not supported") + } + + test("joins") { + // Test single joins. + val testUnconditionalJoin = (sql: String, jt: JoinType) => { + assertEqual( + s"select * from t as tt $sql u", + table("t").as("tt").join(table("u"), jt, None).select(star())) + } + val testConditionalJoin = (sql: String, jt: JoinType) => { + assertEqual( + s"select * from t $sql u as uu on a = b", + table("t").join(table("u").as("uu"), jt, Option('a === 'b)).select(star())) + } + val testNaturalJoin = (sql: String, jt: JoinType) => { + assertEqual( + s"select * from t tt natural $sql u as uu", + table("t").as("tt").join(table("u").as("uu"), NaturalJoin(jt), None).select(star())) + } + val testUsingJoin = (sql: String, jt: JoinType) => { + assertEqual( + s"select * from t $sql u using(a, b)", + table("t").join(table("u"), UsingJoin(jt, Seq('a.attr, 'b.attr)), None).select(star())) + } + val testAll = Seq(testUnconditionalJoin, testConditionalJoin, testNaturalJoin, testUsingJoin) + + def test(sql: String, jt: JoinType, tests: Seq[(String, JoinType) => Unit]): Unit = { + tests.foreach(_(sql, jt)) + } + test("cross join", Inner, Seq(testUnconditionalJoin)) + test(",", Inner, Seq(testUnconditionalJoin)) + test("join", Inner, testAll) + test("inner join", Inner, testAll) + test("left join", LeftOuter, testAll) + test("left outer join", LeftOuter, testAll) + test("right join", RightOuter, testAll) + test("right outer join", RightOuter, testAll) + test("full join", FullOuter, testAll) + test("full outer join", FullOuter, testAll) + + // Test multiple consecutive joins + assertEqual( + "select * from a join b join c right join d", + table("a").join(table("b")).join(table("c")).join(table("d"), RightOuter).select(star())) + } + + test("sampled relations") { + val sql = "select * from t" + assertEqual(s"$sql tablesample(100 rows)", + table("t").limit(100).select(star())) + assertEqual(s"$sql tablesample(43 percent) as x", + Sample(0, .43d, withReplacement = false, 10L, table("t").as("x"))(true).select(star())) + assertEqual(s"$sql tablesample(bucket 4 out of 10) as x", + Sample(0, .4d, withReplacement = false, 10L, table("t").as("x"))(true).select(star())) + intercept(s"$sql tablesample(bucket 4 out of 10 on x) as x", + "TABLESAMPLE(BUCKET x OUT OF y ON id) is not supported") + intercept(s"$sql tablesample(bucket 11 out of 10) as x", + s"Sampling fraction (${11.0/10.0}) must be on interval [0, 1]") + } + + test("sub-query") { + val plan = table("t0").select('id) + assertEqual("select id from (t0)", plan) + assertEqual("select id from ((((((t0))))))", plan) + assertEqual( + "(select * from t1) union distinct (select * from t2)", + Distinct(table("t1").select(star()).union(table("t2").select(star())))) + assertEqual( + "select * from ((select * from t1) union (select * from t2)) t", + Distinct( + table("t1").select(star()).union(table("t2").select(star()))).as("t").select(star())) + assertEqual( + """select id + |from (((select id from t0) + | union all + | (select id from t0)) + | union all + | (select id from t0)) as u_1 + """.stripMargin, + plan.union(plan).union(plan).as("u_1").select('id)) + } + + test("scalar sub-query") { + assertEqual( + "select (select max(b) from s) ss from t", + table("t").select(ScalarSubquery(table("s").select('max.function('b))).as("ss"))) + assertEqual( + "select * from t where a = (select b from s)", + table("t").where('a === ScalarSubquery(table("s").select('b))).select(star())) + assertEqual( + "select g from t group by g having a > (select b from s)", + table("t") + .groupBy('g)('g) + .where(('a > ScalarSubquery(table("s").select('b))).cast(BooleanType))) + } + + test("table reference") { + assertEqual("table t", table("t")) + assertEqual("table d.t", table("d", "t")) + } + + test("inline table") { + assertEqual("values 1, 2, 3, 4", LocalRelation.fromExternalRows( + Seq('col1.int), + Seq(1, 2, 3, 4).map(x => Row(x)))) + assertEqual( + "values (1, 'a'), (2, 'b'), (3, 'c') as tbl(a, b)", + LocalRelation.fromExternalRows( + Seq('a.int, 'b.string), + Seq((1, "a"), (2, "b"), (3, "c")).map(x => Row(x._1, x._2))).as("tbl")) + intercept("values (a, 'a'), (b, 'b')", + "All expressions in an inline table must be constants.") + intercept("values (1, 'a'), (2, 'b') as tbl(a, b, c)", + "Number of aliases must match the number of fields in an inline table.") + intercept[ArrayIndexOutOfBoundsException](parsePlan("values (1, 'a'), (2, 'b', 5Y)")) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableIdentifierParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableIdentifierParserSuite.scala new file mode 100644 index 0000000000..297b1931a9 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableIdentifierParserSuite.scala @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.parser + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.TableIdentifier + +class TableIdentifierParserSuite extends SparkFunSuite { + import CatalystSqlParser._ + + test("table identifier") { + // Regular names. + assert(TableIdentifier("q") === parseTableIdentifier("q")) + assert(TableIdentifier("q", Option("d")) === parseTableIdentifier("d.q")) + + // Illegal names. + intercept[ParseException](parseTableIdentifier("")) + intercept[ParseException](parseTableIdentifier("d.q.g")) + + // SQL Keywords. + val keywords = Seq("select", "from", "where", "left", "right") + keywords.foreach { keyword => + intercept[ParseException](parseTableIdentifier(keyword)) + assert(TableIdentifier(keyword) === parseTableIdentifier(s"`$keyword`")) + assert(TableIdentifier(keyword, Option("db")) === parseTableIdentifier(s"db.`$keyword`")) + } + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ng/ErrorParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ng/ErrorParserSuite.scala deleted file mode 100644 index 1963fc368f..0000000000 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ng/ErrorParserSuite.scala +++ /dev/null @@ -1,67 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.sql.catalyst.parser.ng - -import org.apache.spark.SparkFunSuite - -/** - * Test various parser errors. - */ -class ErrorParserSuite extends SparkFunSuite { - def intercept(sql: String, line: Int, startPosition: Int, messages: String*): Unit = { - val e = intercept[ParseException](CatalystSqlParser.parsePlan(sql)) - - // Check position. - assert(e.line.isDefined) - assert(e.line.get === line) - assert(e.startPosition.isDefined) - assert(e.startPosition.get === startPosition) - - // Check messages. - val error = e.getMessage - messages.foreach { message => - assert(error.contains(message)) - } - } - - test("no viable input") { - intercept("select from tbl", 1, 7, "no viable alternative at input", "-------^^^") - intercept("select\nfrom tbl", 2, 0, "no viable alternative at input", "^^^") - intercept("select ((r + 1) ", 1, 16, "no viable alternative at input", "----------------^^^") - } - - test("extraneous input") { - intercept("select 1 1", 1, 9, "extraneous input '1' expecting", "---------^^^") - intercept("select *\nfrom r as q t", 2, 12, "extraneous input", "------------^^^") - } - - test("mismatched input") { - intercept("select * from r order by q from t", 1, 27, - "mismatched input", - "---------------------------^^^") - intercept("select *\nfrom r\norder by q\nfrom t", 4, 0, "mismatched input", "^^^") - } - - test("semantic errors") { - intercept("select *\nfrom r\norder by q\ncluster by q", 3, 0, - "Combination of ORDER BY/SORT BY/DISTRIBUTE BY/CLUSTER BY is not supported", - "^^^") - intercept("select * from r where a in (select * from t)", 1, 24, - "IN with a Sub-query is currently not supported", - "------------------------^^^") - } -} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ng/ExpressionParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ng/ExpressionParserSuite.scala deleted file mode 100644 index 32311a5a66..0000000000 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ng/ExpressionParserSuite.scala +++ /dev/null @@ -1,497 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.sql.catalyst.parser.ng - -import java.sql.{Date, Timestamp} - -import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, _} -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.PlanTest -import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.CalendarInterval - -/** - * Test basic expression parsing. If a type of expression is supported it should be tested here. - * - * Please note that some of the expressions test don't have to be sound expressions, only their - * structure needs to be valid. Unsound expressions should be caught by the Analyzer or - * CheckAnalysis classes. - */ -class ExpressionParserSuite extends PlanTest { - import CatalystSqlParser._ - import org.apache.spark.sql.catalyst.dsl.expressions._ - import org.apache.spark.sql.catalyst.dsl.plans._ - - def assertEqual(sqlCommand: String, e: Expression): Unit = { - compareExpressions(parseExpression(sqlCommand), e) - } - - def intercept(sqlCommand: String, messages: String*): Unit = { - val e = intercept[ParseException](parseExpression(sqlCommand)) - messages.foreach { message => - assert(e.message.contains(message)) - } - } - - test("star expressions") { - // Global Star - assertEqual("*", UnresolvedStar(None)) - - // Targeted Star - assertEqual("a.b.*", UnresolvedStar(Option(Seq("a", "b")))) - } - - // NamedExpression (Alias/Multialias) - test("named expressions") { - // No Alias - val r0 = 'a - assertEqual("a", r0) - - // Single Alias. - val r1 = 'a as "b" - assertEqual("a as b", r1) - assertEqual("a b", r1) - - // Multi-Alias - assertEqual("a as (b, c)", MultiAlias('a, Seq("b", "c"))) - assertEqual("a() (b, c)", MultiAlias('a.function(), Seq("b", "c"))) - - // Numeric literals without a space between the literal qualifier and the alias, should not be - // interpreted as such. An unresolved reference should be returned instead. - // TODO add the JIRA-ticket number. - assertEqual("1SL", Symbol("1SL")) - - // Aliased star is allowed. - assertEqual("a.* b", UnresolvedStar(Option(Seq("a"))) as 'b) - } - - test("binary logical expressions") { - // And - assertEqual("a and b", 'a && 'b) - - // Or - assertEqual("a or b", 'a || 'b) - - // Combination And/Or check precedence - assertEqual("a and b or c and d", ('a && 'b) || ('c && 'd)) - assertEqual("a or b or c and d", 'a || 'b || ('c && 'd)) - - // Multiple AND/OR get converted into a balanced tree - assertEqual("a or b or c or d or e or f", (('a || 'b) || 'c) || (('d || 'e) || 'f)) - assertEqual("a and b and c and d and e and f", (('a && 'b) && 'c) && (('d && 'e) && 'f)) - } - - test("long binary logical expressions") { - def testVeryBinaryExpression(op: String, clazz: Class[_]): Unit = { - val sql = (1 to 1000).map(x => s"$x == $x").mkString(op) - val e = parseExpression(sql) - assert(e.collect { case _: EqualTo => true }.size === 1000) - assert(e.collect { case x if clazz.isInstance(x) => true }.size === 999) - } - testVeryBinaryExpression(" AND ", classOf[And]) - testVeryBinaryExpression(" OR ", classOf[Or]) - } - - test("not expressions") { - assertEqual("not a", !'a) - assertEqual("!a", !'a) - assertEqual("not true > true", Not(GreaterThan(true, true))) - } - - test("exists expression") { - intercept("exists (select 1 from b where b.x = a.x)", "EXISTS clauses are not supported") - } - - test("comparison expressions") { - assertEqual("a = b", 'a === 'b) - assertEqual("a == b", 'a === 'b) - assertEqual("a <=> b", 'a <=> 'b) - assertEqual("a <> b", 'a =!= 'b) - assertEqual("a != b", 'a =!= 'b) - assertEqual("a < b", 'a < 'b) - assertEqual("a <= b", 'a <= 'b) - assertEqual("a > b", 'a > 'b) - assertEqual("a >= b", 'a >= 'b) - } - - test("between expressions") { - assertEqual("a between b and c", 'a >= 'b && 'a <= 'c) - assertEqual("a not between b and c", !('a >= 'b && 'a <= 'c)) - } - - test("in expressions") { - assertEqual("a in (b, c, d)", 'a in ('b, 'c, 'd)) - assertEqual("a not in (b, c, d)", !('a in ('b, 'c, 'd))) - } - - test("in sub-query") { - intercept("a in (select b from c)", "IN with a Sub-query is currently not supported") - } - - test("like expressions") { - assertEqual("a like 'pattern%'", 'a like "pattern%") - assertEqual("a not like 'pattern%'", !('a like "pattern%")) - assertEqual("a rlike 'pattern%'", 'a rlike "pattern%") - assertEqual("a not rlike 'pattern%'", !('a rlike "pattern%")) - assertEqual("a regexp 'pattern%'", 'a rlike "pattern%") - assertEqual("a not regexp 'pattern%'", !('a rlike "pattern%")) - } - - test("is null expressions") { - assertEqual("a is null", 'a.isNull) - assertEqual("a is not null", 'a.isNotNull) - assertEqual("a = b is null", ('a === 'b).isNull) - assertEqual("a = b is not null", ('a === 'b).isNotNull) - } - - test("binary arithmetic expressions") { - // Simple operations - assertEqual("a * b", 'a * 'b) - assertEqual("a / b", 'a / 'b) - assertEqual("a DIV b", ('a / 'b).cast(LongType)) - assertEqual("a % b", 'a % 'b) - assertEqual("a + b", 'a + 'b) - assertEqual("a - b", 'a - 'b) - assertEqual("a & b", 'a & 'b) - assertEqual("a ^ b", 'a ^ 'b) - assertEqual("a | b", 'a | 'b) - - // Check precedences - assertEqual( - "a * t | b ^ c & d - e + f % g DIV h / i * k", - 'a * 't | ('b ^ ('c & ('d - 'e + (('f % 'g / 'h).cast(LongType) / 'i * 'k))))) - } - - test("unary arithmetic expressions") { - assertEqual("+a", 'a) - assertEqual("-a", -'a) - assertEqual("~a", ~'a) - assertEqual("-+~~a", -(~(~'a))) - } - - test("cast expressions") { - // Note that DataType parsing is tested elsewhere. - assertEqual("cast(a as int)", 'a.cast(IntegerType)) - assertEqual("cast(a as timestamp)", 'a.cast(TimestampType)) - assertEqual("cast(a as array)", 'a.cast(ArrayType(IntegerType))) - assertEqual("cast(cast(a as int) as long)", 'a.cast(IntegerType).cast(LongType)) - } - - test("function expressions") { - assertEqual("foo()", 'foo.function()) - assertEqual("foo.bar()", Symbol("foo.bar").function()) - assertEqual("foo(*)", 'foo.function(star())) - assertEqual("count(*)", 'count.function(1)) - assertEqual("foo(a, b)", 'foo.function('a, 'b)) - assertEqual("foo(all a, b)", 'foo.function('a, 'b)) - assertEqual("foo(distinct a, b)", 'foo.distinctFunction('a, 'b)) - assertEqual("grouping(distinct a, b)", 'grouping.distinctFunction('a, 'b)) - assertEqual("`select`(all a, b)", 'select.function('a, 'b)) - } - - test("window function expressions") { - val func = 'foo.function(star()) - def windowed( - partitioning: Seq[Expression] = Seq.empty, - ordering: Seq[SortOrder] = Seq.empty, - frame: WindowFrame = UnspecifiedFrame): Expression = { - WindowExpression(func, WindowSpecDefinition(partitioning, ordering, frame)) - } - - // Basic window testing. - assertEqual("foo(*) over w1", UnresolvedWindowExpression(func, WindowSpecReference("w1"))) - assertEqual("foo(*) over ()", windowed()) - assertEqual("foo(*) over (partition by a, b)", windowed(Seq('a, 'b))) - assertEqual("foo(*) over (distribute by a, b)", windowed(Seq('a, 'b))) - assertEqual("foo(*) over (cluster by a, b)", windowed(Seq('a, 'b))) - assertEqual("foo(*) over (order by a desc, b asc)", windowed(Seq.empty, Seq('a.desc, 'b.asc ))) - assertEqual("foo(*) over (sort by a desc, b asc)", windowed(Seq.empty, Seq('a.desc, 'b.asc ))) - assertEqual("foo(*) over (partition by a, b order by c)", windowed(Seq('a, 'b), Seq('c.asc))) - assertEqual("foo(*) over (distribute by a, b sort by c)", windowed(Seq('a, 'b), Seq('c.asc))) - - // Test use of expressions in window functions. - assertEqual( - "sum(product + 1) over (partition by ((product) + (1)) order by 2)", - WindowExpression('sum.function('product + 1), - WindowSpecDefinition(Seq('product + 1), Seq(Literal(2).asc), UnspecifiedFrame))) - assertEqual( - "sum(product + 1) over (partition by ((product / 2) + 1) order by 2)", - WindowExpression('sum.function('product + 1), - WindowSpecDefinition(Seq('product / 2 + 1), Seq(Literal(2).asc), UnspecifiedFrame))) - - // Range/Row - val frameTypes = Seq(("rows", RowFrame), ("range", RangeFrame)) - val boundaries = Seq( - ("10 preceding", ValuePreceding(10), CurrentRow), - ("3 + 1 following", ValueFollowing(4), CurrentRow), // Will fail during analysis - ("unbounded preceding", UnboundedPreceding, CurrentRow), - ("unbounded following", UnboundedFollowing, CurrentRow), // Will fail during analysis - ("between unbounded preceding and current row", UnboundedPreceding, CurrentRow), - ("between unbounded preceding and unbounded following", - UnboundedPreceding, UnboundedFollowing), - ("between 10 preceding and current row", ValuePreceding(10), CurrentRow), - ("between current row and 5 following", CurrentRow, ValueFollowing(5)), - ("between 10 preceding and 5 following", ValuePreceding(10), ValueFollowing(5)) - ) - frameTypes.foreach { - case (frameTypeSql, frameType) => - boundaries.foreach { - case (boundarySql, begin, end) => - val query = s"foo(*) over (partition by a order by b $frameTypeSql $boundarySql)" - val expr = windowed(Seq('a), Seq('b.asc), SpecifiedWindowFrame(frameType, begin, end)) - assertEqual(query, expr) - } - } - - // We cannot use non integer constants. - intercept("foo(*) over (partition by a order by b rows 10.0 preceding)", - "Frame bound value must be a constant integer.") - - // We cannot use an arbitrary expression. - intercept("foo(*) over (partition by a order by b rows exp(b) preceding)", - "Frame bound value must be a constant integer.") - } - - test("row constructor") { - // Note that '(a)' will be interpreted as a nested expression. - assertEqual("(a, b)", CreateStruct(Seq('a, 'b))) - assertEqual("(a, b, c)", CreateStruct(Seq('a, 'b, 'c))) - } - - test("scalar sub-query") { - assertEqual( - "(select max(val) from tbl) > current", - ScalarSubquery(table("tbl").select('max.function('val))) > 'current) - assertEqual( - "a = (select b from s)", - 'a === ScalarSubquery(table("s").select('b))) - } - - test("case when") { - assertEqual("case a when 1 then b when 2 then c else d end", - CaseKeyWhen('a, Seq(1, 'b, 2, 'c, 'd))) - assertEqual("case when a = 1 then b when a = 2 then c else d end", - CaseWhen(Seq(('a === 1, 'b.expr), ('a === 2, 'c.expr)), 'd)) - } - - test("dereference") { - assertEqual("a.b", UnresolvedAttribute("a.b")) - assertEqual("`select`.b", UnresolvedAttribute("select.b")) - assertEqual("(a + b).b", ('a + 'b).getField("b")) // This will fail analysis. - assertEqual("struct(a, b).b", 'struct.function('a, 'b).getField("b")) - } - - test("reference") { - // Regular - assertEqual("a", 'a) - - // Starting with a digit. - assertEqual("1a", Symbol("1a")) - - // Quoted using a keyword. - assertEqual("`select`", 'select) - - // Unquoted using an unreserved keyword. - assertEqual("columns", 'columns) - } - - test("subscript") { - assertEqual("a[b]", 'a.getItem('b)) - assertEqual("a[1 + 1]", 'a.getItem(Literal(1) + 1)) - assertEqual("`c`.a[b]", UnresolvedAttribute("c.a").getItem('b)) - } - - test("parenthesis") { - assertEqual("(a)", 'a) - assertEqual("r * (a + b)", 'r * ('a + 'b)) - } - - test("type constructors") { - // Dates. - assertEqual("dAte '2016-03-11'", Literal(Date.valueOf("2016-03-11"))) - intercept[IllegalArgumentException] { - parseExpression("DAtE 'mar 11 2016'") - } - - // Timestamps. - assertEqual("tImEstAmp '2016-03-11 20:54:00.000'", - Literal(Timestamp.valueOf("2016-03-11 20:54:00.000"))) - intercept[IllegalArgumentException] { - parseExpression("timestamP '2016-33-11 20:54:00.000'") - } - - // Unsupported datatype. - intercept("GEO '(10,-6)'", "Literals of type 'GEO' are currently not supported.") - } - - test("literals") { - // NULL - assertEqual("null", Literal(null)) - - // Boolean - assertEqual("trUe", Literal(true)) - assertEqual("False", Literal(false)) - - // Integral should have the narrowest possible type - assertEqual("787324", Literal(787324)) - assertEqual("7873247234798249234", Literal(7873247234798249234L)) - assertEqual("78732472347982492793712334", - Literal(BigDecimal("78732472347982492793712334").underlying())) - - // Decimal - assertEqual("7873247234798249279371.2334", - Literal(BigDecimal("7873247234798249279371.2334").underlying())) - - // Scientific Decimal - assertEqual("9.0e1", 90d) - assertEqual(".9e+2", 90d) - assertEqual("0.9e+2", 90d) - assertEqual("900e-1", 90d) - assertEqual("900.0E-1", 90d) - assertEqual("9.e+1", 90d) - intercept(".e3") - - // Tiny Int Literal - assertEqual("10Y", Literal(10.toByte)) - intercept("-1000Y") - - // Small Int Literal - assertEqual("10S", Literal(10.toShort)) - intercept("40000S") - - // Long Int Literal - assertEqual("10L", Literal(10L)) - intercept("78732472347982492793712334L") - - // Double Literal - assertEqual("10.0D", Literal(10.0D)) - // TODO we need to figure out if we should throw an exception here! - assertEqual("1E309", Literal(Double.PositiveInfinity)) - } - - test("strings") { - // Single Strings. - assertEqual("\"hello\"", "hello") - assertEqual("'hello'", "hello") - - // Multi-Strings. - assertEqual("\"hello\" 'world'", "helloworld") - assertEqual("'hello' \" \" 'world'", "hello world") - - // 'LIKE' string literals. Notice that an escaped '%' is the same as an escaped '\' and a - // regular '%'; to get the correct result you need to add another escaped '\'. - // TODO figure out if we shouldn't change the ParseUtils.unescapeSQLString method? - assertEqual("'pattern%'", "pattern%") - assertEqual("'no-pattern\\%'", "no-pattern\\%") - assertEqual("'pattern\\\\%'", "pattern\\%") - assertEqual("'pattern\\\\\\%'", "pattern\\\\%") - - // Escaped characters. - // See: http://dev.mysql.com/doc/refman/5.7/en/string-literals.html - assertEqual("'\\0'", "\u0000") // ASCII NUL (X'00') - assertEqual("'\\''", "\'") // Single quote - assertEqual("'\\\"'", "\"") // Double quote - assertEqual("'\\b'", "\b") // Backspace - assertEqual("'\\n'", "\n") // Newline - assertEqual("'\\r'", "\r") // Carriage return - assertEqual("'\\t'", "\t") // Tab character - assertEqual("'\\Z'", "\u001A") // ASCII 26 - CTRL + Z (EOF on windows) - - // Octals - assertEqual("'\\110\\145\\154\\154\\157\\041'", "Hello!") - - // Unicode - assertEqual("'\\u0087\\u0111\\u0114\\u0108\\u0100\\u0032\\u0058\\u0041'", "World :)") - } - - test("intervals") { - def intervalLiteral(u: String, s: String): Literal = { - Literal(CalendarInterval.fromSingleUnitString(u, s)) - } - - // Empty interval statement - intercept("interval", "at least one time unit should be given for interval literal") - - // Single Intervals. - val units = Seq( - "year", - "month", - "week", - "day", - "hour", - "minute", - "second", - "millisecond", - "microsecond") - val forms = Seq("", "s") - val values = Seq("0", "10", "-7", "21") - units.foreach { unit => - forms.foreach { form => - values.foreach { value => - val expected = intervalLiteral(unit, value) - assertEqual(s"interval $value $unit$form", expected) - assertEqual(s"interval '$value' $unit$form", expected) - } - } - } - - // Hive nanosecond notation. - assertEqual("interval 13.123456789 seconds", intervalLiteral("second", "13.123456789")) - assertEqual("interval -13.123456789 second", intervalLiteral("second", "-13.123456789")) - - // Non Existing unit - intercept("interval 10 nanoseconds", "No interval can be constructed") - - // Year-Month intervals. - val yearMonthValues = Seq("123-10", "496-0", "-2-3", "-123-0") - yearMonthValues.foreach { value => - val result = Literal(CalendarInterval.fromYearMonthString(value)) - assertEqual(s"interval '$value' year to month", result) - } - - // Day-Time intervals. - val datTimeValues = Seq( - "99 11:22:33.123456789", - "-99 11:22:33.123456789", - "10 9:8:7.123456789", - "1 0:0:0", - "-1 0:0:0", - "1 0:0:1") - datTimeValues.foreach { value => - val result = Literal(CalendarInterval.fromDayTimeString(value)) - assertEqual(s"interval '$value' day to second", result) - } - - // Unknown FROM TO intervals - intercept("interval 10 month to second", "Intervals FROM month TO second are not supported.") - - // Composed intervals. - assertEqual( - "interval 3 months 22 seconds 1 millisecond", - Literal(new CalendarInterval(3, 22001000L))) - assertEqual( - "interval 3 years '-1-10' year to month 3 weeks '1 0:0:2' day to second", - Literal(new CalendarInterval(14, - 22 * CalendarInterval.MICROS_PER_DAY + 2 * CalendarInterval.MICROS_PER_SECOND))) - } - - test("composed expressions") { - assertEqual("1 + r.r As q", (Literal(1) + UnresolvedAttribute("r.r")).as("q")) - assertEqual("1 - f('o', o(bar))", Literal(1) - 'f.function("o", 'o.function('bar))) - intercept("1 - f('o', o(bar)) hello * world", "mismatched input '*'") - } -} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ng/PlanParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ng/PlanParserSuite.scala deleted file mode 100644 index 4206d22ca7..0000000000 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ng/PlanParserSuite.scala +++ /dev/null @@ -1,429 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.sql.catalyst.parser.ng - -import org.apache.spark.sql.Row -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans._ -import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.types.{BooleanType, IntegerType} - -class PlanParserSuite extends PlanTest { - import CatalystSqlParser._ - import org.apache.spark.sql.catalyst.dsl.expressions._ - import org.apache.spark.sql.catalyst.dsl.plans._ - - def assertEqual(sqlCommand: String, plan: LogicalPlan): Unit = { - comparePlans(parsePlan(sqlCommand), plan) - } - - def intercept(sqlCommand: String, messages: String*): Unit = { - val e = intercept[ParseException](parsePlan(sqlCommand)) - messages.foreach { message => - assert(e.message.contains(message)) - } - } - - test("case insensitive") { - val plan = table("a").select(star()) - assertEqual("sELEct * FroM a", plan) - assertEqual("select * fRoM a", plan) - assertEqual("SELECT * FROM a", plan) - } - - test("show functions") { - assertEqual("show functions", ShowFunctions(None, None)) - assertEqual("show functions foo", ShowFunctions(None, Some("foo"))) - assertEqual("show functions foo.bar", ShowFunctions(Some("foo"), Some("bar"))) - assertEqual("show functions 'foo\\\\.*'", ShowFunctions(None, Some("foo\\.*"))) - intercept("show functions foo.bar.baz", "SHOW FUNCTIONS unsupported name") - } - - test("describe function") { - assertEqual("describe function bar", DescribeFunction("bar", isExtended = false)) - assertEqual("describe function extended bar", DescribeFunction("bar", isExtended = true)) - assertEqual("describe function foo.bar", DescribeFunction("foo.bar", isExtended = false)) - assertEqual("describe function extended f.bar", DescribeFunction("f.bar", isExtended = true)) - } - - test("set operations") { - val a = table("a").select(star()) - val b = table("b").select(star()) - - assertEqual("select * from a union select * from b", Distinct(a.union(b))) - assertEqual("select * from a union distinct select * from b", Distinct(a.union(b))) - assertEqual("select * from a union all select * from b", a.union(b)) - assertEqual("select * from a except select * from b", a.except(b)) - intercept("select * from a except all select * from b", "EXCEPT ALL is not supported.") - assertEqual("select * from a except distinct select * from b", a.except(b)) - assertEqual("select * from a intersect select * from b", a.intersect(b)) - intercept("select * from a intersect all select * from b", "INTERSECT ALL is not supported.") - assertEqual("select * from a intersect distinct select * from b", a.intersect(b)) - } - - test("common table expressions") { - def cte(plan: LogicalPlan, namedPlans: (String, LogicalPlan)*): With = { - val ctes = namedPlans.map { - case (name, cte) => - name -> SubqueryAlias(name, cte) - }.toMap - With(plan, ctes) - } - assertEqual( - "with cte1 as (select * from a) select * from cte1", - cte(table("cte1").select(star()), "cte1" -> table("a").select(star()))) - assertEqual( - "with cte1 (select 1) select * from cte1", - cte(table("cte1").select(star()), "cte1" -> OneRowRelation.select(1))) - assertEqual( - "with cte1 (select 1), cte2 as (select * from cte1) select * from cte2", - cte(table("cte2").select(star()), - "cte1" -> OneRowRelation.select(1), - "cte2" -> table("cte1").select(star()))) - intercept( - "with cte1 (select 1), cte1 as (select 1 from cte1) select * from cte1", - "Name 'cte1' is used for multiple common table expressions") - } - - test("simple select query") { - assertEqual("select 1", OneRowRelation.select(1)) - assertEqual("select a, b", OneRowRelation.select('a, 'b)) - assertEqual("select a, b from db.c", table("db", "c").select('a, 'b)) - assertEqual("select a, b from db.c where x < 1", table("db", "c").where('x < 1).select('a, 'b)) - assertEqual( - "select a, b from db.c having x < 1", - table("db", "c").select('a, 'b).where(('x < 1).cast(BooleanType))) - assertEqual("select distinct a, b from db.c", Distinct(table("db", "c").select('a, 'b))) - assertEqual("select all a, b from db.c", table("db", "c").select('a, 'b)) - } - - test("reverse select query") { - assertEqual("from a", table("a")) - assertEqual("from a select b, c", table("a").select('b, 'c)) - assertEqual( - "from db.a select b, c where d < 1", table("db", "a").where('d < 1).select('b, 'c)) - assertEqual("from a select distinct b, c", Distinct(table("a").select('b, 'c))) - assertEqual( - "from (from a union all from b) c select *", - table("a").union(table("b")).as("c").select(star())) - } - - test("transform query spec") { - val p = ScriptTransformation(Seq('a, 'b), "func", Seq.empty, table("e"), null) - assertEqual("select transform(a, b) using 'func' from e where f < 10", - p.copy(child = p.child.where('f < 10), output = Seq('key.string, 'value.string))) - assertEqual("map a, b using 'func' as c, d from e", - p.copy(output = Seq('c.string, 'd.string))) - assertEqual("reduce a, b using 'func' as (c: int, d decimal(10, 0)) from e", - p.copy(output = Seq('c.int, 'd.decimal(10, 0)))) - } - - test("multi select query") { - assertEqual( - "from a select * select * where s < 10", - table("a").select(star()).union(table("a").where('s < 10).select(star()))) - intercept( - "from a select * select * from x where a.s < 10", - "Multi-Insert queries cannot have a FROM clause in their individual SELECT statements") - assertEqual( - "from a insert into tbl1 select * insert into tbl2 select * where s < 10", - table("a").select(star()).insertInto("tbl1").union( - table("a").where('s < 10).select(star()).insertInto("tbl2"))) - } - - test("query organization") { - // Test all valid combinations of order by/sort by/distribute by/cluster by/limit/windows - val baseSql = "select * from t" - val basePlan = table("t").select(star()) - - val ws = Map("w1" -> WindowSpecDefinition(Seq.empty, Seq.empty, UnspecifiedFrame)) - val limitWindowClauses = Seq( - ("", (p: LogicalPlan) => p), - (" limit 10", (p: LogicalPlan) => p.limit(10)), - (" window w1 as ()", (p: LogicalPlan) => WithWindowDefinition(ws, p)), - (" window w1 as () limit 10", (p: LogicalPlan) => WithWindowDefinition(ws, p).limit(10)) - ) - - val orderSortDistrClusterClauses = Seq( - ("", basePlan), - (" order by a, b desc", basePlan.orderBy('a.asc, 'b.desc)), - (" sort by a, b desc", basePlan.sortBy('a.asc, 'b.desc)), - (" distribute by a, b", basePlan.distribute('a, 'b)), - (" distribute by a sort by b", basePlan.distribute('a).sortBy('b.asc)), - (" cluster by a, b", basePlan.distribute('a, 'b).sortBy('a.asc, 'b.asc)) - ) - - orderSortDistrClusterClauses.foreach { - case (s1, p1) => - limitWindowClauses.foreach { - case (s2, pf2) => - assertEqual(baseSql + s1 + s2, pf2(p1)) - } - } - - val msg = "Combination of ORDER BY/SORT BY/DISTRIBUTE BY/CLUSTER BY is not supported" - intercept(s"$baseSql order by a sort by a", msg) - intercept(s"$baseSql cluster by a distribute by a", msg) - intercept(s"$baseSql order by a cluster by a", msg) - intercept(s"$baseSql order by a distribute by a", msg) - } - - test("insert into") { - val sql = "select * from t" - val plan = table("t").select(star()) - def insert( - partition: Map[String, Option[String]], - overwrite: Boolean = false, - ifNotExists: Boolean = false): LogicalPlan = - InsertIntoTable(table("s"), partition, plan, overwrite, ifNotExists) - - // Single inserts - assertEqual(s"insert overwrite table s $sql", - insert(Map.empty, overwrite = true)) - assertEqual(s"insert overwrite table s if not exists $sql", - insert(Map.empty, overwrite = true, ifNotExists = true)) - assertEqual(s"insert into s $sql", - insert(Map.empty)) - assertEqual(s"insert into table s partition (c = 'd', e = 1) $sql", - insert(Map("c" -> Option("d"), "e" -> Option("1")))) - assertEqual(s"insert overwrite table s partition (c = 'd', x) if not exists $sql", - insert(Map("c" -> Option("d"), "x" -> None), overwrite = true, ifNotExists = true)) - - // Multi insert - val plan2 = table("t").where('x > 5).select(star()) - assertEqual("from t insert into s select * limit 1 insert into u select * where x > 5", - InsertIntoTable( - table("s"), Map.empty, plan.limit(1), overwrite = false, ifNotExists = false).union( - InsertIntoTable( - table("u"), Map.empty, plan2, overwrite = false, ifNotExists = false))) - } - - test("aggregation") { - val sql = "select a, b, sum(c) as c from d group by a, b" - - // Normal - assertEqual(sql, table("d").groupBy('a, 'b)('a, 'b, 'sum.function('c).as("c"))) - - // Cube - assertEqual(s"$sql with cube", - table("d").groupBy(Cube(Seq('a, 'b)))('a, 'b, 'sum.function('c).as("c"))) - - // Rollup - assertEqual(s"$sql with rollup", - table("d").groupBy(Rollup(Seq('a, 'b)))('a, 'b, 'sum.function('c).as("c"))) - - // Grouping Sets - assertEqual(s"$sql grouping sets((a, b), (a), ())", - GroupingSets(Seq(0, 1, 3), Seq('a, 'b), table("d"), Seq('a, 'b, 'sum.function('c).as("c")))) - intercept(s"$sql grouping sets((a, b), (c), ())", - "c doesn't show up in the GROUP BY list") - } - - test("limit") { - val sql = "select * from t" - val plan = table("t").select(star()) - assertEqual(s"$sql limit 10", plan.limit(10)) - assertEqual(s"$sql limit cast(9 / 4 as int)", plan.limit(Cast(Literal(9) / 4, IntegerType))) - } - - test("window spec") { - // Note that WindowSpecs are testing in the ExpressionParserSuite - val sql = "select * from t" - val plan = table("t").select(star()) - val spec = WindowSpecDefinition(Seq('a, 'b), Seq('c.asc), - SpecifiedWindowFrame(RowFrame, ValuePreceding(1), ValueFollowing(1))) - - // Test window resolution. - val ws1 = Map("w1" -> spec, "w2" -> spec, "w3" -> spec) - assertEqual( - s"""$sql - |window w1 as (partition by a, b order by c rows between 1 preceding and 1 following), - | w2 as w1, - | w3 as w1""".stripMargin, - WithWindowDefinition(ws1, plan)) - - // Fail with no reference. - intercept(s"$sql window w2 as w1", "Cannot resolve window reference 'w1'") - - // Fail when resolved reference is not a window spec. - intercept( - s"""$sql - |window w1 as (partition by a, b order by c rows between 1 preceding and 1 following), - | w2 as w1, - | w3 as w2""".stripMargin, - "Window reference 'w2' is not a window specification" - ) - } - - test("lateral view") { - // Single lateral view - assertEqual( - "select * from t lateral view explode(x) expl as x", - table("t") - .generate(Explode('x), join = true, outer = false, Some("expl"), Seq("x")) - .select(star())) - - // Multiple lateral views - assertEqual( - """select * - |from t - |lateral view explode(x) expl - |lateral view outer json_tuple(x, y) jtup q, z""".stripMargin, - table("t") - .generate(Explode('x), join = true, outer = false, Some("expl"), Seq.empty) - .generate(JsonTuple(Seq('x, 'y)), join = true, outer = true, Some("jtup"), Seq("q", "z")) - .select(star())) - - // Multi-Insert lateral views. - val from = table("t1").generate(Explode('x), join = true, outer = false, Some("expl"), Seq("x")) - assertEqual( - """from t1 - |lateral view explode(x) expl as x - |insert into t2 - |select * - |lateral view json_tuple(x, y) jtup q, z - |insert into t3 - |select * - |where s < 10 - """.stripMargin, - Union(from - .generate(JsonTuple(Seq('x, 'y)), join = true, outer = false, Some("jtup"), Seq("q", "z")) - .select(star()) - .insertInto("t2"), - from.where('s < 10).select(star()).insertInto("t3"))) - - // Unsupported generator. - intercept( - "select * from t lateral view posexplode(x) posexpl as x, y", - "Generator function 'posexplode' is not supported") - } - - test("joins") { - // Test single joins. - val testUnconditionalJoin = (sql: String, jt: JoinType) => { - assertEqual( - s"select * from t as tt $sql u", - table("t").as("tt").join(table("u"), jt, None).select(star())) - } - val testConditionalJoin = (sql: String, jt: JoinType) => { - assertEqual( - s"select * from t $sql u as uu on a = b", - table("t").join(table("u").as("uu"), jt, Option('a === 'b)).select(star())) - } - val testNaturalJoin = (sql: String, jt: JoinType) => { - assertEqual( - s"select * from t tt natural $sql u as uu", - table("t").as("tt").join(table("u").as("uu"), NaturalJoin(jt), None).select(star())) - } - val testUsingJoin = (sql: String, jt: JoinType) => { - assertEqual( - s"select * from t $sql u using(a, b)", - table("t").join(table("u"), UsingJoin(jt, Seq('a.attr, 'b.attr)), None).select(star())) - } - val testAll = Seq(testUnconditionalJoin, testConditionalJoin, testNaturalJoin, testUsingJoin) - - def test(sql: String, jt: JoinType, tests: Seq[(String, JoinType) => Unit]): Unit = { - tests.foreach(_(sql, jt)) - } - test("cross join", Inner, Seq(testUnconditionalJoin)) - test(",", Inner, Seq(testUnconditionalJoin)) - test("join", Inner, testAll) - test("inner join", Inner, testAll) - test("left join", LeftOuter, testAll) - test("left outer join", LeftOuter, testAll) - test("right join", RightOuter, testAll) - test("right outer join", RightOuter, testAll) - test("full join", FullOuter, testAll) - test("full outer join", FullOuter, testAll) - - // Test multiple consecutive joins - assertEqual( - "select * from a join b join c right join d", - table("a").join(table("b")).join(table("c")).join(table("d"), RightOuter).select(star())) - } - - test("sampled relations") { - val sql = "select * from t" - assertEqual(s"$sql tablesample(100 rows)", - table("t").limit(100).select(star())) - assertEqual(s"$sql tablesample(43 percent) as x", - Sample(0, .43d, withReplacement = false, 10L, table("t").as("x"))(true).select(star())) - assertEqual(s"$sql tablesample(bucket 4 out of 10) as x", - Sample(0, .4d, withReplacement = false, 10L, table("t").as("x"))(true).select(star())) - intercept(s"$sql tablesample(bucket 4 out of 10 on x) as x", - "TABLESAMPLE(BUCKET x OUT OF y ON id) is not supported") - intercept(s"$sql tablesample(bucket 11 out of 10) as x", - s"Sampling fraction (${11.0/10.0}) must be on interval [0, 1]") - } - - test("sub-query") { - val plan = table("t0").select('id) - assertEqual("select id from (t0)", plan) - assertEqual("select id from ((((((t0))))))", plan) - assertEqual( - "(select * from t1) union distinct (select * from t2)", - Distinct(table("t1").select(star()).union(table("t2").select(star())))) - assertEqual( - "select * from ((select * from t1) union (select * from t2)) t", - Distinct( - table("t1").select(star()).union(table("t2").select(star()))).as("t").select(star())) - assertEqual( - """select id - |from (((select id from t0) - | union all - | (select id from t0)) - | union all - | (select id from t0)) as u_1 - """.stripMargin, - plan.union(plan).union(plan).as("u_1").select('id)) - } - - test("scalar sub-query") { - assertEqual( - "select (select max(b) from s) ss from t", - table("t").select(ScalarSubquery(table("s").select('max.function('b))).as("ss"))) - assertEqual( - "select * from t where a = (select b from s)", - table("t").where('a === ScalarSubquery(table("s").select('b))).select(star())) - assertEqual( - "select g from t group by g having a > (select b from s)", - table("t") - .groupBy('g)('g) - .where(('a > ScalarSubquery(table("s").select('b))).cast(BooleanType))) - } - - test("table reference") { - assertEqual("table t", table("t")) - assertEqual("table d.t", table("d", "t")) - } - - test("inline table") { - assertEqual("values 1, 2, 3, 4", LocalRelation.fromExternalRows( - Seq('col1.int), - Seq(1, 2, 3, 4).map(x => Row(x)))) - assertEqual( - "values (1, 'a'), (2, 'b'), (3, 'c') as tbl(a, b)", - LocalRelation.fromExternalRows( - Seq('a.int, 'b.string), - Seq((1, "a"), (2, "b"), (3, "c")).map(x => Row(x._1, x._2))).as("tbl")) - intercept("values (a, 'a'), (b, 'b')", - "All expressions in an inline table must be constants.") - intercept("values (1, 'a'), (2, 'b') as tbl(a, b, c)", - "Number of aliases must match the number of fields in an inline table.") - intercept[ArrayIndexOutOfBoundsException](parsePlan("values (1, 'a'), (2, 'b', 5Y)")) - } -} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ng/TableIdentifierParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ng/TableIdentifierParserSuite.scala deleted file mode 100644 index 0874322187..0000000000 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ng/TableIdentifierParserSuite.scala +++ /dev/null @@ -1,42 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.sql.catalyst.parser.ng - -import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.TableIdentifier - -class TableIdentifierParserSuite extends SparkFunSuite { - import CatalystSqlParser._ - - test("table identifier") { - // Regular names. - assert(TableIdentifier("q") === parseTableIdentifier("q")) - assert(TableIdentifier("q", Option("d")) === parseTableIdentifier("d.q")) - - // Illegal names. - intercept[ParseException](parseTableIdentifier("")) - intercept[ParseException](parseTableIdentifier("d.q.g")) - - // SQL Keywords. - val keywords = Seq("select", "from", "where", "left", "right") - keywords.foreach { keyword => - intercept[ParseException](parseTableIdentifier(keyword)) - assert(TableIdentifier(keyword) === parseTableIdentifier(s"`$keyword`")) - assert(TableIdentifier(keyword, Option("db")) === parseTableIdentifier(s"db.`$keyword`")) - } - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkQl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkQl.scala deleted file mode 100644 index 6fe04757ba..0000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkQl.scala +++ /dev/null @@ -1,387 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.sql.execution - -import org.apache.spark.sql.{AnalysisException, SaveMode} -import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.catalyst.parser._ -import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, OneRowRelation} -import org.apache.spark.sql.execution.command._ -import org.apache.spark.sql.execution.datasources._ -import org.apache.spark.sql.types.StructType - -private[sql] class SparkQl(conf: ParserConf = SimpleParserConf()) extends CatalystQl(conf) { - import ParserUtils._ - - /** Check if a command should not be explained. */ - protected def isNoExplainCommand(command: String): Boolean = { - "TOK_DESCTABLE" == command || "TOK_ALTERTABLE" == command - } - - /** - * For each node, extract properties in the form of a list - * ['key_part1', 'key_part2', 'key_part3', 'value'] - * into a pair (key_part1.key_part2.key_part3, value). - * - * Example format: - * - * TOK_TABLEPROPERTY - * :- 'k1' - * +- 'v1' - * TOK_TABLEPROPERTY - * :- 'k2' - * +- 'v2' - * TOK_TABLEPROPERTY - * :- 'k3' - * +- 'v3' - */ - private def extractProps( - props: Seq[ASTNode], - expectedNodeText: String): Seq[(String, String)] = { - props.map { - case Token(x, keysAndValue) if x == expectedNodeText => - val key = keysAndValue.init.map { x => unquoteString(x.text) }.mkString(".") - val value = unquoteString(keysAndValue.last.text) - (key, value) - case p => - parseFailed(s"Expected property '$expectedNodeText' in command", p) - } - } - - protected override def nodeToPlan(node: ASTNode): LogicalPlan = { - node match { - case Token("TOK_SETCONFIG", Nil) => - val keyValueSeparatorIndex = node.remainder.indexOf('=') - if (keyValueSeparatorIndex >= 0) { - val key = node.remainder.substring(0, keyValueSeparatorIndex).trim - val value = node.remainder.substring(keyValueSeparatorIndex + 1).trim - SetCommand(Some(key -> Option(value))) - } else if (node.remainder.nonEmpty) { - SetCommand(Some(node.remainder -> None)) - } else { - SetCommand(None) - } - - // Just fake explain for any of the native commands. - case Token("TOK_EXPLAIN", explainArgs) if isNoExplainCommand(explainArgs.head.text) => - ExplainCommand(OneRowRelation) - - case Token("TOK_EXPLAIN", explainArgs) if "TOK_CREATETABLE" == explainArgs.head.text => - val Some(crtTbl) :: _ :: extended :: Nil = - getClauses(Seq("TOK_CREATETABLE", "FORMATTED", "EXTENDED"), explainArgs) - ExplainCommand(nodeToPlan(crtTbl), extended = extended.isDefined) - - case Token("TOK_EXPLAIN", explainArgs) => - // Ignore FORMATTED if present. - val Some(query) :: _ :: extended :: Nil = - getClauses(Seq("TOK_QUERY", "FORMATTED", "EXTENDED"), explainArgs) - ExplainCommand(nodeToPlan(query), extended = extended.isDefined) - - case Token("TOK_REFRESHTABLE", nameParts :: Nil) => - val tableIdent = extractTableIdent(nameParts) - RefreshTable(tableIdent) - - // CREATE DATABASE [IF NOT EXISTS] database_name [COMMENT database_comment] - // [LOCATION path] [WITH DBPROPERTIES (key1=val1, key2=val2, ...)]; - case Token("TOK_CREATEDATABASE", Token(dbName, Nil) :: args) => - val databaseName = cleanIdentifier(dbName) - val Seq(ifNotExists, dbLocation, databaseComment, dbprops) = getClauses(Seq( - "TOK_IFNOTEXISTS", - "TOK_DATABASELOCATION", - "TOK_DATABASECOMMENT", - "TOK_DATABASEPROPERTIES"), args) - val location = dbLocation.map { - case Token("TOK_DATABASELOCATION", Token(loc, Nil) :: Nil) => unquoteString(loc) - case _ => parseFailed("Invalid CREATE DATABASE command", node) - } - val comment = databaseComment.map { - case Token("TOK_DATABASECOMMENT", Token(com, Nil) :: Nil) => unquoteString(com) - case _ => parseFailed("Invalid CREATE DATABASE command", node) - } - val props = dbprops.toSeq.flatMap { - case Token("TOK_DATABASEPROPERTIES", Token("TOK_DBPROPLIST", propList) :: Nil) => - // Example format: - // - // TOK_DATABASEPROPERTIES - // +- TOK_DBPROPLIST - // :- TOK_TABLEPROPERTY - // : :- 'k1' - // : +- 'v1' - // :- TOK_TABLEPROPERTY - // :- 'k2' - // +- 'v2' - extractProps(propList, "TOK_TABLEPROPERTY") - case _ => parseFailed("Invalid CREATE DATABASE command", node) - }.toMap - CreateDatabase(databaseName, ifNotExists.isDefined, location, comment, props) - - // DROP DATABASE [IF EXISTS] database_name [RESTRICT|CASCADE]; - case Token("TOK_DROPDATABASE", Token(dbName, Nil) :: otherArgs) => - // Example format: - // - // TOK_DROPDATABASE - // :- database_name - // :- TOK_IFEXISTS - // +- TOK_RESTRICT/TOK_CASCADE - val databaseName = cleanIdentifier(dbName) - // The default is RESTRICT - val Seq(ifExists, _, cascade) = getClauses(Seq( - "TOK_IFEXISTS", "TOK_RESTRICT", "TOK_CASCADE"), otherArgs) - DropDatabase(databaseName, ifExists.isDefined, cascade.isDefined) - - // ALTER (DATABASE|SCHEMA) database_name SET DBPROPERTIES (property_name=property_value, ...) - case Token("TOK_ALTERDATABASE_PROPERTIES", Token(dbName, Nil) :: args) => - val databaseName = cleanIdentifier(dbName) - val dbprops = getClause("TOK_DATABASEPROPERTIES", args) - val props = dbprops match { - case Token("TOK_DATABASEPROPERTIES", Token("TOK_DBPROPLIST", propList) :: Nil) => - // Example format: - // - // TOK_DATABASEPROPERTIES - // +- TOK_DBPROPLIST - // :- TOK_TABLEPROPERTY - // : :- 'k1' - // : +- 'v1' - // :- TOK_TABLEPROPERTY - // :- 'k2' - // +- 'v2' - extractProps(propList, "TOK_TABLEPROPERTY") - case _ => parseFailed("Invalid ALTER DATABASE command", node) - } - AlterDatabaseProperties(databaseName, props.toMap) - - // DESCRIBE DATABASE [EXTENDED] db_name - case Token("TOK_DESCDATABASE", Token(dbName, Nil) :: describeArgs) => - val databaseName = cleanIdentifier(dbName) - val extended = getClauseOption("EXTENDED", describeArgs) - DescribeDatabase(databaseName, extended.isDefined) - - // CREATE [TEMPORARY] FUNCTION [db_name.]function_name AS class_name - // [USING JAR|FILE|ARCHIVE 'file_uri' [, JAR|FILE|ARCHIVE 'file_uri'] ]; - case Token("TOK_CREATEFUNCTION", args) => - // Example format: - // - // TOK_CREATEFUNCTION - // :- db_name - // :- func_name - // :- alias - // +- TOK_RESOURCE_LIST - // :- TOK_RESOURCE_URI - // : :- TOK_JAR - // : +- '/path/to/jar' - // +- TOK_RESOURCE_URI - // :- TOK_FILE - // +- 'path/to/file' - val (funcNameArgs, otherArgs) = args.partition { - case Token("TOK_RESOURCE_LIST", _) => false - case Token("TOK_TEMPORARY", _) => false - case Token(_, Nil) => true - case _ => parseFailed("Invalid CREATE FUNCTION command", node) - } - // If database name is specified, there are 3 tokens, otherwise 2. - val (dbName, funcName, alias) = funcNameArgs match { - case Token(dbName, Nil) :: Token(fname, Nil) :: Token(aname, Nil) :: Nil => - (Some(unquoteString(dbName)), unquoteString(fname), unquoteString(aname)) - case Token(fname, Nil) :: Token(aname, Nil) :: Nil => - (None, unquoteString(fname), unquoteString(aname)) - case _ => - parseFailed("Invalid CREATE FUNCTION command", node) - } - // Extract other keywords, if they exist - val Seq(rList, temp) = getClauses(Seq("TOK_RESOURCE_LIST", "TOK_TEMPORARY"), otherArgs) - val resources: Seq[(String, String)] = rList.toSeq.flatMap { - case Token("TOK_RESOURCE_LIST", resList) => - resList.map { - case Token("TOK_RESOURCE_URI", rType :: Token(rPath, Nil) :: Nil) => - val resourceType = rType match { - case Token("TOK_JAR", Nil) => "jar" - case Token("TOK_FILE", Nil) => "file" - case Token("TOK_ARCHIVE", Nil) => "archive" - case Token(f, _) => parseFailed(s"Unexpected resource format '$f'", node) - } - (resourceType, unquoteString(rPath)) - case _ => parseFailed("Invalid CREATE FUNCTION command", node) - } - case _ => parseFailed("Invalid CREATE FUNCTION command", node) - } - CreateFunction(dbName, funcName, alias, resources, temp.isDefined)(node.source) - - // DROP [TEMPORARY] FUNCTION [IF EXISTS] function_name; - case Token("TOK_DROPFUNCTION", args) => - // Example format: - // - // TOK_DROPFUNCTION - // :- db_name - // :- func_name - // :- TOK_IFEXISTS - // +- TOK_TEMPORARY - val (funcNameArgs, otherArgs) = args.partition { - case Token("TOK_IFEXISTS", _) => false - case Token("TOK_TEMPORARY", _) => false - case Token(_, Nil) => true - case _ => parseFailed("Invalid DROP FUNCTION command", node) - } - // If database name is specified, there are 2 tokens, otherwise 1. - val (dbName, funcName) = funcNameArgs match { - case Token(dbName, Nil) :: Token(fname, Nil) :: Nil => - (Some(unquoteString(dbName)), unquoteString(fname)) - case Token(fname, Nil) :: Nil => - (None, unquoteString(fname)) - case _ => - parseFailed("Invalid DROP FUNCTION command", node) - } - - val Seq(ifExists, temp) = getClauses(Seq( - "TOK_IFEXISTS", "TOK_TEMPORARY"), otherArgs) - - DropFunction(dbName, funcName, ifExists.isDefined, temp.isDefined)(node.source) - - case Token("TOK_ALTERTABLE", alterTableArgs) => - AlterTableCommandParser.parse(node) - - case Token("TOK_CREATETABLEUSING", createTableArgs) => - val Seq( - temp, - ifNotExists, - Some(tabName), - tableCols, - Some(Token("TOK_TABLEPROVIDER", providerNameParts)), - tableOpts, - tableAs) = getClauses(Seq( - "TEMPORARY", - "TOK_IFNOTEXISTS", - "TOK_TABNAME", "TOK_TABCOLLIST", - "TOK_TABLEPROVIDER", - "TOK_TABLEOPTIONS", - "TOK_QUERY"), createTableArgs) - val tableIdent: TableIdentifier = extractTableIdent(tabName) - val columns = tableCols.map { - case Token("TOK_TABCOLLIST", fields) => StructType(fields.map(nodeToStructField)) - case _ => parseFailed("Invalid CREATE TABLE command", node) - } - val provider = providerNameParts.map { - case Token(name, Nil) => name - case _ => parseFailed("Invalid CREATE TABLE command", node) - }.mkString(".") - val options = tableOpts.toSeq.flatMap { - case Token("TOK_TABLEOPTIONS", opts) => extractProps(opts, "TOK_TABLEOPTION") - case _ => parseFailed("Invalid CREATE TABLE command", node) - }.toMap - val asClause = tableAs.map(nodeToPlan) - - if (temp.isDefined && ifNotExists.isDefined) { - throw new AnalysisException( - "a CREATE TEMPORARY TABLE statement does not allow IF NOT EXISTS clause.") - } - - if (asClause.isDefined) { - if (columns.isDefined) { - throw new AnalysisException( - "a CREATE TABLE AS SELECT statement does not allow column definitions.") - } - - val mode = if (ifNotExists.isDefined) { - SaveMode.Ignore - } else if (temp.isDefined) { - SaveMode.Overwrite - } else { - SaveMode.ErrorIfExists - } - - CreateTableUsingAsSelect(tableIdent, - provider, - temp.isDefined, - Array.empty[String], - bucketSpec = None, - mode, - options, - asClause.get) - } else { - CreateTableUsing( - tableIdent, - columns, - provider, - temp.isDefined, - options, - ifNotExists.isDefined, - managedIfNoPath = false) - } - - case Token("TOK_SWITCHDATABASE", Token(database, Nil) :: Nil) => - SetDatabaseCommand(cleanIdentifier(database)) - - case Token("TOK_DESCTABLE", describeArgs) => - // Reference: https://cwiki.apache.org/confluence/display/Hive/LanguageManual+DDL - val Some(tableType) :: formatted :: extended :: pretty :: Nil = - getClauses(Seq("TOK_TABTYPE", "FORMATTED", "EXTENDED", "PRETTY"), describeArgs) - if (formatted.isDefined || pretty.isDefined) { - // FORMATTED and PRETTY are not supported and this statement will be treated as - // a Hive native command. - nodeToDescribeFallback(node) - } else { - tableType match { - case Token("TOK_TABTYPE", Token("TOK_TABNAME", nameParts) :: Nil) => - nameParts match { - case Token(dbName, Nil) :: Token(tableName, Nil) :: Nil => - // It is describing a table with the format like "describe db.table". - // TODO: Actually, a user may mean tableName.columnName. Need to resolve this - // issue. - val tableIdent = TableIdentifier( - cleanIdentifier(tableName), Some(cleanIdentifier(dbName))) - datasources.DescribeCommand(tableIdent, isExtended = extended.isDefined) - case Token(dbName, Nil) :: Token(tableName, Nil) :: Token(colName, Nil) :: Nil => - // It is describing a column with the format like "describe db.table column". - nodeToDescribeFallback(node) - case tableName :: Nil => - // It is describing a table with the format like "describe table". - datasources.DescribeCommand( - TableIdentifier(cleanIdentifier(tableName.text)), - isExtended = extended.isDefined) - case _ => - nodeToDescribeFallback(node) - } - // All other cases. - case _ => - nodeToDescribeFallback(node) - } - } - - case Token("TOK_CACHETABLE", Token(tableName, Nil) :: args) => - val Seq(lzy, selectAst) = getClauses(Seq("LAZY", "TOK_QUERY"), args) - CacheTableCommand(tableName, selectAst.map(nodeToPlan), lzy.isDefined) - - case Token("TOK_UNCACHETABLE", Token(tableName, Nil) :: Nil) => - UncacheTableCommand(tableName) - - case Token("TOK_CLEARCACHE", Nil) => - ClearCacheCommand - - case Token("TOK_SHOWTABLES", args) => - val databaseName = args match { - case Nil => None - case Token("TOK_FROM", Token(dbName, Nil) :: Nil) :: Nil => Option(dbName) - case _ => noParseRule("SHOW TABLES", node) - } - ShowTablesCommand(databaseName) - - case _ => - super.nodeToPlan(node) - } - } - - protected def nodeToDescribeFallback(node: ASTNode): LogicalPlan = noParseRule("Describe", node) -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala index 8333074eca..b4687c985d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala @@ -20,8 +20,8 @@ import scala.collection.JavaConverters._ import org.apache.spark.sql.SaveMode import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.catalyst.parser.ng.{AbstractSqlParser, AstBuilder, ParseException} -import org.apache.spark.sql.catalyst.parser.ng.SqlBaseParser._ +import org.apache.spark.sql.catalyst.parser.{AbstractSqlParser, AstBuilder, ParseException} +import org.apache.spark.sql.catalyst.parser.SqlBaseParser._ import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, OneRowRelation} import org.apache.spark.sql.execution.command.{DescribeCommand => _, _} import org.apache.spark.sql.execution.datasources._ @@ -37,7 +37,7 @@ object SparkSqlParser extends AbstractSqlParser{ * Builder that converts an ANTLR ParseTree into a LogicalPlan/Expression/TableIdentifier. */ class SparkSqlAstBuilder extends AstBuilder { - import org.apache.spark.sql.catalyst.parser.ng.ParserUtils._ + import org.apache.spark.sql.catalyst.parser.ParserUtils._ /** * Create a [[SetCommand]] logical plan. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AlterTableCommandParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AlterTableCommandParser.scala deleted file mode 100644 index 9fbe6db467..0000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AlterTableCommandParser.scala +++ /dev/null @@ -1,431 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.command - -import scala.collection.mutable.ArrayBuffer - -import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.catalyst.catalog.ExternalCatalog.TablePartitionSpec -import org.apache.spark.sql.catalyst.expressions.{Ascending, Descending, SortDirection} -import org.apache.spark.sql.catalyst.parser._ -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.execution.datasources._ -import org.apache.spark.sql.types.StructType - - -/** - * Helper object to parse alter table commands. - */ -object AlterTableCommandParser { - import ParserUtils._ - - /** - * Parse the given node assuming it is an alter table command. - */ - def parse(node: ASTNode): LogicalPlan = { - node.children match { - case (tabName @ Token("TOK_TABNAME", _)) :: otherNodes => - val tableIdent = extractTableIdent(tabName) - val partSpec = getClauseOption("TOK_PARTSPEC", node.children).map(parsePartitionSpec) - matchAlterTableCommands(node, otherNodes, tableIdent, partSpec) - case _ => - parseFailed("Could not parse ALTER TABLE command", node) - } - } - - private def cleanAndUnquoteString(s: String): String = { - cleanIdentifier(unquoteString(s)) - } - - /** - * Extract partition spec from the given [[ASTNode]] as a map, assuming it exists. - * - * Example format: - * - * TOK_PARTSPEC - * :- TOK_PARTVAL - * : :- dt - * : +- '2008-08-08' - * +- TOK_PARTVAL - * :- country - * +- 'us' - */ - private def parsePartitionSpec(node: ASTNode): Map[String, String] = { - node match { - case Token("TOK_PARTSPEC", partitions) => - partitions.map { - // Note: sometimes there's a "=", "<" or ">" between the key and the value - // (e.g. when dropping all partitions with value > than a certain constant) - case Token("TOK_PARTVAL", ident :: conj :: constant :: Nil) => - (cleanAndUnquoteString(ident.text), cleanAndUnquoteString(constant.text)) - case Token("TOK_PARTVAL", ident :: constant :: Nil) => - (cleanAndUnquoteString(ident.text), cleanAndUnquoteString(constant.text)) - case Token("TOK_PARTVAL", ident :: Nil) => - (cleanAndUnquoteString(ident.text), null) - case _ => - parseFailed("Invalid ALTER TABLE command", node) - }.toMap - case _ => - parseFailed("Expected partition spec in ALTER TABLE command", node) - } - } - - /** - * Extract table properties from the given [[ASTNode]] as a map, assuming it exists. - * - * Example format: - * - * TOK_TABLEPROPERTIES - * +- TOK_TABLEPROPLIST - * :- TOK_TABLEPROPERTY - * : :- 'test' - * : +- 'value' - * +- TOK_TABLEPROPERTY - * :- 'comment' - * +- 'new_comment' - */ - private def extractTableProps(node: ASTNode): Map[String, String] = { - node match { - case Token("TOK_TABLEPROPERTIES", propsList) => - propsList.flatMap { - case Token("TOK_TABLEPROPLIST", props) => - props.map { case Token("TOK_TABLEPROPERTY", key :: value :: Nil) => - val k = cleanAndUnquoteString(key.text) - val v = value match { - case Token("TOK_NULL", Nil) => null - case _ => cleanAndUnquoteString(value.text) - } - (k, v) - } - case _ => - parseFailed("Invalid ALTER TABLE command", node) - }.toMap - case _ => - parseFailed("Expected table properties in ALTER TABLE command", node) - } - } - - /** - * Parse an alter table command from a [[ASTNode]] into a [[LogicalPlan]]. - * This follows https://cwiki.apache.org/confluence/display/Hive/LanguageManual+DDL. - * - * @param node the original [[ASTNode]] to parse. - * @param otherNodes the other [[ASTNode]]s after the first one containing the table name. - * @param tableIdent identifier of the table, parsed from the first [[ASTNode]]. - * @param partition spec identifying the partition this command is concerned with, if any. - */ - // TODO: This method is massive. Break it down. - private def matchAlterTableCommands( - node: ASTNode, - otherNodes: Seq[ASTNode], - tableIdent: TableIdentifier, - partition: Option[TablePartitionSpec]): LogicalPlan = { - otherNodes match { - // ALTER TABLE table_name RENAME TO new_table_name; - case Token("TOK_ALTERTABLE_RENAME", renameArgs) :: _ => - val tableNameClause = getClause("TOK_TABNAME", renameArgs) - val newTableIdent = extractTableIdent(tableNameClause) - AlterTableRename(tableIdent, newTableIdent)(node.source) - - // ALTER TABLE table_name SET TBLPROPERTIES ('comment' = new_comment); - case Token("TOK_ALTERTABLE_PROPERTIES", args) :: _ => - val properties = extractTableProps(args.head) - AlterTableSetProperties(tableIdent, properties)(node.source) - - // ALTER TABLE table_name UNSET TBLPROPERTIES IF EXISTS ('comment', 'key'); - case Token("TOK_ALTERTABLE_DROPPROPERTIES", args) :: _ => - val properties = extractTableProps(args.head) - val ifExists = getClauseOption("TOK_IFEXISTS", args).isDefined - AlterTableUnsetProperties(tableIdent, properties, ifExists)(node.source) - - // ALTER TABLE table_name [PARTITION spec] SET SERDE serde_name [WITH SERDEPROPERTIES props]; - case Token("TOK_ALTERTABLE_SERIALIZER", Token(serdeClassName, Nil) :: serdeArgs) :: _ => - AlterTableSerDeProperties( - tableIdent, - Some(cleanAndUnquoteString(serdeClassName)), - serdeArgs.headOption.map(extractTableProps), - partition)(node.source) - - // ALTER TABLE table_name [PARTITION spec] SET SERDEPROPERTIES serde_properties; - case Token("TOK_ALTERTABLE_SERDEPROPERTIES", args) :: _ => - AlterTableSerDeProperties( - tableIdent, - None, - Some(extractTableProps(args.head)), - partition)(node.source) - - // ALTER TABLE table_name CLUSTERED BY (col, ...) [SORTED BY (col, ...)] INTO n BUCKETS; - case Token("TOK_ALTERTABLE_CLUSTER_SORT", Token("TOK_ALTERTABLE_BUCKETS", b) :: Nil) :: _ => - val clusterCols: Seq[String] = b.head match { - case Token("TOK_TABCOLNAME", children) => children.map(_.text) - case _ => parseFailed("Invalid ALTER TABLE command", node) - } - // If sort columns are specified, num buckets should be the third arg. - // If sort columns are not specified, num buckets should be the second arg. - // TODO: actually use `sortDirections` once we actually store that in the metastore - val (sortCols: Seq[String], sortDirections: Seq[SortDirection], numBuckets: Int) = { - b.tail match { - case Token("TOK_TABCOLNAME", children) :: numBucketsNode :: Nil => - val (cols, directions) = children.map { - case Token("TOK_TABSORTCOLNAMEASC", Token(col, Nil) :: Nil) => (col, Ascending) - case Token("TOK_TABSORTCOLNAMEDESC", Token(col, Nil) :: Nil) => (col, Descending) - }.unzip - (cols, directions, numBucketsNode.text.toInt) - case numBucketsNode :: Nil => - (Nil, Nil, numBucketsNode.text.toInt) - case _ => - parseFailed("Invalid ALTER TABLE command", node) - } - } - AlterTableStorageProperties( - tableIdent, - BucketSpec(numBuckets, clusterCols, sortCols))(node.source) - - // ALTER TABLE table_name NOT CLUSTERED - case Token("TOK_ALTERTABLE_CLUSTER_SORT", Token("TOK_NOT_CLUSTERED", Nil) :: Nil) :: _ => - AlterTableNotClustered(tableIdent)(node.source) - - // ALTER TABLE table_name NOT SORTED - case Token("TOK_ALTERTABLE_CLUSTER_SORT", Token("TOK_NOT_SORTED", Nil) :: Nil) :: _ => - AlterTableNotSorted(tableIdent)(node.source) - - // ALTER TABLE table_name SKEWED BY (col1, col2) - // ON ((col1_value, col2_value) [, (col1_value, col2_value), ...]) - // [STORED AS DIRECTORIES]; - case Token("TOK_ALTERTABLE_SKEWED", - Token("TOK_TABLESKEWED", - Token("TOK_TABCOLNAME", colNames) :: colValues :: rest) :: Nil) :: _ => - // Example format: - // - // TOK_ALTERTABLE_SKEWED - // :- TOK_TABLESKEWED - // : :- TOK_TABCOLNAME - // : : :- dt - // : : +- country - // :- TOK_TABCOLVALUE_PAIR - // : :- TOK_TABCOLVALUES - // : : :- TOK_TABCOLVALUE - // : : : :- '2008-08-08' - // : : : +- 'us' - // : :- TOK_TABCOLVALUES - // : : :- TOK_TABCOLVALUE - // : : : :- '2009-09-09' - // : : : +- 'uk' - // +- TOK_STOREASDIR - val names = colNames.map { n => cleanAndUnquoteString(n.text) } - val values = colValues match { - case Token("TOK_TABCOLVALUE", vals) => - Seq(vals.map { n => cleanAndUnquoteString(n.text) }) - case Token("TOK_TABCOLVALUE_PAIR", pairs) => - pairs.map { - case Token("TOK_TABCOLVALUES", Token("TOK_TABCOLVALUE", vals) :: Nil) => - vals.map { n => cleanAndUnquoteString(n.text) } - case _ => - parseFailed("Invalid ALTER TABLE command", node) - } - case _ => - parseFailed("Invalid ALTER TABLE command", node) - } - val storedAsDirs = rest match { - case Token("TOK_STOREDASDIRS", Nil) :: Nil => true - case _ => false - } - AlterTableSkewed( - tableIdent, - names, - values, - storedAsDirs)(node.source) - - // ALTER TABLE table_name NOT SKEWED - case Token("TOK_ALTERTABLE_SKEWED", Nil) :: _ => - AlterTableNotSkewed(tableIdent)(node.source) - - // ALTER TABLE table_name NOT STORED AS DIRECTORIES - case Token("TOK_ALTERTABLE_SKEWED", Token("TOK_STOREDASDIRS", Nil) :: Nil) :: _ => - AlterTableNotStoredAsDirs(tableIdent)(node.source) - - // ALTER TABLE table_name SET SKEWED LOCATION (col1="loc1" [, (col2, col3)="loc2", ...] ); - case Token("TOK_ALTERTABLE_SKEWED_LOCATION", - Token("TOK_SKEWED_LOCATIONS", - Token("TOK_SKEWED_LOCATION_LIST", locationMaps) :: Nil) :: Nil) :: _ => - // Example format: - // - // TOK_ALTERTABLE_SKEWED_LOCATION - // +- TOK_SKEWED_LOCATIONS - // +- TOK_SKEWED_LOCATION_LIST - // :- TOK_SKEWED_LOCATION_MAP - // : :- 'col1' - // : +- 'loc1' - // +- TOK_SKEWED_LOCATION_MAP - // :- TOK_TABCOLVALUES - // : +- TOK_TABCOLVALUE - // : :- 'col2' - // : +- 'col3' - // +- 'loc2' - val skewedMaps = locationMaps.flatMap { - case Token("TOK_SKEWED_LOCATION_MAP", col :: loc :: Nil) => - col match { - case Token(const, Nil) => - Seq((cleanAndUnquoteString(const), cleanAndUnquoteString(loc.text))) - case Token("TOK_TABCOLVALUES", Token("TOK_TABCOLVALUE", keys) :: Nil) => - keys.map { k => (cleanAndUnquoteString(k.text), cleanAndUnquoteString(loc.text)) } - } - case _ => - parseFailed("Invalid ALTER TABLE command", node) - }.toMap - AlterTableSkewedLocation(tableIdent, skewedMaps)(node.source) - - // ALTER TABLE table_name ADD [IF NOT EXISTS] PARTITION spec [LOCATION 'loc1'] - // spec [LOCATION 'loc2'] ...; - case Token("TOK_ALTERTABLE_ADDPARTS", args) :: _ => - val (ifNotExists, parts) = args.head match { - case Token("TOK_IFNOTEXISTS", Nil) => (true, args.tail) - case _ => (false, args) - } - // List of (spec, location) to describe partitions to add - // Each partition spec may or may not be followed by a location - val parsedParts = new ArrayBuffer[(TablePartitionSpec, Option[String])] - parts.foreach { - case t @ Token("TOK_PARTSPEC", _) => - parsedParts += ((parsePartitionSpec(t), None)) - case Token("TOK_PARTITIONLOCATION", loc :: Nil) => - // Update the location of the last partition we just added - if (parsedParts.nonEmpty) { - val (spec, _) = parsedParts.remove(parsedParts.length - 1) - parsedParts += ((spec, Some(unquoteString(loc.text)))) - } - case _ => - parseFailed("Invalid ALTER TABLE command", node) - } - AlterTableAddPartition(tableIdent, parsedParts, ifNotExists)(node.source) - - // ALTER TABLE table_name PARTITION spec1 RENAME TO PARTITION spec2; - case Token("TOK_ALTERTABLE_RENAMEPART", spec :: Nil) :: _ => - val newPartition = parsePartitionSpec(spec) - val oldPartition = partition.getOrElse { - parseFailed("Expected old partition spec in ALTER TABLE rename partition command", node) - } - AlterTableRenamePartition(tableIdent, oldPartition, newPartition)(node.source) - - // ALTER TABLE table_name_1 EXCHANGE PARTITION spec WITH TABLE table_name_2; - case Token("TOK_ALTERTABLE_EXCHANGEPARTITION", spec :: newTable :: Nil) :: _ => - val parsedSpec = parsePartitionSpec(spec) - val newTableIdent = extractTableIdent(newTable) - AlterTableExchangePartition(tableIdent, newTableIdent, parsedSpec)(node.source) - - // ALTER TABLE table_name DROP [IF EXISTS] PARTITION spec1[, PARTITION spec2, ...] [PURGE]; - case Token("TOK_ALTERTABLE_DROPPARTS", args) :: _ => - val parts = args.collect { case p @ Token("TOK_PARTSPEC", _) => parsePartitionSpec(p) } - val ifExists = getClauseOption("TOK_IFEXISTS", args).isDefined - val purge = getClauseOption("PURGE", args).isDefined - AlterTableDropPartition(tableIdent, parts, ifExists, purge)(node.source) - - // ALTER TABLE table_name ARCHIVE PARTITION spec; - case Token("TOK_ALTERTABLE_ARCHIVE", spec :: Nil) :: _ => - AlterTableArchivePartition(tableIdent, parsePartitionSpec(spec))(node.source) - - // ALTER TABLE table_name UNARCHIVE PARTITION spec; - case Token("TOK_ALTERTABLE_UNARCHIVE", spec :: Nil) :: _ => - AlterTableUnarchivePartition(tableIdent, parsePartitionSpec(spec))(node.source) - - // ALTER TABLE table_name [PARTITION spec] SET FILEFORMAT file_format; - case Token("TOK_ALTERTABLE_FILEFORMAT", args) :: _ => - val Seq(fileFormat, genericFormat) = - getClauses(Seq("TOK_TABLEFILEFORMAT", "TOK_FILEFORMAT_GENERIC"), args) - // Note: the AST doesn't contain information about which file format is being set here. - // E.g. we can't differentiate between INPUTFORMAT and OUTPUTFORMAT if either is set. - // Right now this just stores the values, but we should figure out how to get the keys. - val fFormat = fileFormat - .map { _.children.map { n => cleanAndUnquoteString(n.text) }} - .getOrElse(Seq()) - val gFormat = genericFormat.map { f => cleanAndUnquoteString(f.children(0).text) } - AlterTableSetFileFormat(tableIdent, partition, fFormat, gFormat)(node.source) - - // ALTER TABLE table_name [PARTITION spec] SET LOCATION "loc"; - case Token("TOK_ALTERTABLE_LOCATION", Token(loc, Nil) :: Nil) :: _ => - AlterTableSetLocation(tableIdent, partition, cleanAndUnquoteString(loc))(node.source) - - // ALTER TABLE table_name TOUCH [PARTITION spec]; - case Token("TOK_ALTERTABLE_TOUCH", args) :: _ => - // Note: the partition spec, if it exists, comes after TOUCH, so `partition` should - // always be None here. Instead, we need to parse it from the TOUCH node's children. - val part = getClauseOption("TOK_PARTSPEC", args).map(parsePartitionSpec) - AlterTableTouch(tableIdent, part)(node.source) - - // ALTER TABLE table_name [PARTITION spec] COMPACT 'compaction_type'; - case Token("TOK_ALTERTABLE_COMPACT", Token(compactType, Nil) :: Nil) :: _ => - AlterTableCompact(tableIdent, partition, cleanAndUnquoteString(compactType))(node.source) - - // ALTER TABLE table_name [PARTITION spec] CONCATENATE; - case Token("TOK_ALTERTABLE_MERGEFILES", _) :: _ => - AlterTableMerge(tableIdent, partition)(node.source) - - // ALTER TABLE table_name [PARTITION spec] CHANGE [COLUMN] col_old_name col_new_name - // column_type [COMMENT col_comment] [FIRST|AFTER column_name] [CASCADE|RESTRICT]; - case Token("TOK_ALTERTABLE_RENAMECOL", oldName :: newName :: dataType :: args) :: _ => - val afterColName: Option[String] = - getClauseOption("TOK_ALTERTABLE_CHANGECOL_AFTER_POSITION", args).map { ap => - ap.children match { - case Token(col, Nil) :: Nil => col - case _ => parseFailed("Invalid ALTER TABLE command", node) - } - } - val restrict = getClauseOption("TOK_RESTRICT", args).isDefined - val cascade = getClauseOption("TOK_CASCADE", args).isDefined - val comment = args.headOption.map { - case Token("TOK_ALTERTABLE_CHANGECOL_AFTER_POSITION", _) => null - case Token("TOK_RESTRICT", _) => null - case Token("TOK_CASCADE", _) => null - case Token(commentStr, Nil) => cleanAndUnquoteString(commentStr) - case _ => parseFailed("Invalid ALTER TABLE command", node) - } - AlterTableChangeCol( - tableIdent, - partition, - oldName.text, - newName.text, - nodeToDataType(dataType), - comment, - afterColName, - restrict, - cascade)(node.source) - - // ALTER TABLE table_name [PARTITION spec] ADD COLUMNS (name type [COMMENT comment], ...) - // [CASCADE|RESTRICT] - case Token("TOK_ALTERTABLE_ADDCOLS", args) :: _ => - val columnNodes = getClause("TOK_TABCOLLIST", args).children - val columns = StructType(columnNodes.map(nodeToStructField)) - val restrict = getClauseOption("TOK_RESTRICT", args).isDefined - val cascade = getClauseOption("TOK_CASCADE", args).isDefined - AlterTableAddCol(tableIdent, partition, columns, restrict, cascade)(node.source) - - // ALTER TABLE table_name [PARTITION spec] REPLACE COLUMNS (name type [COMMENT comment], ...) - // [CASCADE|RESTRICT] - case Token("TOK_ALTERTABLE_REPLACECOLS", args) :: _ => - val columnNodes = getClause("TOK_TABCOLLIST", args).children - val columns = StructType(columnNodes.map(nodeToStructField)) - val restrict = getClauseOption("TOK_RESTRICT", args).isDefined - val cascade = getClauseOption("TOK_CASCADE", args).isDefined - AlterTableReplaceCol(tableIdent, partition, columns, restrict, cascade)(node.source) - - case _ => - parseFailed("Unsupported ALTER TABLE command", node) - } - } - -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index d06e9086a3..6cc72fba48 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -26,7 +26,6 @@ import org.apache.parquet.hadoop.ParquetOutputCommitter import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.CatalystConf -import org.apache.spark.sql.catalyst.parser.ParserConf import org.apache.spark.util.Utils //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -500,19 +499,6 @@ object SQLConf { doc = "When true, we could use `datasource`.`path` as table in SQL query." ) - val PARSER_SUPPORT_QUOTEDID = booleanConf("spark.sql.parser.supportQuotedIdentifiers", - defaultValue = Some(true), - isPublic = false, - doc = "Whether to use quoted identifier.\n false: default(past) behavior. Implies only" + - "alphaNumeric and underscore are valid characters in identifiers.\n" + - " true: implies column names can contain any character.") - - val PARSER_SUPPORT_SQL11_RESERVED_KEYWORDS = booleanConf( - "spark.sql.parser.supportSQL11ReservedKeywords", - defaultValue = Some(false), - isPublic = false, - doc = "This flag should be set to true to enable support for SQL2011 reserved keywords.") - val WHOLESTAGE_CODEGEN_ENABLED = booleanConf("spark.sql.codegen.wholeStage", defaultValue = Some(true), doc = "When true, the whole stage (of multiple operators) will be compiled into single java" + @@ -573,7 +559,7 @@ object SQLConf { * * SQLConf is thread-safe (internally synchronized, so safe to be used in multiple threads). */ -class SQLConf extends Serializable with CatalystConf with ParserConf with Logging { +class SQLConf extends Serializable with CatalystConf with Logging { import SQLConf._ /** Only low degree of contention is expected for conf, thus NOT using ConcurrentHashMap. */ @@ -674,10 +660,6 @@ class SQLConf extends Serializable with CatalystConf with ParserConf with Loggin def runSQLOnFile: Boolean = getConf(RUN_SQL_ON_FILES) - def supportQuotedId: Boolean = getConf(PARSER_SUPPORT_QUOTEDID) - - def supportSQL11ReservedKeywords: Boolean = getConf(PARSER_SUPPORT_SQL11_RESERVED_KEYWORDS) - override def orderByOrdinal: Boolean = getConf(ORDER_BY_ORDINAL) override def groupByOrdinal: Boolean = getConf(GROUP_BY_ORDINAL) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala index 47c9a22acd..f148f2d4ea 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala @@ -21,11 +21,22 @@ import java.io.File import org.apache.spark.sql.{AnalysisException, QueryTest, Row} import org.apache.spark.sql.catalyst.catalog.CatalogDatabase -import org.apache.spark.sql.catalyst.parser.ParserUtils._ import org.apache.spark.sql.test.SharedSQLContext class DDLSuite extends QueryTest with SharedSQLContext { + private val escapedIdentifier = "`(.+)`".r + + /** + * Strip backticks, if any, from the string. + */ + def cleanIdentifier(ident: String): String = { + ident match { + case escapedIdentifier(i) => i + case plainIdent => plainIdent + } + } + /** * Drops database `databaseName` after calling `f`. */ diff --git a/sql/hive/pom.xml b/sql/hive/pom.xml index 22bad93e6d..58efd80512 100644 --- a/sql/hive/pom.xml +++ b/sql/hive/pom.xml @@ -225,25 +225,6 @@ -da -Xmx3g -XX:MaxPermSize=${MaxPermGen} -XX:ReservedCodeCacheSize=512m - - org.codehaus.mojo - build-helper-maven-plugin - - - add-default-sources - generate-sources - - add-source - - - - v${hive.version.short}/src/main/scala - ${project.build.directory/generated-sources/antlr - - - - - diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index 9a5ec9880e..2cdc931c3f 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -25,7 +25,7 @@ import com.google.common.cache.{CacheBuilder, CacheLoader, LoadingCache} import org.apache.hadoop.fs.{FileStatus, Path} import org.apache.hadoop.hive.common.StatsSetupConst import org.apache.hadoop.hive.conf.HiveConf -import org.apache.hadoop.hive.metastore.{TableType => HiveTableType, Warehouse} +import org.apache.hadoop.hive.metastore.{TableType => HiveTableType} import org.apache.hadoop.hive.metastore.api.FieldSchema import org.apache.hadoop.hive.ql.metadata.{Table => HiveTable, _} import org.apache.hadoop.hive.ql.plan.TableDesc @@ -988,3 +988,28 @@ private[hive] object HiveMetastoreTypes { case udt: UserDefinedType[_] => toMetastoreType(udt.sqlType) } } + +private[hive] case class CreateTableAsSelect( + tableDesc: CatalogTable, + child: LogicalPlan, + allowExisting: Boolean) extends UnaryNode with Command { + + override def output: Seq[Attribute] = Seq.empty[Attribute] + override lazy val resolved: Boolean = + tableDesc.identifier.database.isDefined && + tableDesc.schema.nonEmpty && + tableDesc.storage.serde.isDefined && + tableDesc.storage.inputFormat.isDefined && + tableDesc.storage.outputFormat.isDefined && + childrenResolved +} + +private[hive] case class CreateViewAsSelect( + tableDesc: CatalogTable, + child: LogicalPlan, + allowExisting: Boolean, + replace: Boolean, + sql: String) extends UnaryNode with Command { + override def output: Seq[Attribute] = Seq.empty[Attribute] + override lazy val resolved: Boolean = false +} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala deleted file mode 100644 index 052c43a3ce..0000000000 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ /dev/null @@ -1,749 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.hive - -import java.util.Locale - -import scala.collection.JavaConverters._ - -import org.apache.hadoop.hive.common.`type`.HiveDecimal -import org.apache.hadoop.hive.conf.HiveConf -import org.apache.hadoop.hive.conf.HiveConf.ConfVars -import org.apache.hadoop.hive.ql.exec.{FunctionInfo, FunctionRegistry} -import org.apache.hadoop.hive.ql.parse.EximUtil -import org.apache.hadoop.hive.ql.session.SessionState -import org.apache.hadoop.hive.serde.serdeConstants -import org.apache.hadoop.hive.serde2.`lazy`.LazySimpleSerDe - -import org.apache.spark.internal.Logging -import org.apache.spark.sql.catalyst.catalog._ -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.parser._ -import org.apache.spark.sql.catalyst.plans._ -import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.execution.SparkQl -import org.apache.spark.sql.hive.HiveShim.HiveFunctionWrapper -import org.apache.spark.sql.hive.execution._ -import org.apache.spark.sql.types._ -import org.apache.spark.sql.AnalysisException - -/** - * Used when we need to start parsing the AST before deciding that we are going to pass the command - * back for Hive to execute natively. Will be replaced with a native command that contains the - * cmd string. - */ -private[hive] case object NativePlaceholder extends LogicalPlan { - override def children: Seq[LogicalPlan] = Seq.empty - override def output: Seq[Attribute] = Seq.empty -} - -private[hive] case class CreateTableAsSelect( - tableDesc: CatalogTable, - child: LogicalPlan, - allowExisting: Boolean) extends UnaryNode with Command { - - override def output: Seq[Attribute] = Seq.empty[Attribute] - override lazy val resolved: Boolean = - tableDesc.identifier.database.isDefined && - tableDesc.schema.nonEmpty && - tableDesc.storage.serde.isDefined && - tableDesc.storage.inputFormat.isDefined && - tableDesc.storage.outputFormat.isDefined && - childrenResolved -} - -private[hive] case class CreateViewAsSelect( - tableDesc: CatalogTable, - child: LogicalPlan, - allowExisting: Boolean, - replace: Boolean, - sql: String) extends UnaryNode with Command { - override def output: Seq[Attribute] = Seq.empty[Attribute] - override lazy val resolved: Boolean = false -} - -/** Provides a mapping from HiveQL statements to catalyst logical plans and expression trees. */ -private[hive] class HiveQl(conf: ParserConf) extends SparkQl(conf) with Logging { - import ParseUtils._ - import ParserUtils._ - - // Token text -> human readable text - private val hiveUnsupportedCommands = Map( - "TOK_CREATEROLE" -> "CREATE ROLE", - "TOK_DROPROLE" -> "DROP ROLE", - "TOK_EXPORT" -> "EXPORT TABLE", - "TOK_GRANT" -> "GRANT", - "TOK_GRANT_ROLE" -> "GRANT", - "TOK_IMPORT" -> "IMPORT TABLE", - "TOK_REVOKE" -> "REVOKE", - "TOK_REVOKE_ROLE" -> "REVOKE", - "TOK_SHOW_COMPACTIONS" -> "SHOW COMPACTIONS", - "TOK_SHOW_CREATETABLE" -> "SHOW CREATE TABLE", - "TOK_SHOW_GRANT" -> "SHOW GRANT", - "TOK_SHOW_ROLE_GRANT" -> "SHOW ROLE GRANT", - "TOK_SHOW_ROLE_PRINCIPALS" -> "SHOW PRINCIPALS", - "TOK_SHOW_ROLES" -> "SHOW ROLES", - "TOK_SHOW_SET_ROLE" -> "SHOW CURRENT ROLES / SET ROLE", - "TOK_SHOW_TRANSACTIONS" -> "SHOW TRANSACTIONS", - "TOK_SHOWINDEXES" -> "SHOW INDEXES", - "TOK_SHOWLOCKS" -> "SHOW LOCKS") - - private val nativeCommands = Set( - "TOK_ALTERDATABASE_OWNER", - "TOK_ALTERINDEX_PROPERTIES", - "TOK_ALTERINDEX_REBUILD", - "TOK_ALTERTABLE_ALTERPARTS", - "TOK_ALTERTABLE_PARTITION", - "TOK_ALTERVIEW_ADDPARTS", - "TOK_ALTERVIEW_AS", - "TOK_ALTERVIEW_DROPPARTS", - "TOK_ALTERVIEW_PROPERTIES", - "TOK_ALTERVIEW_RENAME", - - "TOK_CREATEINDEX", - "TOK_CREATEMACRO", - - "TOK_DROPINDEX", - "TOK_DROPMACRO", - "TOK_DROPTABLE_PROPERTIES", - "TOK_DROPVIEW", - "TOK_DROPVIEW_PROPERTIES", - - "TOK_LOAD", - - "TOK_LOCKTABLE", - - "TOK_MSCK", - - "TOK_SHOW_TABLESTATUS", - "TOK_SHOW_TBLPROPERTIES", - "TOK_SHOWCOLUMNS", - "TOK_SHOWDATABASES", - "TOK_SHOWPARTITIONS", - - "TOK_UNLOCKTABLE" - ) - - // Commands that we do not need to explain. - private val noExplainCommands = Set( - "TOK_DESCTABLE", - "TOK_SHOWTABLES", - "TOK_TRUNCATETABLE", // truncate table" is a NativeCommand, does not need to explain. - "TOK_ALTERTABLE" - ) ++ nativeCommands - - /** - * Returns the HiveConf - */ - private[this] def hiveConf: HiveConf = { - var ss = SessionState.get() - // SessionState is lazy initialization, it can be null here - if (ss == null) { - val original = Thread.currentThread().getContextClassLoader - val conf = new HiveConf(classOf[SessionState]) - conf.setClassLoader(original) - ss = new SessionState(conf) - SessionState.start(ss) - } - ss.getConf - } - - protected def getProperties(node: ASTNode): Seq[(String, String)] = node match { - case Token("TOK_TABLEPROPLIST", list) => - list.map { - case Token("TOK_TABLEPROPERTY", Token(key, Nil) :: Token(value, Nil) :: Nil) => - unquoteString(key) -> unquoteString(value) - } - } - - private def createView( - view: ASTNode, - viewNameParts: ASTNode, - query: ASTNode, - schema: Seq[CatalogColumn], - properties: Map[String, String], - allowExist: Boolean, - replace: Boolean): CreateViewAsSelect = { - val tableIdentifier = extractTableIdent(viewNameParts) - val originalText = query.source - val tableDesc = CatalogTable( - identifier = tableIdentifier, - tableType = CatalogTableType.VIRTUAL_VIEW, - schema = schema, - storage = CatalogStorageFormat( - locationUri = None, - inputFormat = None, - outputFormat = None, - serde = None, - serdeProperties = Map.empty[String, String] - ), - properties = properties, - viewOriginalText = Some(originalText), - viewText = Some(originalText)) - - // We need to keep the original SQL string so that if `spark.sql.nativeView` is - // false, we can fall back to use hive native command later. - // We can remove this when parser is configurable(can access SQLConf) in the future. - val sql = view.source - CreateViewAsSelect(tableDesc, nodeToPlan(query), allowExist, replace, sql) - } - - /** Creates LogicalPlan for a given SQL string. */ - override def parsePlan(sql: String): LogicalPlan = { - safeParse(sql, ParseDriver.parsePlan(sql, conf)) { ast => - if (nativeCommands.contains(ast.text)) { - HiveNativeCommand(sql) - } else if (hiveUnsupportedCommands.contains(ast.text)) { - val humanReadableText = hiveUnsupportedCommands(ast.text) - throw new AnalysisException("Unsupported operation: " + humanReadableText) - } else { - nodeToPlan(ast) match { - case NativePlaceholder => HiveNativeCommand(sql) - case plan => plan - } - } - } - } - - protected override def isNoExplainCommand(command: String): Boolean = - noExplainCommands.contains(command) - - protected override def nodeToPlan(node: ASTNode): LogicalPlan = { - node match { - case Token("TOK_DFS", Nil) => - HiveNativeCommand(node.source + " " + node.remainder) - - case Token("TOK_ADDFILE", Nil) => - AddFile(node.remainder) - - case Token("TOK_ADDJAR", Nil) => - AddJar(node.remainder) - - // Special drop table that also uncaches. - case Token("TOK_DROPTABLE", Token("TOK_TABNAME", tableNameParts) :: ifExists) => - val tableName = tableNameParts.map { case Token(p, Nil) => p }.mkString(".") - DropTable(tableName, ifExists.nonEmpty) - - // Support "ANALYZE TABLE tableName COMPUTE STATISTICS noscan" - case Token("TOK_ANALYZE", - Token("TOK_TAB", Token("TOK_TABNAME", tableNameParts) :: partitionSpec) :: isNoscan) => - // Reference: - // https://cwiki.apache.org/confluence/display/Hive/StatsDev#StatsDev-ExistingTables - if (partitionSpec.nonEmpty) { - // Analyze partitions will be treated as a Hive native command. - NativePlaceholder - } else if (isNoscan.isEmpty) { - // If users do not specify "noscan", it will be treated as a Hive native command. - NativePlaceholder - } else { - val tableName = tableNameParts.map { case Token(p, Nil) => p }.mkString(".") - AnalyzeTable(tableName) - } - - case view @ Token("TOK_ALTERVIEW", children) => - val Some(nameParts) :: maybeQuery :: _ = - getClauses(Seq( - "TOK_TABNAME", - "TOK_QUERY", - "TOK_ALTERVIEW_ADDPARTS", - "TOK_ALTERVIEW_DROPPARTS", - "TOK_ALTERVIEW_PROPERTIES", - "TOK_ALTERVIEW_RENAME"), children) - - // if ALTER VIEW doesn't have query part, let hive to handle it. - maybeQuery.map { query => - createView(view, nameParts, query, Nil, Map(), allowExist = false, replace = true) - }.getOrElse(NativePlaceholder) - - case view @ Token("TOK_CREATEVIEW", children) - if children.collect { case t @ Token("TOK_QUERY", _) => t }.nonEmpty => - val Seq( - Some(viewNameParts), - Some(query), - maybeComment, - replace, - allowExisting, - maybeProperties, - maybeColumns, - maybePartCols - ) = getClauses(Seq( - "TOK_TABNAME", - "TOK_QUERY", - "TOK_TABLECOMMENT", - "TOK_ORREPLACE", - "TOK_IFNOTEXISTS", - "TOK_TABLEPROPERTIES", - "TOK_TABCOLNAME", - "TOK_VIEWPARTCOLS"), children) - - // If the view is partitioned, we let hive handle it. - if (maybePartCols.isDefined) { - NativePlaceholder - } else { - val schema = maybeColumns.map { cols => - // We can't specify column types when create view, so fill it with null first, and - // update it after the schema has been resolved later. - nodeToColumns(cols, lowerCase = true).map(_.copy(dataType = null)) - }.getOrElse(Seq.empty[CatalogColumn]) - - val properties = scala.collection.mutable.Map.empty[String, String] - - maybeProperties.foreach { - case Token("TOK_TABLEPROPERTIES", list :: Nil) => - properties ++= getProperties(list) - } - - maybeComment.foreach { - case Token("TOK_TABLECOMMENT", child :: Nil) => - val comment = unescapeSQLString(child.text) - if (comment ne null) { - properties += ("comment" -> comment) - } - } - - createView(view, viewNameParts, query, schema, properties.toMap, - allowExisting.isDefined, replace.isDefined) - } - - case Token("TOK_CREATETABLE", children) - if children.collect { case t @ Token("TOK_QUERY", _) => t }.nonEmpty => - // Reference: https://cwiki.apache.org/confluence/display/Hive/LanguageManual+DDL - val ( - Some(tableNameParts) :: - _ /* likeTable */ :: - externalTable :: - Some(query) :: - allowExisting +: - _) = - getClauses( - Seq( - "TOK_TABNAME", - "TOK_LIKETABLE", - "EXTERNAL", - "TOK_QUERY", - "TOK_IFNOTEXISTS", - "TOK_TABLECOMMENT", - "TOK_TABCOLLIST", - "TOK_TABLEPARTCOLS", // Partitioned by - "TOK_TABLEBUCKETS", // Clustered by - "TOK_TABLESKEWED", // Skewed by - "TOK_TABLEROWFORMAT", - "TOK_TABLESERIALIZER", - "TOK_FILEFORMAT_GENERIC", - "TOK_TABLEFILEFORMAT", // User-provided InputFormat and OutputFormat - "TOK_STORAGEHANDLER", // Storage handler - "TOK_TABLELOCATION", - "TOK_TABLEPROPERTIES"), - children) - val tableIdentifier = extractTableIdent(tableNameParts) - - // TODO add bucket support - var tableDesc: CatalogTable = CatalogTable( - identifier = tableIdentifier, - tableType = - if (externalTable.isDefined) { - CatalogTableType.EXTERNAL_TABLE - } else { - CatalogTableType.MANAGED_TABLE - }, - storage = CatalogStorageFormat( - locationUri = None, - inputFormat = None, - outputFormat = None, - serde = None, - serdeProperties = Map.empty[String, String] - ), - schema = Seq.empty[CatalogColumn]) - - // default storage type abbreviation (e.g. RCFile, ORC, PARQUET etc.) - val defaultStorageType = hiveConf.getVar(HiveConf.ConfVars.HIVEDEFAULTFILEFORMAT) - // handle the default format for the storage type abbreviation - val hiveSerDe = HiveSerDe.sourceToSerDe(defaultStorageType, hiveConf).getOrElse { - HiveSerDe( - inputFormat = Option("org.apache.hadoop.mapred.TextInputFormat"), - outputFormat = Option("org.apache.hadoop.hive.ql.io.IgnoreKeyTextOutputFormat")) - } - - tableDesc = tableDesc.withNewStorage( - inputFormat = hiveSerDe.inputFormat.orElse(tableDesc.storage.inputFormat), - outputFormat = hiveSerDe.outputFormat.orElse(tableDesc.storage.outputFormat), - serde = hiveSerDe.serde.orElse(tableDesc.storage.serde)) - - children.collect { - case list @ Token("TOK_TABCOLLIST", _) => - val cols = nodeToColumns(list, lowerCase = true) - if (cols != null) { - tableDesc = tableDesc.copy(schema = cols) - } - case Token("TOK_TABLECOMMENT", child :: Nil) => - val comment = unescapeSQLString(child.text) - // TODO support the sql text - tableDesc = tableDesc.copy(viewText = Option(comment)) - case Token("TOK_TABLEPARTCOLS", list @ Token("TOK_TABCOLLIST", _) :: Nil) => - val cols = nodeToColumns(list.head, lowerCase = false) - if (cols != null) { - tableDesc = tableDesc.copy(partitionColumns = cols) - } - case Token("TOK_TABLEROWFORMAT", Token("TOK_SERDEPROPS", child :: Nil) :: Nil) => - val serdeParams = new java.util.HashMap[String, String]() - child match { - case Token("TOK_TABLEROWFORMATFIELD", rowChild1 :: rowChild2) => - val fieldDelim = unescapeSQLString (rowChild1.text) - serdeParams.put(serdeConstants.FIELD_DELIM, fieldDelim) - serdeParams.put(serdeConstants.SERIALIZATION_FORMAT, fieldDelim) - if (rowChild2.length > 1) { - val fieldEscape = unescapeSQLString (rowChild2.head.text) - serdeParams.put(serdeConstants.ESCAPE_CHAR, fieldEscape) - } - case Token("TOK_TABLEROWFORMATCOLLITEMS", rowChild :: Nil) => - val collItemDelim = unescapeSQLString(rowChild.text) - serdeParams.put(serdeConstants.COLLECTION_DELIM, collItemDelim) - case Token("TOK_TABLEROWFORMATMAPKEYS", rowChild :: Nil) => - val mapKeyDelim = unescapeSQLString(rowChild.text) - serdeParams.put(serdeConstants.MAPKEY_DELIM, mapKeyDelim) - case Token("TOK_TABLEROWFORMATLINES", rowChild :: Nil) => - val lineDelim = unescapeSQLString(rowChild.text) - if (!(lineDelim == "\n") && !(lineDelim == "10")) { - throw new AnalysisException( - s"LINES TERMINATED BY only supports newline '\\n' right now: $rowChild") - } - serdeParams.put(serdeConstants.LINE_DELIM, lineDelim) - case Token("TOK_TABLEROWFORMATNULL", rowChild :: Nil) => - val nullFormat = unescapeSQLString(rowChild.text) - // TODO support the nullFormat - case _ => assert(false) - } - tableDesc = tableDesc.withNewStorage( - serdeProperties = tableDesc.storage.serdeProperties ++ serdeParams.asScala) - case Token("TOK_TABLELOCATION", child :: Nil) => - val location = EximUtil.relativeToAbsolutePath(hiveConf, unescapeSQLString(child.text)) - tableDesc = tableDesc.withNewStorage(locationUri = Option(location)) - case Token("TOK_TABLESERIALIZER", child :: Nil) => - tableDesc = tableDesc.withNewStorage( - serde = Option(unescapeSQLString(child.children.head.text))) - if (child.numChildren == 2) { - // This is based on the readProps(..) method in - // ql/src/java/org/apache/hadoop/hive/ql/parse/BaseSemanticAnalyzer.java: - val serdeParams = child.children(1).children.head.children.map { - case Token(_, Token(prop, Nil) :: valueNode) => - val value = valueNode.headOption - .map(_.text) - .map(unescapeSQLString) - .orNull - (unescapeSQLString(prop), value) - }.toMap - tableDesc = tableDesc.withNewStorage( - serdeProperties = tableDesc.storage.serdeProperties ++ serdeParams) - } - case Token("TOK_FILEFORMAT_GENERIC", child :: Nil) => - child.text.toLowerCase(Locale.ENGLISH) match { - case "orc" => - tableDesc = tableDesc.withNewStorage( - inputFormat = Option("org.apache.hadoop.hive.ql.io.orc.OrcInputFormat"), - outputFormat = Option("org.apache.hadoop.hive.ql.io.orc.OrcOutputFormat")) - if (tableDesc.storage.serde.isEmpty) { - tableDesc = tableDesc.withNewStorage( - serde = Option("org.apache.hadoop.hive.ql.io.orc.OrcSerde")) - } - - case "parquet" => - tableDesc = tableDesc.withNewStorage( - inputFormat = - Option("org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat"), - outputFormat = - Option("org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat")) - if (tableDesc.storage.serde.isEmpty) { - tableDesc = tableDesc.withNewStorage( - serde = Option("org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe")) - } - - case "rcfile" => - tableDesc = tableDesc.withNewStorage( - inputFormat = Option("org.apache.hadoop.hive.ql.io.RCFileInputFormat"), - outputFormat = Option("org.apache.hadoop.hive.ql.io.RCFileOutputFormat")) - if (tableDesc.storage.serde.isEmpty) { - tableDesc = tableDesc.withNewStorage( - serde = - Option("org.apache.hadoop.hive.serde2.columnar.LazyBinaryColumnarSerDe")) - } - - case "textfile" => - tableDesc = tableDesc.withNewStorage( - inputFormat = Option("org.apache.hadoop.mapred.TextInputFormat"), - outputFormat = Option("org.apache.hadoop.hive.ql.io.IgnoreKeyTextOutputFormat")) - - case "sequencefile" => - tableDesc = tableDesc.withNewStorage( - inputFormat = Option("org.apache.hadoop.mapred.SequenceFileInputFormat"), - outputFormat = Option("org.apache.hadoop.mapred.SequenceFileOutputFormat")) - - case "avro" => - tableDesc = tableDesc.withNewStorage( - inputFormat = - Option("org.apache.hadoop.hive.ql.io.avro.AvroContainerInputFormat"), - outputFormat = - Option("org.apache.hadoop.hive.ql.io.avro.AvroContainerOutputFormat")) - if (tableDesc.storage.serde.isEmpty) { - tableDesc = tableDesc.withNewStorage( - serde = Option("org.apache.hadoop.hive.serde2.avro.AvroSerDe")) - } - - case _ => - throw new AnalysisException( - s"Unrecognized file format in STORED AS clause: ${child.text}") - } - - case Token("TOK_TABLESERIALIZER", - Token("TOK_SERDENAME", Token(serdeName, Nil) :: otherProps) :: Nil) => - tableDesc = tableDesc.withNewStorage(serde = Option(unquoteString(serdeName))) - - otherProps match { - case Token("TOK_TABLEPROPERTIES", list :: Nil) :: Nil => - tableDesc = tableDesc.withNewStorage( - serdeProperties = tableDesc.storage.serdeProperties ++ getProperties(list)) - case _ => - } - - case Token("TOK_TABLEPROPERTIES", list :: Nil) => - tableDesc = tableDesc.copy(properties = tableDesc.properties ++ getProperties(list)) - case list @ Token("TOK_TABLEFILEFORMAT", _) => - tableDesc = tableDesc.withNewStorage( - inputFormat = Option(unescapeSQLString(list.children.head.text)), - outputFormat = Option(unescapeSQLString(list.children(1).text))) - case Token("TOK_STORAGEHANDLER", _) => - throw new AnalysisException( - "CREATE TABLE AS SELECT cannot be used for a non-native table") - case _ => // Unsupported features - } - - CreateTableAsSelect(tableDesc, nodeToPlan(query), allowExisting.isDefined) - - // If its not a "CTAS" like above then take it as a native command - case Token("TOK_CREATETABLE", _) => - NativePlaceholder - - // Support "TRUNCATE TABLE table_name [PARTITION partition_spec]" - case Token("TOK_TRUNCATETABLE", Token("TOK_TABLE_PARTITION", table) :: Nil) => - NativePlaceholder - - case _ => - super.nodeToPlan(node) - } - } - - protected override def nodeToDescribeFallback(node: ASTNode): LogicalPlan = NativePlaceholder - - protected override def nodeToTransformation( - node: ASTNode, - child: LogicalPlan): Option[logical.ScriptTransformation] = node match { - case Token("TOK_SELEXPR", - Token("TOK_TRANSFORM", - Token("TOK_EXPLIST", inputExprs) :: - Token("TOK_SERDE", inputSerdeClause) :: - Token("TOK_RECORDWRITER", writerClause) :: - // TODO: Need to support other types of (in/out)put - Token(script, Nil) :: - Token("TOK_SERDE", outputSerdeClause) :: - Token("TOK_RECORDREADER", readerClause) :: - outputClause) :: Nil) => - - val (output, schemaLess) = outputClause match { - case Token("TOK_ALIASLIST", aliases) :: Nil => - (aliases.map { case Token(name, Nil) => - AttributeReference(cleanIdentifier(name), StringType)() }, false) - case Token("TOK_TABCOLLIST", attributes) :: Nil => - (attributes.map { case Token("TOK_TABCOL", Token(name, Nil) :: dataType :: Nil) => - AttributeReference(cleanIdentifier(name), nodeToDataType(dataType))() }, false) - case Nil => - (List(AttributeReference("key", StringType)(), - AttributeReference("value", StringType)()), true) - case _ => - noParseRule("Transform", node) - } - - type SerDeInfo = ( - Seq[(String, String)], // Input row format information - Option[String], // Optional input SerDe class - Seq[(String, String)], // Input SerDe properties - Boolean // Whether to use default record reader/writer - ) - - def matchSerDe(clause: Seq[ASTNode]): SerDeInfo = clause match { - case Token("TOK_SERDEPROPS", propsClause) :: Nil => - val rowFormat = propsClause.map { - case Token(name, Token(value, Nil) :: Nil) => (name, value) - } - (rowFormat, None, Nil, false) - - case Token("TOK_SERDENAME", Token(serdeClass, Nil) :: Nil) :: Nil => - (Nil, Some(unescapeSQLString(serdeClass)), Nil, false) - - case Token("TOK_SERDENAME", Token(serdeClass, Nil) :: - Token("TOK_TABLEPROPERTIES", - Token("TOK_TABLEPROPLIST", propsClause) :: Nil) :: Nil) :: Nil => - val serdeProps = propsClause.map { - case Token("TOK_TABLEPROPERTY", Token(name, Nil) :: Token(value, Nil) :: Nil) => - (unescapeSQLString(name), unescapeSQLString(value)) - } - - // SPARK-10310: Special cases LazySimpleSerDe - // TODO Fully supports user-defined record reader/writer classes - val unescapedSerDeClass = unescapeSQLString(serdeClass) - val useDefaultRecordReaderWriter = - unescapedSerDeClass == classOf[LazySimpleSerDe].getCanonicalName - (Nil, Some(unescapedSerDeClass), serdeProps, useDefaultRecordReaderWriter) - - case Nil => - // Uses default TextRecordReader/TextRecordWriter, sets field delimiter here - val serdeProps = Seq(serdeConstants.FIELD_DELIM -> "\t") - (Nil, Option(hiveConf.getVar(ConfVars.HIVESCRIPTSERDE)), serdeProps, true) - } - - val (inRowFormat, inSerdeClass, inSerdeProps, useDefaultRecordReader) = - matchSerDe(inputSerdeClause) - - val (outRowFormat, outSerdeClass, outSerdeProps, useDefaultRecordWriter) = - matchSerDe(outputSerdeClause) - - val unescapedScript = unescapeSQLString(script) - - // TODO Adds support for user-defined record reader/writer classes - val recordReaderClass = if (useDefaultRecordReader) { - Option(hiveConf.getVar(ConfVars.HIVESCRIPTRECORDREADER)) - } else { - None - } - - val recordWriterClass = if (useDefaultRecordWriter) { - Option(hiveConf.getVar(ConfVars.HIVESCRIPTRECORDWRITER)) - } else { - None - } - - val schema = HiveScriptIOSchema( - inRowFormat, outRowFormat, - inSerdeClass, outSerdeClass, - inSerdeProps, outSerdeProps, - recordReaderClass, recordWriterClass, - schemaLess) - - Some( - logical.ScriptTransformation( - inputExprs.map(nodeToExpr), - unescapedScript, - output, - child, schema)) - case _ => None - } - - protected override def nodeToGenerator(node: ASTNode): Generator = node match { - case Token("TOK_FUNCTION", Token(functionName, Nil) :: children) => - val functionInfo: FunctionInfo = - Option(FunctionRegistry.getFunctionInfo(functionName.toLowerCase)).getOrElse( - sys.error(s"Couldn't find function $functionName")) - val functionClassName = functionInfo.getFunctionClass.getName - HiveGenericUDTF( - functionName, new HiveFunctionWrapper(functionClassName), children.map(nodeToExpr)) - case other => super.nodeToGenerator(node) - } - - // This is based the getColumns methods in - // ql/src/java/org/apache/hadoop/hive/ql/parse/BaseSemanticAnalyzer.java - protected def nodeToColumns(node: ASTNode, lowerCase: Boolean): Seq[CatalogColumn] = { - node.children.map(_.children).collect { - case Token(rawColName, Nil) :: colTypeNode :: comment => - val colName = if (!lowerCase) rawColName else rawColName.toLowerCase - CatalogColumn( - name = cleanIdentifier(colName), - dataType = nodeToTypeString(colTypeNode), - nullable = true, - comment.headOption.map(n => unescapeSQLString(n.text))) - } - } - - // This is based on the following methods in - // ql/src/java/org/apache/hadoop/hive/ql/parse/BaseSemanticAnalyzer.java: - // getTypeStringFromAST - // getStructTypeStringFromAST - // getUnionTypeStringFromAST - protected def nodeToTypeString(node: ASTNode): String = node.tokenType match { - case SparkSqlParser.TOK_LIST => - val listType :: Nil = node.children - val listTypeString = nodeToTypeString(listType) - s"${serdeConstants.LIST_TYPE_NAME}<$listTypeString>" - - case SparkSqlParser.TOK_MAP => - val keyType :: valueType :: Nil = node.children - val keyTypeString = nodeToTypeString(keyType) - val valueTypeString = nodeToTypeString(valueType) - s"${serdeConstants.MAP_TYPE_NAME}<$keyTypeString,$valueTypeString>" - - case SparkSqlParser.TOK_STRUCT => - val typeNode = node.children.head - require(typeNode.children.nonEmpty, "Struct must have one or more columns.") - val structColStrings = typeNode.children.map { columnNode => - val Token(colName, Nil) :: colTypeNode :: Nil = columnNode.children - cleanIdentifier(colName) + ":" + nodeToTypeString(colTypeNode) - } - s"${serdeConstants.STRUCT_TYPE_NAME}<${structColStrings.mkString(",")}>" - - case SparkSqlParser.TOK_UNIONTYPE => - val typeNode = node.children.head - val unionTypesString = typeNode.children.map(nodeToTypeString).mkString(",") - s"${serdeConstants.UNION_TYPE_NAME}<$unionTypesString>" - - case SparkSqlParser.TOK_CHAR => - val Token(size, Nil) :: Nil = node.children - s"${serdeConstants.CHAR_TYPE_NAME}($size)" - - case SparkSqlParser.TOK_VARCHAR => - val Token(size, Nil) :: Nil = node.children - s"${serdeConstants.VARCHAR_TYPE_NAME}($size)" - - case SparkSqlParser.TOK_DECIMAL => - val precisionAndScale = node.children match { - case Token(precision, Nil) :: Token(scale, Nil) :: Nil => - precision + "," + scale - case Token(precision, Nil) :: Nil => - precision + "," + HiveDecimal.USER_DEFAULT_SCALE - case Nil => - HiveDecimal.USER_DEFAULT_PRECISION + "," + HiveDecimal.USER_DEFAULT_SCALE - case _ => - noParseRule("Decimal", node) - } - s"${serdeConstants.DECIMAL_TYPE_NAME}($precisionAndScale)" - - // Simple data types. - case SparkSqlParser.TOK_BOOLEAN => serdeConstants.BOOLEAN_TYPE_NAME - case SparkSqlParser.TOK_TINYINT => serdeConstants.TINYINT_TYPE_NAME - case SparkSqlParser.TOK_SMALLINT => serdeConstants.SMALLINT_TYPE_NAME - case SparkSqlParser.TOK_INT => serdeConstants.INT_TYPE_NAME - case SparkSqlParser.TOK_BIGINT => serdeConstants.BIGINT_TYPE_NAME - case SparkSqlParser.TOK_FLOAT => serdeConstants.FLOAT_TYPE_NAME - case SparkSqlParser.TOK_DOUBLE => serdeConstants.DOUBLE_TYPE_NAME - case SparkSqlParser.TOK_STRING => serdeConstants.STRING_TYPE_NAME - case SparkSqlParser.TOK_BINARY => serdeConstants.BINARY_TYPE_NAME - case SparkSqlParser.TOK_DATE => serdeConstants.DATE_TYPE_NAME - case SparkSqlParser.TOK_TIMESTAMP => serdeConstants.TIMESTAMP_TYPE_NAME - case SparkSqlParser.TOK_INTERVAL_YEAR_MONTH => serdeConstants.INTERVAL_YEAR_MONTH_TYPE_NAME - case SparkSqlParser.TOK_INTERVAL_DAY_TIME => serdeConstants.INTERVAL_DAY_TIME_TYPE_NAME - case SparkSqlParser.TOK_DATETIME => serdeConstants.DATETIME_TYPE_NAME - case _ => null - } - -} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveSqlParser.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveSqlParser.scala index d6a08fcc57..12e4f49756 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveSqlParser.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveSqlParser.scala @@ -29,8 +29,8 @@ import org.apache.hadoop.hive.serde2.`lazy`.LazySimpleSerDe import org.apache.spark.sql.catalyst.catalog.{CatalogColumn, CatalogStorageFormat, CatalogTable, CatalogTableType} import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.parser.ng._ -import org.apache.spark.sql.catalyst.parser.ng.SqlBaseParser._ +import org.apache.spark.sql.catalyst.parser._ +import org.apache.spark.sql.catalyst.parser.SqlBaseParser._ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.SparkSqlAstBuilder import org.apache.spark.sql.hive.{CreateTableAsSelect => CTAS, CreateViewAsSelect => CreateView} @@ -161,18 +161,10 @@ class HiveSqlAstBuilder extends SparkSqlAstBuilder { } // Create the schema. - val schema = Option(ctx.colTypeList).toSeq.flatMap(_.colType.asScala).map { col => - CatalogColumn( - col.identifier.getText, - col.dataType.getText.toLowerCase, // TODO validate this? - nullable = true, - Option(col.STRING).map(string)) - } + val schema = Option(ctx.columns).toSeq.flatMap(visitCatalogColumns(_, _.toLowerCase)) // Get the column by which the table is partitioned. - val partitionCols = Option(ctx.identifierList).toSeq.flatMap(visitIdentifierList).map { - CatalogColumn(_, null, nullable = true, None) - } + val partitionCols = Option(ctx.partitionColumns).toSeq.flatMap(visitCatalogColumns(_)) // Create the storage. def format(fmt: ParserRuleContext): CatalogStorageFormat = { @@ -439,4 +431,19 @@ class HiveSqlAstBuilder extends SparkSqlAstBuilder { } EmptyStorageFormat.copy(serdeProperties = entries.toMap) } + + /** + * Create a sequence of [[CatalogColumn]]s from a column list + */ + private def visitCatalogColumns( + ctx: ColTypeListContext, + formatter: String => String = identity): Seq[CatalogColumn] = withOrigin(ctx) { + ctx.colType.asScala.map { col => + CatalogColumn( + formatter(col.identifier.getText), + col.dataType.getText.toLowerCase, // TODO validate this? + nullable = true, + Option(col.STRING).map(string)) + } + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ErrorPositionSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ErrorPositionSuite.scala index 4b6da7cd33..d9664680f4 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ErrorPositionSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ErrorPositionSuite.scala @@ -22,8 +22,8 @@ import scala.util.Try import org.scalatest.BeforeAndAfterEach import org.apache.spark.sql.{AnalysisException, QueryTest} -import org.apache.spark.sql.catalyst.parser.ParseDriver import org.apache.spark.sql.catalyst.util.quietly +import org.apache.spark.sql.hive.execution.HiveSqlParser import org.apache.spark.sql.hive.test.TestHiveSingleton class ErrorPositionSuite extends QueryTest with TestHiveSingleton with BeforeAndAfterEach { @@ -131,7 +131,7 @@ class ErrorPositionSuite extends QueryTest with TestHiveSingleton with BeforeAnd * @param token a unique token in the string that should be indicated by the exception */ def positionTest(name: String, query: String, token: String): Unit = { - def ast = ParseDriver.parsePlan(query, hiveContext.conf) + def ast = HiveSqlParser.parsePlan(query) def parseTree = Try(quietly(ast.treeString)).getOrElse("") test(name) { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveQlSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveQlSuite.scala index 0aaf57649c..75108c6d47 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveQlSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveQlSuite.scala @@ -24,11 +24,11 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.catalog.{CatalogColumn, CatalogTable, CatalogTableType} import org.apache.spark.sql.catalyst.expressions.JsonTuple -import org.apache.spark.sql.catalyst.parser.SimpleParserConf import org.apache.spark.sql.catalyst.plans.logical.Generate +import org.apache.spark.sql.hive.execution.HiveSqlParser class HiveQlSuite extends SparkFunSuite with BeforeAndAfterAll { - val parser = new HiveQl(SimpleParserConf()) + val parser = HiveSqlParser private def extractTableDesc(sql: String): (CatalogTable, Boolean) = { parser.parsePlan(sql).collect { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala index ae026ed496..05318f51af 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala @@ -21,7 +21,6 @@ import scala.reflect.ClassTag import org.apache.spark.sql.{QueryTest, Row} import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.catalyst.parser.SimpleParserConf import org.apache.spark.sql.execution.joins._ import org.apache.spark.sql.hive.execution._ import org.apache.spark.sql.hive.test.TestHiveSingleton @@ -30,11 +29,9 @@ import org.apache.spark.sql.internal.SQLConf class StatisticsSuite extends QueryTest with TestHiveSingleton { import hiveContext.sql - val parser = new HiveQl(SimpleParserConf()) - test("parse analyze commands") { def assertAnalyzeCommand(analyzeCommand: String, c: Class[_]) { - val parsed = parser.parsePlan(analyzeCommand) + val parsed = HiveSqlParser.parsePlan(analyzeCommand) val operators = parsed.collect { case a: AnalyzeTable => a case o => o -- cgit v1.2.3 From 208fff3ac87f200fd4e6f0407d70bf81cf8c556f Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Thu, 31 Mar 2016 09:39:15 -0700 Subject: [SPARK-14164][MLLIB] Improve input layer validation of MultilayerPerceptronClassifier ## What changes were proposed in this pull request? This issue improves an input layer validation and adds related testcases to MultilayerPerceptronClassifier. ```scala - // TODO: how to check ALSO that all elements are greater than 0? - ParamValidators.arrayLengthGt(1) + (t: Array[Int]) => t.forall(ParamValidators.gt(0)) && t.length > 1 ``` ## How was this patch tested? Pass the Jenkins tests including the new testcases. Author: Dongjoon Hyun Closes #11964 from dongjoon-hyun/SPARK-14164. --- .../classification/MultilayerPerceptronClassifier.scala | 3 +-- .../MultilayerPerceptronClassifierSuite.scala | 17 +++++++++++++++++ 2 files changed, 18 insertions(+), 2 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala index f6de5f2df4..7ce3ec68da 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala @@ -43,8 +43,7 @@ private[ml] trait MultilayerPerceptronParams extends PredictorParams "Sizes of layers from input layer to output layer" + " E.g., Array(780, 100, 10) means 780 inputs, " + "one hidden layer with 100 neurons and output layer of 10 neurons.", - // TODO: how to check ALSO that all elements are greater than 0? - ParamValidators.arrayLengthGt(1) + (t: Array[Int]) => t.forall(ParamValidators.gt(0)) && t.length > 1 ) /** @group getParam */ diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala index 5df8e6a847..53c7a559e3 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala @@ -43,6 +43,23 @@ class MultilayerPerceptronClassifierSuite ).toDF("features", "label") } + test("Input Validation") { + val mlpc = new MultilayerPerceptronClassifier() + intercept[IllegalArgumentException] { + mlpc.setLayers(Array[Int]()) + } + intercept[IllegalArgumentException] { + mlpc.setLayers(Array[Int](1)) + } + intercept[IllegalArgumentException] { + mlpc.setLayers(Array[Int](0, 1)) + } + intercept[IllegalArgumentException] { + mlpc.setLayers(Array[Int](1, 0)) + } + mlpc.setLayers(Array[Int](1, 1)) + } + test("XOR function learning as binary classification problem with two outputs.") { val layers = Array[Int](2, 5, 2) val trainer = new MultilayerPerceptronClassifier() -- cgit v1.2.3 From 3b3cc76004438a942ecea752db39f3a904a52462 Mon Sep 17 00:00:00 2001 From: jerryshao Date: Thu, 31 Mar 2016 10:27:33 -0700 Subject: [SPARK-14062][YARN] Fix log4j and upload metrics.properties automatically with distributed cache ## What changes were proposed in this pull request? 1. Currently log4j which uses distributed cache only adds to AM's classpath, not executor's, this is introduced in #9118, which breaks the original meaning of that PR, so here add log4j file to the classpath of both AM and executors. 2. Automatically upload metrics.properties to distributed cache, so that it could be used by remote driver and executors implicitly. ## How was this patch tested? Unit test and integration test is done. Author: jerryshao Closes #11885 from jerryshao/SPARK-14062. --- .../org/apache/spark/deploy/yarn/Client.scala | 71 ++++++++-------------- .../spark/deploy/yarn/ExecutorRunnable.scala | 3 +- .../org/apache/spark/deploy/yarn/ClientSuite.scala | 7 +-- 3 files changed, 31 insertions(+), 50 deletions(-) diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index 7b29c1ae4d..f0f13a16e0 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -351,14 +351,6 @@ private[spark] class Client( val statCache: Map[URI, FileStatus] = HashMap[URI, FileStatus]() - val oldLog4jConf = Option(System.getenv("SPARK_LOG4J_CONF")) - if (oldLog4jConf.isDefined) { - logWarning( - "SPARK_LOG4J_CONF detected in the system environment. This variable has been " + - "deprecated. Please refer to the \"Launching Spark on YARN\" documentation " + - "for alternatives.") - } - def addDistributedUri(uri: URI): Boolean = { val uriStr = uri.toString() if (distributedUris.contains(uriStr)) { @@ -479,25 +471,16 @@ private[spark] class Client( } /** - * Copy a few resources to the distributed cache if their scheme is not "local". + * Copy user jar to the distributed cache if their scheme is not "local". * Otherwise, set the corresponding key in our SparkConf to handle it downstream. - * Each resource is represented by a 3-tuple of: - * (1) destination resource name, - * (2) local path to the resource, - * (3) Spark property key to set if the scheme is not local */ - List( - (APP_JAR_NAME, args.userJar, APP_JAR), - ("log4j.properties", oldLog4jConf.orNull, null) - ).foreach { case (destName, path, confKey) => - if (path != null && !path.trim().isEmpty()) { - val (isLocal, localizedPath) = distribute(path, destName = Some(destName)) - if (isLocal && confKey != null) { - require(localizedPath != null, s"Path $path already distributed.") - // If the resource is intended for local use only, handle this downstream - // by setting the appropriate property - sparkConf.set(confKey, localizedPath) - } + Option(args.userJar).filter(_.trim.nonEmpty).foreach { jar => + val (isLocal, localizedPath) = distribute(jar, destName = Some(APP_JAR_NAME)) + if (isLocal) { + require(localizedPath != null, s"Path $jar already distributed") + // If the resource is intended for local use only, handle this downstream + // by setting the appropriate property + sparkConf.set(APP_JAR, localizedPath) } } @@ -541,11 +524,10 @@ private[spark] class Client( distribute(f, targetDir = targetDir) } - // Distribute an archive with Hadoop and Spark configuration for the AM. + // Distribute an archive with Hadoop and Spark configuration for the AM and executors. val (_, confLocalizedPath) = distribute(createConfArchive().toURI().getPath(), resType = LocalResourceType.ARCHIVE, - destName = Some(LOCALIZED_CONF_DIR), - appMasterOnly = true) + destName = Some(LOCALIZED_CONF_DIR)) require(confLocalizedPath != null) localResources @@ -554,10 +536,10 @@ private[spark] class Client( /** * Create an archive with the config files for distribution. * - * These are only used by the AM, since executors will use the configuration object broadcast by - * the driver. The files are zipped and added to the job as an archive, so that YARN will explode - * it when distributing to the AM. This directory is then added to the classpath of the AM - * process, just to make sure that everybody is using the same default config. + * These will be used by AM and executors. The files are zipped and added to the job as an + * archive, so that YARN will explode it when distributing to AM and executors. This directory + * is then added to the classpath of AM and executor process, just to make sure that everybody + * is using the same default config. * * This follows the order of precedence set by the startup scripts, in which HADOOP_CONF_DIR * shows up in the classpath before YARN_CONF_DIR. @@ -576,11 +558,14 @@ private[spark] class Client( // required when user changes log4j.properties directly to set the log configurations. If // configuration file is provided through --files then executors will be taking configurations // from --files instead of $SPARK_CONF_DIR/log4j.properties. - val log4jFileName = "log4j.properties" - Option(Utils.getContextOrSparkClassLoader.getResource(log4jFileName)).foreach { url => - if (url.getProtocol == "file") { - hadoopConfFiles(log4jFileName) = new File(url.getPath) - } + + // Also uploading metrics.properties to distributed cache if exists in classpath. + // If user specify this file using --files then executors will use the one + // from --files instead. + for { prop <- Seq("log4j.properties", "metrics.properties") + url <- Option(Utils.getContextOrSparkClassLoader.getResource(prop)) + if url.getProtocol == "file" } { + hadoopConfFiles(prop) = new File(url.getPath) } Seq("HADOOP_CONF_DIR", "YARN_CONF_DIR").foreach { envKey => @@ -659,7 +644,7 @@ private[spark] class Client( pySparkArchives: Seq[String]): HashMap[String, String] = { logInfo("Setting up the launch environment for our AM container") val env = new HashMap[String, String]() - populateClasspath(args, yarnConf, sparkConf, env, true, sparkConf.get(DRIVER_CLASS_PATH)) + populateClasspath(args, yarnConf, sparkConf, env, sparkConf.get(DRIVER_CLASS_PATH)) env("SPARK_YARN_MODE") = "true" env("SPARK_YARN_STAGING_DIR") = stagingDir env("SPARK_USER") = UserGroupInformation.getCurrentUser().getShortUserName() @@ -1236,18 +1221,16 @@ object Client extends Logging { conf: Configuration, sparkConf: SparkConf, env: HashMap[String, String], - isAM: Boolean, extraClassPath: Option[String] = None): Unit = { extraClassPath.foreach { cp => addClasspathEntry(getClusterPath(sparkConf, cp), env) } + addClasspathEntry(YarnSparkHadoopUtil.expandEnvironment(Environment.PWD), env) - if (isAM) { - addClasspathEntry( - YarnSparkHadoopUtil.expandEnvironment(Environment.PWD) + Path.SEPARATOR + - LOCALIZED_CONF_DIR, env) - } + addClasspathEntry( + YarnSparkHadoopUtil.expandEnvironment(Environment.PWD) + Path.SEPARATOR + + LOCALIZED_CONF_DIR, env) if (sparkConf.get(USER_CLASS_PATH_FIRST)) { // in order to properly add the app jar when user classpath is first diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala index f956a4d1d5..7b55d781f8 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala @@ -289,8 +289,7 @@ private[yarn] class ExecutorRunnable( private def prepareEnvironment(container: Container): HashMap[String, String] = { val env = new HashMap[String, String]() - Client.populateClasspath(null, yarnConf, sparkConf, env, false, - sparkConf.get(EXECUTOR_CLASS_PATH)) + Client.populateClasspath(null, yarnConf, sparkConf, env, sparkConf.get(EXECUTOR_CLASS_PATH)) sparkConf.getExecutorEnv.foreach { case (key, value) => // This assumes each executor environment variable set here is a path diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala index e3613a93ed..64723c361c 100644 --- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala @@ -121,7 +121,7 @@ class ClientSuite extends SparkFunSuite with Matchers with BeforeAndAfterAll val env = new MutableHashMap[String, String]() val args = new ClientArguments(Array("--jar", USER, "--addJars", ADDED), sparkConf) - populateClasspath(args, conf, sparkConf, env, true) + populateClasspath(args, conf, sparkConf, env) val cp = env("CLASSPATH").split(":|;|") s"$SPARK,$USER,$ADDED".split(",").foreach({ entry => @@ -178,8 +178,7 @@ class ClientSuite extends SparkFunSuite with Matchers with BeforeAndAfterAll "/remotePath/1:/remotePath/2") val env = new MutableHashMap[String, String]() - populateClasspath(null, conf, sparkConf, env, false, - extraClassPath = Some("/localPath/my1.jar")) + populateClasspath(null, conf, sparkConf, env, extraClassPath = Some("/localPath/my1.jar")) val cp = classpath(env) cp should contain ("/remotePath/spark.jar") cp should contain ("/remotePath/my1.jar") @@ -356,7 +355,7 @@ class ClientSuite extends SparkFunSuite with Matchers with BeforeAndAfterAll private def classpath(client: Client): Array[String] = { val env = new MutableHashMap[String, String]() - populateClasspath(null, client.hadoopConf, client.sparkConf, env, false) + populateClasspath(null, client.hadoopConf, client.sparkConf, env) classpath(env) } -- cgit v1.2.3 From a0a1991580ed24230f88cae9f5a4dfbe58f03b28 Mon Sep 17 00:00:00 2001 From: Yuhao Yang Date: Thu, 31 Mar 2016 11:12:40 -0700 Subject: [SPARK-13782][ML] Model export/import for spark.ml: BisectingKMeans ## What changes were proposed in this pull request? jira: https://issues.apache.org/jira/browse/SPARK-13782 Model export/import for BisectingKMeans in spark.ml and mllib ## How was this patch tested? unit tests Author: Yuhao Yang Closes #11933 from hhbyyh/bisectingsave. --- .../spark/ml/clustering/BisectingKMeans.scala | 59 +++++++++++-- .../spark/mllib/clustering/BisectingKMeans.scala | 2 +- .../mllib/clustering/BisectingKMeansModel.scala | 98 +++++++++++++++++++++- .../spark/ml/clustering/BisectingKMeansSuite.scala | 22 ++++- .../mllib/clustering/BisectingKMeansSuite.scala | 18 ++++ 5 files changed, 190 insertions(+), 9 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala index f014a1d572..55f751c57f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala @@ -17,11 +17,13 @@ package org.apache.spark.ml.clustering +import org.apache.hadoop.fs.Path + import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.{Estimator, Model} -import org.apache.spark.ml.param.{IntParam, Param, ParamMap, Params} +import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ -import org.apache.spark.ml.util.{Identifiable, SchemaUtils} +import org.apache.spark.ml.util._ import org.apache.spark.mllib.clustering. {BisectingKMeans => MLlibBisectingKMeans, BisectingKMeansModel => MLlibBisectingKMeansModel} import org.apache.spark.mllib.linalg.{Vector, VectorUDT} @@ -49,7 +51,7 @@ private[clustering] trait BisectingKMeansParams extends Params /** @group expertParam */ @Since("2.0.0") - final val minDivisibleClusterSize = new Param[Double]( + final val minDivisibleClusterSize = new DoubleParam( this, "minDivisibleClusterSize", "the minimum number of points (if >= 1.0) or the minimum proportion", @@ -81,7 +83,7 @@ private[clustering] trait BisectingKMeansParams extends Params class BisectingKMeansModel private[ml] ( @Since("2.0.0") override val uid: String, private val parentModel: MLlibBisectingKMeansModel - ) extends Model[BisectingKMeansModel] with BisectingKMeansParams { + ) extends Model[BisectingKMeansModel] with BisectingKMeansParams with MLWritable { @Since("2.0.0") override def copy(extra: ParamMap): BisectingKMeansModel = { @@ -115,6 +117,44 @@ class BisectingKMeansModel private[ml] ( val data = dataset.select(col($(featuresCol))).rdd.map { case Row(point: Vector) => point } parentModel.computeCost(data) } + + @Since("2.0.0") + override def write: MLWriter = new BisectingKMeansModel.BisectingKMeansModelWriter(this) +} + +object BisectingKMeansModel extends MLReadable[BisectingKMeansModel] { + @Since("2.0.0") + override def read: MLReader[BisectingKMeansModel] = new BisectingKMeansModelReader + + @Since("2.0.0") + override def load(path: String): BisectingKMeansModel = super.load(path) + + /** [[MLWriter]] instance for [[BisectingKMeansModel]] */ + private[BisectingKMeansModel] + class BisectingKMeansModelWriter(instance: BisectingKMeansModel) extends MLWriter { + + override protected def saveImpl(path: String): Unit = { + // Save metadata and Params + DefaultParamsWriter.saveMetadata(instance, path, sc) + val dataPath = new Path(path, "data").toString + instance.parentModel.save(sc, dataPath) + } + } + + private class BisectingKMeansModelReader extends MLReader[BisectingKMeansModel] { + + /** Checked against metadata when loading model */ + private val className = classOf[BisectingKMeansModel].getName + + override def load(path: String): BisectingKMeansModel = { + val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + val dataPath = new Path(path, "data").toString + val mllibModel = MLlibBisectingKMeansModel.load(sc, dataPath) + val model = new BisectingKMeansModel(metadata.uid, mllibModel) + DefaultParamsReader.getAndSetParams(model, metadata) + model + } + } } /** @@ -137,7 +177,7 @@ class BisectingKMeansModel private[ml] ( @Experimental class BisectingKMeans @Since("2.0.0") ( @Since("2.0.0") override val uid: String) - extends Estimator[BisectingKMeansModel] with BisectingKMeansParams { + extends Estimator[BisectingKMeansModel] with BisectingKMeansParams with DefaultParamsWritable { setDefault( k -> 4, @@ -148,7 +188,7 @@ class BisectingKMeans @Since("2.0.0") ( override def copy(extra: ParamMap): BisectingKMeans = defaultCopy(extra) @Since("2.0.0") - def this() = this(Identifiable.randomUID("bisecting k-means")) + def this() = this(Identifiable.randomUID("bisecting-kmeans")) /** @group setParam */ @Since("2.0.0") @@ -194,3 +234,10 @@ class BisectingKMeans @Since("2.0.0") ( } } + +@Since("2.0.0") +object BisectingKMeans extends DefaultParamsReadable[BisectingKMeans] { + + @Since("2.0.0") + override def load(path: String): BisectingKMeans = super.load(path) +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeans.scala index 64b838a1db..e4bd0dc25e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeans.scala @@ -411,7 +411,7 @@ private object BisectingKMeans extends Serializable { private[clustering] class ClusteringTreeNode private[clustering] ( val index: Int, val size: Long, - private val centerWithNorm: VectorWithNorm, + private[clustering] val centerWithNorm: VectorWithNorm, val cost: Double, val height: Double, val children: Array[ClusteringTreeNode]) extends Serializable { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeansModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeansModel.scala index 01a0d31f14..c3b5b8b790 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeansModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeansModel.scala @@ -17,11 +17,19 @@ package org.apache.spark.mllib.clustering +import org.json4s._ +import org.json4s.DefaultFormats +import org.json4s.jackson.JsonMethods._ +import org.json4s.JsonDSL._ + +import org.apache.spark.SparkContext import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.api.java.JavaRDD import org.apache.spark.internal.Logging import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.mllib.util.{Loader, Saveable} import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{Row, SQLContext} /** * Clustering model produced by [[BisectingKMeans]]. @@ -34,7 +42,7 @@ import org.apache.spark.rdd.RDD @Experimental class BisectingKMeansModel private[clustering] ( private[clustering] val root: ClusteringTreeNode - ) extends Serializable with Logging { + ) extends Serializable with Saveable with Logging { /** * Leaf cluster centers. @@ -92,4 +100,92 @@ class BisectingKMeansModel private[clustering] ( */ @Since("1.6.0") def computeCost(data: JavaRDD[Vector]): Double = this.computeCost(data.rdd) + + @Since("2.0.0") + override def save(sc: SparkContext, path: String): Unit = { + BisectingKMeansModel.SaveLoadV1_0.save(sc, this, path) + } + + override protected def formatVersion: String = "1.0" +} + +@Since("2.0.0") +object BisectingKMeansModel extends Loader[BisectingKMeansModel] { + + @Since("2.0.0") + override def load(sc: SparkContext, path: String): BisectingKMeansModel = { + val (loadedClassName, formatVersion, metadata) = Loader.loadMetadata(sc, path) + implicit val formats = DefaultFormats + val rootId = (metadata \ "rootId").extract[Int] + val classNameV1_0 = SaveLoadV1_0.thisClassName + (loadedClassName, formatVersion) match { + case (classNameV1_0, "1.0") => + val model = SaveLoadV1_0.load(sc, path, rootId) + model + case _ => throw new Exception( + s"BisectingKMeansModel.load did not recognize model with (className, format version):" + + s"($loadedClassName, $formatVersion). Supported:\n" + + s" ($classNameV1_0, 1.0)") + } + } + + private case class Data(index: Int, size: Long, center: Vector, norm: Double, cost: Double, + height: Double, children: Seq[Int]) + + private object Data { + def apply(r: Row): Data = Data(r.getInt(0), r.getLong(1), r.getAs[Vector](2), r.getDouble(3), + r.getDouble(4), r.getDouble(5), r.getSeq[Int](6)) + } + + private[clustering] object SaveLoadV1_0 { + private val thisFormatVersion = "1.0" + + private[clustering] + val thisClassName = "org.apache.spark.mllib.clustering.BisectingKMeansModel" + + def save(sc: SparkContext, model: BisectingKMeansModel, path: String): Unit = { + val sqlContext = SQLContext.getOrCreate(sc) + import sqlContext.implicits._ + val metadata = compact(render( + ("class" -> thisClassName) ~ ("version" -> thisFormatVersion) + ~ ("rootId" -> model.root.index))) + sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path)) + + val data = getNodes(model.root).map(node => Data(node.index, node.size, + node.centerWithNorm.vector, node.centerWithNorm.norm, node.cost, node.height, + node.children.map(_.index))) + val dataRDD = sc.parallelize(data).toDF() + dataRDD.write.parquet(Loader.dataPath(path)) + } + + private def getNodes(node: ClusteringTreeNode): Array[ClusteringTreeNode] = { + if (node.children.isEmpty) { + Array(node) + } else { + node.children.flatMap(getNodes(_)) ++ Array(node) + } + } + + def load(sc: SparkContext, path: String, rootId: Int): BisectingKMeansModel = { + val sqlContext = SQLContext.getOrCreate(sc) + val rows = sqlContext.read.parquet(Loader.dataPath(path)) + Loader.checkSchema[Data](rows.schema) + val data = rows.select("index", "size", "center", "norm", "cost", "height", "children") + val nodes = data.rdd.map(Data.apply).collect().map(d => (d.index, d)).toMap + val rootNode = buildTree(rootId, nodes) + new BisectingKMeansModel(rootNode) + } + + private def buildTree(rootId: Int, nodes: Map[Int, Data]): ClusteringTreeNode = { + val root = nodes.get(rootId).get + if (root.children.isEmpty) { + new ClusteringTreeNode(root.index, root.size, new VectorWithNorm(root.center, root.norm), + root.cost, root.height, new Array[ClusteringTreeNode](0)) + } else { + val children = root.children.map(c => buildTree(c, nodes)) + new ClusteringTreeNode(root.index, root.size, new VectorWithNorm(root.center, root.norm), + root.cost, root.height, children.toArray) + } + } + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala index b719a8c7e7..18f2c994b4 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala @@ -18,10 +18,12 @@ package org.apache.spark.ml.clustering import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.util.DefaultReadWriteTest import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.DataFrame -class BisectingKMeansSuite extends SparkFunSuite with MLlibTestSparkContext { +class BisectingKMeansSuite + extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { final val k = 5 @transient var dataset: DataFrame = _ @@ -84,4 +86,22 @@ class BisectingKMeansSuite extends SparkFunSuite with MLlibTestSparkContext { assert(model.computeCost(dataset) < 0.1) assert(model.hasParent) } + + test("read/write") { + def checkModelData(model: BisectingKMeansModel, model2: BisectingKMeansModel): Unit = { + assert(model.clusterCenters === model2.clusterCenters) + } + val bisectingKMeans = new BisectingKMeans() + testEstimatorAndModelReadWrite( + bisectingKMeans, dataset, BisectingKMeansSuite.allParamSettings, checkModelData) + } +} + +object BisectingKMeansSuite { + val allParamSettings: Map[String, Any] = Map( + "k" -> 3, + "maxIter" -> 2, + "seed" -> -1L, + "minDivisibleClusterSize" -> 2.0 + ) } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/BisectingKMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/BisectingKMeansSuite.scala index 41b9d5c0d9..35f7932ae8 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/BisectingKMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/BisectingKMeansSuite.scala @@ -21,6 +21,7 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ +import org.apache.spark.util.Utils class BisectingKMeansSuite extends SparkFunSuite with MLlibTestSparkContext { @@ -179,4 +180,21 @@ class BisectingKMeansSuite extends SparkFunSuite with MLlibTestSparkContext { } } } + + test("BisectingKMeans model save/load") { + val tempDir = Utils.createTempDir() + val path = tempDir.toURI.toString + + val points = (1 until 8).map(i => Vectors.dense(i)) + val data = sc.parallelize(points, 2) + val model = new BisectingKMeans().run(data) + try { + model.save(sc, path) + val sameModel = BisectingKMeansModel.load(sc, path) + assert(model.k === sameModel.k) + model.clusterCenters.zip(sameModel.clusterCenters).foreach(c => c._1 === c._2) + } finally { + Utils.deleteRecursively(tempDir) + } + } } -- cgit v1.2.3 From 8b207f3b6a0eb617d38091f3b9001830ac3651fe Mon Sep 17 00:00:00 2001 From: Xusen Yin Date: Thu, 31 Mar 2016 11:17:32 -0700 Subject: [SPARK-11892][ML] Model export/import for spark.ml: OneVsRest # What changes were proposed in this pull request? https://issues.apache.org/jira/browse/SPARK-11892 Add save/load for spark ml.OneVsRest and its model. Also add OneVsRest and OneVsRestModel in MetaAlgorithmReadWrite. # How was this patch tested? Test with Scala unit test. Author: Xusen Yin Closes #9934 from yinxusen/SPARK-11892. --- .../apache/spark/ml/classification/OneVsRest.scala | 165 +++++++++++++++++++-- .../scala/org/apache/spark/ml/util/ReadWrite.scala | 8 +- .../spark/ml/classification/OneVsRestSuite.scala | 68 ++++++++- 3 files changed, 223 insertions(+), 18 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala index c41a611f1c..98b99a3485 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala @@ -21,22 +21,24 @@ import java.util.UUID import scala.language.existentials +import org.apache.hadoop.fs.Path +import org.json4s.{DefaultFormats, JObject, _} +import org.json4s.JsonDSL._ +import org.json4s.jackson.JsonMethods._ + +import org.apache.spark.SparkContext import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml._ import org.apache.spark.ml.attribute._ -import org.apache.spark.ml.param.{Param, ParamMap} -import org.apache.spark.ml.util.{Identifiable, MetadataUtils} +import org.apache.spark.ml.param.{Param, ParamMap, ParamPair, Params} +import org.apache.spark.ml.util._ import org.apache.spark.mllib.linalg.Vector import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ import org.apache.spark.storage.StorageLevel -/** - * Params for [[OneVsRest]]. - */ -private[ml] trait OneVsRestParams extends PredictorParams { - +private[ml] trait ClassifierTypeTrait { // scalastyle:off structural.type type ClassifierType = Classifier[F, E, M] forSome { type F @@ -44,6 +46,12 @@ private[ml] trait OneVsRestParams extends PredictorParams { type E <: Classifier[F, E, M] } // scalastyle:on structural.type +} + +/** + * Params for [[OneVsRest]]. + */ +private[ml] trait OneVsRestParams extends PredictorParams with ClassifierTypeTrait { /** * param for the base binary classifier that we reduce multiclass classification into. @@ -57,6 +65,55 @@ private[ml] trait OneVsRestParams extends PredictorParams { def getClassifier: ClassifierType = $(classifier) } +private[ml] object OneVsRestParams extends ClassifierTypeTrait { + + def validateParams(instance: OneVsRestParams): Unit = { + def checkElement(elem: Params, name: String): Unit = elem match { + case stage: MLWritable => // good + case other => + throw new UnsupportedOperationException("OneVsRest write will fail " + + s" because it contains $name which does not implement MLWritable." + + s" Non-Writable $name: ${other.uid} of type ${other.getClass}") + } + + instance match { + case ovrModel: OneVsRestModel => ovrModel.models.foreach(checkElement(_, "model")) + case _ => // no need to check OneVsRest here + } + + checkElement(instance.getClassifier, "classifier") + } + + def saveImpl( + path: String, + instance: OneVsRestParams, + sc: SparkContext, + extraMetadata: Option[JObject] = None): Unit = { + + val params = instance.extractParamMap().toSeq + val jsonParams = render(params + .filter { case ParamPair(p, v) => p.name != "classifier" } + .map { case ParamPair(p, v) => p.name -> parse(p.jsonEncode(v)) } + .toList) + + DefaultParamsWriter.saveMetadata(instance, path, sc, extraMetadata, Some(jsonParams)) + + val classifierPath = new Path(path, "classifier").toString + instance.getClassifier.asInstanceOf[MLWritable].save(classifierPath) + } + + def loadImpl( + path: String, + sc: SparkContext, + expectedClassName: String): (DefaultParamsReader.Metadata, ClassifierType) = { + + val metadata = DefaultParamsReader.loadMetadata(path, sc, expectedClassName) + val classifierPath = new Path(path, "classifier").toString + val estimator = DefaultParamsReader.loadParamsInstance[ClassifierType](classifierPath, sc) + (metadata, estimator) + } +} + /** * :: Experimental :: * Model produced by [[OneVsRest]]. @@ -73,10 +130,10 @@ private[ml] trait OneVsRestParams extends PredictorParams { @Since("1.4.0") @Experimental final class OneVsRestModel private[ml] ( - @Since("1.4.0") override val uid: String, - @Since("1.4.0") labelMetadata: Metadata, + @Since("1.4.0") override val uid: String, + private[ml] val labelMetadata: Metadata, @Since("1.4.0") val models: Array[_ <: ClassificationModel[_, _]]) - extends Model[OneVsRestModel] with OneVsRestParams { + extends Model[OneVsRestModel] with OneVsRestParams with MLWritable { @Since("1.4.0") override def transformSchema(schema: StructType): StructType = { @@ -143,6 +200,56 @@ final class OneVsRestModel private[ml] ( uid, labelMetadata, models.map(_.copy(extra).asInstanceOf[ClassificationModel[_, _]])) copyValues(copied, extra).setParent(parent) } + + @Since("2.0.0") + override def write: MLWriter = new OneVsRestModel.OneVsRestModelWriter(this) +} + +@Since("2.0.0") +object OneVsRestModel extends MLReadable[OneVsRestModel] { + + @Since("2.0.0") + override def read: MLReader[OneVsRestModel] = new OneVsRestModelReader + + @Since("2.0.0") + override def load(path: String): OneVsRestModel = super.load(path) + + /** [[MLWriter]] instance for [[OneVsRestModel]] */ + private[OneVsRestModel] class OneVsRestModelWriter(instance: OneVsRestModel) extends MLWriter { + + OneVsRestParams.validateParams(instance) + + override protected def saveImpl(path: String): Unit = { + val extraJson = ("labelMetadata" -> instance.labelMetadata.json) ~ + ("numClasses" -> instance.models.length) + OneVsRestParams.saveImpl(path, instance, sc, Some(extraJson)) + instance.models.zipWithIndex.foreach { case (model: MLWritable, idx) => + val modelPath = new Path(path, s"model_$idx").toString + model.save(modelPath) + } + } + } + + private class OneVsRestModelReader extends MLReader[OneVsRestModel] { + + /** Checked against metadata when loading model */ + private val className = classOf[OneVsRestModel].getName + + override def load(path: String): OneVsRestModel = { + implicit val format = DefaultFormats + val (metadata, classifier) = OneVsRestParams.loadImpl(path, sc, className) + val labelMetadata = Metadata.fromJson((metadata.metadata \ "labelMetadata").extract[String]) + val numClasses = (metadata.metadata \ "numClasses").extract[Int] + val models = Range(0, numClasses).toArray.map { idx => + val modelPath = new Path(path, s"model_$idx").toString + DefaultParamsReader.loadParamsInstance[ClassificationModel[_, _]](modelPath, sc) + } + val ovrModel = new OneVsRestModel(metadata.uid, labelMetadata, models) + DefaultParamsReader.getAndSetParams(ovrModel, metadata) + ovrModel.set("classifier", classifier) + ovrModel + } + } } /** @@ -158,7 +265,7 @@ final class OneVsRestModel private[ml] ( @Experimental final class OneVsRest @Since("1.4.0") ( @Since("1.4.0") override val uid: String) - extends Estimator[OneVsRestModel] with OneVsRestParams { + extends Estimator[OneVsRestModel] with OneVsRestParams with MLWritable { @Since("1.4.0") def this() = this(Identifiable.randomUID("oneVsRest")) @@ -243,4 +350,40 @@ final class OneVsRest @Since("1.4.0") ( } copied } + + @Since("2.0.0") + override def write: MLWriter = new OneVsRest.OneVsRestWriter(this) +} + +@Since("2.0.0") +object OneVsRest extends MLReadable[OneVsRest] { + + @Since("2.0.0") + override def read: MLReader[OneVsRest] = new OneVsRestReader + + @Since("2.0.0") + override def load(path: String): OneVsRest = super.load(path) + + /** [[MLWriter]] instance for [[OneVsRest]] */ + private[OneVsRest] class OneVsRestWriter(instance: OneVsRest) extends MLWriter { + + OneVsRestParams.validateParams(instance) + + override protected def saveImpl(path: String): Unit = { + OneVsRestParams.saveImpl(path, instance, sc) + } + } + + private class OneVsRestReader extends MLReader[OneVsRest] { + + /** Checked against metadata when loading model */ + private val className = classOf[OneVsRest].getName + + override def load(path: String): OneVsRest = { + val (metadata, classifier) = OneVsRestParams.loadImpl(path, sc, className) + val ovr = new OneVsRest(metadata.uid) + DefaultParamsReader.getAndSetParams(ovr, metadata) + ovr.setClassifier(classifier) + } + } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala index 5a596cad06..39999ede30 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala @@ -29,7 +29,7 @@ import org.apache.spark.SparkContext import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.internal.Logging import org.apache.spark.ml._ -import org.apache.spark.ml.classification.OneVsRestParams +import org.apache.spark.ml.classification.{OneVsRest, OneVsRestModel} import org.apache.spark.ml.feature.RFormulaModel import org.apache.spark.ml.param.{ParamPair, Params} import org.apache.spark.ml.tuning.ValidatorParams @@ -381,10 +381,8 @@ private[ml] object MetaAlgorithmReadWrite { case p: Pipeline => p.getStages.asInstanceOf[Array[Params]] case pm: PipelineModel => pm.stages.asInstanceOf[Array[Params]] case v: ValidatorParams => Array(v.getEstimator, v.getEvaluator) - case ovr: OneVsRestParams => - // TODO: SPARK-11892: This case may require special handling. - throw new UnsupportedOperationException(s"${instance.getClass.getName} write will fail" + - s" because it cannot yet handle an estimator containing type: ${ovr.getClass.getName}.") + case ovr: OneVsRest => Array(ovr.getClassifier) + case ovrModel: OneVsRestModel => Array(ovrModel.getClassifier) ++ ovrModel.models case rformModel: RFormulaModel => Array(rformModel.pipelineModel) case _: Params => Array() } diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala index 2ae74a2090..51c1baf682 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala @@ -21,7 +21,7 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.ml.attribute.NominalAttribute import org.apache.spark.ml.feature.StringIndexer import org.apache.spark.ml.param.{ParamMap, ParamsSuite} -import org.apache.spark.ml.util.{MetadataUtils, MLTestingUtils} +import org.apache.spark.ml.util.{DefaultReadWriteTest, MetadataUtils, MLTestingUtils} import org.apache.spark.mllib.classification.LogisticRegressionSuite._ import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS import org.apache.spark.mllib.evaluation.MulticlassMetrics @@ -33,7 +33,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.DataFrame import org.apache.spark.sql.types.Metadata -class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext { +class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { @transient var dataset: DataFrame = _ @transient var rdd: RDD[LabeledPoint] = _ @@ -160,6 +160,70 @@ class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext { require(m.getThreshold === 0.1, "copy should handle extra model params") } } + + test("read/write: OneVsRest") { + val lr = new LogisticRegression().setMaxIter(10).setRegParam(0.01) + + val ova = new OneVsRest() + .setClassifier(lr) + .setLabelCol("myLabel") + .setFeaturesCol("myFeature") + .setPredictionCol("myPrediction") + + val ova2 = testDefaultReadWrite(ova, testParams = false) + assert(ova.uid === ova2.uid) + assert(ova.getFeaturesCol === ova2.getFeaturesCol) + assert(ova.getLabelCol === ova2.getLabelCol) + assert(ova.getPredictionCol === ova2.getPredictionCol) + + ova2.getClassifier match { + case lr2: LogisticRegression => + assert(lr.uid === lr2.uid) + assert(lr.getMaxIter === lr2.getMaxIter) + assert(lr.getRegParam === lr2.getRegParam) + case other => + throw new AssertionError(s"Loaded OneVsRest expected classifier of type" + + s" LogisticRegression but found ${other.getClass.getName}") + } + } + + test("read/write: OneVsRestModel") { + def checkModelData(model: OneVsRestModel, model2: OneVsRestModel): Unit = { + assert(model.uid === model2.uid) + assert(model.getFeaturesCol === model2.getFeaturesCol) + assert(model.getLabelCol === model2.getLabelCol) + assert(model.getPredictionCol === model2.getPredictionCol) + + val classifier = model.getClassifier.asInstanceOf[LogisticRegression] + + model2.getClassifier match { + case lr2: LogisticRegression => + assert(classifier.uid === lr2.uid) + assert(classifier.getMaxIter === lr2.getMaxIter) + assert(classifier.getRegParam === lr2.getRegParam) + case other => + throw new AssertionError(s"Loaded OneVsRestModel expected classifier of type" + + s" LogisticRegression but found ${other.getClass.getName}") + } + + assert(model.labelMetadata === model2.labelMetadata) + model.models.zip(model2.models).foreach { + case (lrModel1: LogisticRegressionModel, lrModel2: LogisticRegressionModel) => + assert(lrModel1.uid === lrModel2.uid) + assert(lrModel1.coefficients === lrModel2.coefficients) + assert(lrModel1.intercept === lrModel2.intercept) + case other => + throw new AssertionError(s"Loaded OneVsRestModel expected model of type" + + s" LogisticRegressionModel but found ${other.getClass.getName}") + } + } + + val lr = new LogisticRegression().setMaxIter(10).setRegParam(0.01) + val ova = new OneVsRest().setClassifier(lr) + val ovaModel = ova.fit(dataset) + val newOvaModel = testDefaultReadWrite(ovaModel, testParams = false) + checkModelData(ovaModel, newOvaModel) + } } private class MockLogisticRegression(uid: String) extends LogisticRegression(uid) { -- cgit v1.2.3 From 8d6207206c9fd71178417c12cdacf368362df4d8 Mon Sep 17 00:00:00 2001 From: Sameer Agarwal Date: Thu, 31 Mar 2016 11:53:13 -0700 Subject: [SPARK-14263][SQL] Benchmark Vectorized HashMap for GroupBy Aggregates ## What changes were proposed in this pull request? This PR proposes a new data-structure based on a vectorized hashmap that can be potentially _codegened_ in `TungstenAggregate` to speed up aggregates with group by. Micro-benchmarks show a 10x improvement over the current `BytesToBytes` aggregation map. ## How was this patch tested? Intel(R) Core(TM) i7-4960HQ CPU 2.60GHz BytesToBytesMap: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------- hash 108 / 119 96.9 10.3 1.0X fast hash 63 / 70 166.2 6.0 1.7X arrayEqual 70 / 73 150.8 6.6 1.6X Java HashMap (Long) 141 / 200 74.3 13.5 0.8X Java HashMap (two ints) 145 / 185 72.3 13.8 0.7X Java HashMap (UnsafeRow) 499 / 524 21.0 47.6 0.2X BytesToBytesMap (off Heap) 483 / 548 21.7 46.0 0.2X BytesToBytesMap (on Heap) 485 / 562 21.6 46.2 0.2X Vectorized Hashmap 54 / 60 193.7 5.2 2.0X Author: Sameer Agarwal Closes #12055 from sameeragarwal/vectorized-hashmap. --- .../sql/execution/vectorized/AggregateHashMap.java | 107 +++++++++++++++++++++ .../sql/execution/BenchmarkWholeStageCodegen.scala | 45 +++++++-- 2 files changed, 142 insertions(+), 10 deletions(-) create mode 100644 sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/AggregateHashMap.java diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/AggregateHashMap.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/AggregateHashMap.java new file mode 100644 index 0000000000..abe8db589d --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/AggregateHashMap.java @@ -0,0 +1,107 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.vectorized; + +import java.util.Arrays; + +import org.apache.spark.memory.MemoryMode; +import org.apache.spark.sql.types.StructType; + +import static org.apache.spark.sql.types.DataTypes.LongType; + +/** + * This is an illustrative implementation of an append-only single-key/single value aggregate hash + * map that can act as a 'cache' for extremely fast key-value lookups while evaluating aggregates + * (and fall back to the `BytesToBytesMap` if a given key isn't found). This can be potentially + * 'codegened' in TungstenAggregate to speed up aggregates w/ key. + * + * It is backed by a power-of-2-sized array for index lookups and a columnar batch that stores the + * key-value pairs. The index lookups in the array rely on linear probing (with a small number of + * maximum tries) and use an inexpensive hash function which makes it really efficient for a + * majority of lookups. However, using linear probing and an inexpensive hash function also makes it + * less robust as compared to the `BytesToBytesMap` (especially for a large number of keys or even + * for certain distribution of keys) and requires us to fall back on the latter for correctness. + */ +public class AggregateHashMap { + public ColumnarBatch batch; + public int[] buckets; + + private int numBuckets; + private int numRows = 0; + private int maxSteps = 3; + + private static int DEFAULT_CAPACITY = 1 << 16; + private static double DEFAULT_LOAD_FACTOR = 0.25; + private static int DEFAULT_MAX_STEPS = 3; + + public AggregateHashMap(StructType schema, int capacity, double loadFactor, int maxSteps) { + + // We currently only support single key-value pair that are both longs + assert (schema.size() == 2 && schema.fields()[0].dataType() == LongType && + schema.fields()[1].dataType() == LongType); + + // capacity should be a power of 2 + assert (capacity > 0 && ((capacity & (capacity - 1)) == 0)); + + this.maxSteps = maxSteps; + numBuckets = (int) (capacity / loadFactor); + batch = ColumnarBatch.allocate(schema, MemoryMode.ON_HEAP, capacity); + buckets = new int[numBuckets]; + Arrays.fill(buckets, -1); + } + + public AggregateHashMap(StructType schema) { + this(schema, DEFAULT_CAPACITY, DEFAULT_LOAD_FACTOR, DEFAULT_MAX_STEPS); + } + + public int findOrInsert(long key) { + int idx = find(key); + if (idx != -1 && buckets[idx] == -1) { + batch.column(0).putLong(numRows, key); + batch.column(1).putLong(numRows, 0); + buckets[idx] = numRows++; + } + return idx; + } + + public int find(long key) { + long h = hash(key); + int step = 0; + int idx = (int) h & (numBuckets - 1); + while (step < maxSteps) { + // Return bucket index if it's either an empty slot or already contains the key + if (buckets[idx] == -1) { + return idx; + } else if (equals(idx, key)) { + return idx; + } + idx = (idx + 1) & (numBuckets - 1); + step++; + } + // Didn't find it + return -1; + } + + private long hash(long key) { + return key; + } + + private boolean equals(int idx, long key1) { + return batch.column(0).getLong(buckets[idx]) == key1; + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala index a16092e7d7..003d3e062e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala @@ -23,8 +23,9 @@ import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} import org.apache.spark.memory.{StaticMemoryManager, TaskMemoryManager} import org.apache.spark.sql.SQLContext import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.execution.vectorized.AggregateHashMap import org.apache.spark.sql.functions._ -import org.apache.spark.sql.types.IntegerType +import org.apache.spark.sql.types.{IntegerType, LongType, StructType} import org.apache.spark.unsafe.Platform import org.apache.spark.unsafe.hash.Murmur3_x86_32 import org.apache.spark.unsafe.map.BytesToBytesMap @@ -463,18 +464,42 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite { } } + benchmark.addCase("Aggregate HashMap") { iter => + var i = 0 + val numKeys = 65536 + val schema = new StructType() + .add("key", LongType) + .add("value", LongType) + val map = new AggregateHashMap(schema) + while (i < numKeys) { + val idx = map.findOrInsert(i.toLong) + map.batch.column(1).putLong(map.buckets(idx), + map.batch.column(1).getLong(map.buckets(idx)) + 1) + i += 1 + } + var s = 0 + i = 0 + while (i < N) { + if (map.find(i % 100000) != -1) { + s += 1 + } + i += 1 + } + } + /** - Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz + Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz BytesToBytesMap: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------- - hash 651 / 678 80.0 12.5 1.0X - fast hash 336 / 343 155.9 6.4 1.9X - arrayEqual 417 / 428 125.0 8.0 1.6X - Java HashMap (Long) 145 / 168 72.2 13.8 0.8X - Java HashMap (two ints) 157 / 164 66.8 15.0 0.8X - Java HashMap (UnsafeRow) 538 / 573 19.5 51.3 0.2X - BytesToBytesMap (off Heap) 2594 / 2664 20.2 49.5 0.2X - BytesToBytesMap (on Heap) 2693 / 2989 19.5 51.4 0.2X + hash 112 / 116 93.2 10.7 1.0X + fast hash 65 / 69 160.9 6.2 1.7X + arrayEqual 66 / 69 159.1 6.3 1.7X + Java HashMap (Long) 137 / 182 76.3 13.1 0.8X + Java HashMap (two ints) 182 / 230 57.8 17.3 0.6X + Java HashMap (UnsafeRow) 511 / 565 20.5 48.8 0.2X + BytesToBytesMap (off Heap) 481 / 515 21.8 45.9 0.2X + BytesToBytesMap (on Heap) 529 / 600 19.8 50.5 0.2X + Aggregate HashMap 56 / 62 187.9 5.3 2.0X */ benchmark.run() } -- cgit v1.2.3 From 3586929320bba8b8d09c2a451189f76821fdfba4 Mon Sep 17 00:00:00 2001 From: Sameer Agarwal Date: Thu, 31 Mar 2016 11:56:28 -0700 Subject: [SPARK-14278][SQL] Initialize columnar batch with proper memory mode ## What changes were proposed in this pull request? Fixes a minor bug in the record reader constructor that was possibly introduced during refactoring. ## How was this patch tested? N/A Author: Sameer Agarwal Closes #12070 from sameeragarwal/vectorized-rr. --- .../execution/datasources/parquet/VectorizedParquetRecordReader.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java index 0bdf4aab29..a0b6276ef5 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java @@ -191,7 +191,7 @@ public class VectorizedParquetRecordReader extends SpecificParquetRecordReaderBa } } - columnarBatch = ColumnarBatch.allocate(batchSchema); + columnarBatch = ColumnarBatch.allocate(batchSchema, memMode); if (partitionColumns != null) { int partitionIdx = sparkSchema.fields().length; for (int i = 0; i < partitionColumns.fields().length; i++) { -- cgit v1.2.3 From ac1b8b302a92678bbeece6e9c7879f1cb8fdad12 Mon Sep 17 00:00:00 2001 From: Nishkam Ravi Date: Thu, 31 Mar 2016 12:03:05 -0700 Subject: [SPARK-13796] Redirect error message to logWarning ## What changes were proposed in this pull request? Redirect error message to logWarning ## How was this patch tested? Unit tests, manual tests JoshRosen Author: Nishkam Ravi Closes #12052 from nishkamravi2/master_warning. --- core/src/main/scala/org/apache/spark/executor/Executor.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index 3201463b8c..09c5733565 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -254,7 +254,7 @@ private[spark] class Executor( if (conf.getBoolean("spark.storage.exceptionOnPinLeak", false) && !threwException) { throw new SparkException(errMsg) } else { - logError(errMsg) + logWarning(errMsg) } } } -- cgit v1.2.3 From 446c45bd87035e20653394fcaf9dc8caa4299038 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Thu, 31 Mar 2016 12:03:55 -0700 Subject: [SPARK-14182][SQL] Parse DDL Command: Alter View This PR is to provide native parsing support for DDL commands: `Alter View`. Since its AST trees are highly similar to `Alter Table`. Thus, both implementation are integrated into the same one. Based on the Hive DDL document: https://cwiki.apache.org/confluence/display/Hive/LanguageManual+DDL and https://cwiki.apache.org/confluence/display/Hive/PartitionedViews **Syntax:** ```SQL ALTER VIEW view_name RENAME TO new_view_name ``` - to change the name of a view to a different name **Syntax:** ```SQL ALTER VIEW view_name SET TBLPROPERTIES ('comment' = new_comment); ``` - to add metadata to a view **Syntax:** ```SQL ALTER VIEW view_name UNSET TBLPROPERTIES [IF EXISTS] ('comment', 'key') ``` - to remove metadata from a view **Syntax:** ```SQL ALTER VIEW view_name ADD [IF NOT EXISTS] PARTITION spec1[, PARTITION spec2, ...] ``` - to add the partitioning metadata for a view. - the syntax of partition spec in `ALTER VIEW` is identical to `ALTER TABLE`, **EXCEPT** that it is **ILLEGAL** to specify a `LOCATION` clause. **Syntax:** ```SQL ALTER VIEW view_name DROP [IF EXISTS] PARTITION spec1[, PARTITION spec2, ...] ``` - to drop the related partition metadata for a view. Added the related test cases to `DDLCommandSuite` Author: gatorsmile Author: xiaoli Author: Xiao Li Closes #11987 from gatorsmile/parseAlterView. --- .../apache/spark/sql/catalyst/parser/SqlBase.g4 | 20 +-- .../spark/sql/execution/SparkSqlParser.scala | 20 ++- .../apache/spark/sql/execution/command/ddl.scala | 17 ++ .../sql/execution/command/DDLCommandSuite.scala | 175 ++++++++++++++++----- 4 files changed, 177 insertions(+), 55 deletions(-) diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index 3b9f82a80f..a857e670da 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -57,10 +57,11 @@ statement (AS? query)? #createTable | ANALYZE TABLE tableIdentifier partitionSpec? COMPUTE STATISTICS (identifier | FOR COLUMNS identifierSeq?)? #analyze - | ALTER TABLE from=tableIdentifier RENAME TO to=tableIdentifier #renameTable - | ALTER TABLE tableIdentifier + | ALTER (TABLE | VIEW) from=tableIdentifier + RENAME TO to=tableIdentifier #renameTable + | ALTER (TABLE | VIEW) tableIdentifier SET TBLPROPERTIES tablePropertyList #setTableProperties - | ALTER TABLE tableIdentifier + | ALTER (TABLE | VIEW) tableIdentifier UNSET TBLPROPERTIES (IF EXISTS)? tablePropertyList #unsetTableProperties | ALTER TABLE tableIdentifier (partitionSpec)? SET SERDE STRING (WITH SERDEPROPERTIES tablePropertyList)? #setTableSerDe @@ -76,12 +77,16 @@ statement SET SKEWED LOCATION skewedLocationList #setTableSkewLocations | ALTER TABLE tableIdentifier ADD (IF NOT EXISTS)? partitionSpecLocation+ #addTablePartition + | ALTER VIEW tableIdentifier ADD (IF NOT EXISTS)? + partitionSpec+ #addTablePartition | ALTER TABLE tableIdentifier from=partitionSpec RENAME TO to=partitionSpec #renameTablePartition | ALTER TABLE from=tableIdentifier EXCHANGE partitionSpec WITH TABLE to=tableIdentifier #exchangeTablePartition | ALTER TABLE tableIdentifier DROP (IF EXISTS)? partitionSpec (',' partitionSpec)* PURGE? #dropTablePartitions + | ALTER VIEW tableIdentifier + DROP (IF EXISTS)? partitionSpec (',' partitionSpec)* #dropTablePartitions | ALTER TABLE tableIdentifier ARCHIVE partitionSpec #archiveTablePartition | ALTER TABLE tableIdentifier UNARCHIVE partitionSpec #unarchiveTablePartition | ALTER TABLE tableIdentifier partitionSpec? @@ -133,15 +138,6 @@ hiveNativeCommands | DELETE FROM tableIdentifier (WHERE booleanExpression)? | TRUNCATE TABLE tableIdentifier partitionSpec? (COLUMNS identifierList)? - | ALTER VIEW from=tableIdentifier AS? RENAME TO to=tableIdentifier - | ALTER VIEW from=tableIdentifier AS? - SET TBLPROPERTIES tablePropertyList - | ALTER VIEW from=tableIdentifier AS? - UNSET TBLPROPERTIES (IF EXISTS)? tablePropertyList - | ALTER VIEW from=tableIdentifier AS? - ADD (IF NOT EXISTS)? partitionSpecLocation+ - | ALTER VIEW from=tableIdentifier AS? - DROP (IF EXISTS)? partitionSpec (',' partitionSpec)* PURGE? | DROP VIEW (IF EXISTS)? qualifiedName | SHOW COLUMNS (FROM | IN) tableIdentifier ((FROM|IN) identifier)? | START TRANSACTION (transactionMode (',' transactionMode)*)? diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala index b4687c985d..16a899e01f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala @@ -335,6 +335,7 @@ class SparkSqlAstBuilder extends AstBuilder { * For example: * {{{ * ALTER TABLE table1 RENAME TO table2; + * ALTER VIEW view1 RENAME TO view2; * }}} */ override def visitRenameTable(ctx: RenameTableContext): LogicalPlan = withOrigin(ctx) { @@ -350,6 +351,7 @@ class SparkSqlAstBuilder extends AstBuilder { * For example: * {{{ * ALTER TABLE table SET TBLPROPERTIES ('comment' = new_comment); + * ALTER VIEW view SET TBLPROPERTIES ('comment' = new_comment); * }}} */ override def visitSetTableProperties( @@ -366,6 +368,7 @@ class SparkSqlAstBuilder extends AstBuilder { * For example: * {{{ * ALTER TABLE table UNSET TBLPROPERTIES IF EXISTS ('comment', 'key'); + * ALTER VIEW view UNSET TBLPROPERTIES IF EXISTS ('comment', 'key'); * }}} */ override def visitUnsetTableProperties( @@ -510,16 +513,22 @@ class SparkSqlAstBuilder extends AstBuilder { * For example: * {{{ * ALTER TABLE table ADD [IF NOT EXISTS] PARTITION spec [LOCATION 'loc1'] + * ALTER VIEW view ADD [IF NOT EXISTS] PARTITION spec * }}} */ override def visitAddTablePartition( ctx: AddTablePartitionContext): LogicalPlan = withOrigin(ctx) { // Create partition spec to location mapping. - val specsAndLocs = ctx.partitionSpecLocation.asScala.map { - splCtx => - val spec = visitNonOptionalPartitionSpec(splCtx.partitionSpec) - val location = Option(splCtx.locationSpec).map(visitLocationSpec) - spec -> location + val specsAndLocs = if (ctx.partitionSpec.isEmpty) { + ctx.partitionSpecLocation.asScala.map { + splCtx => + val spec = visitNonOptionalPartitionSpec(splCtx.partitionSpec) + val location = Option(splCtx.locationSpec).map(visitLocationSpec) + spec -> location + } + } else { + // Alter View: the location clauses are not allowed. + ctx.partitionSpec.asScala.map(visitNonOptionalPartitionSpec(_) -> None) } AlterTableAddPartition( visitTableIdentifier(ctx.tableIdentifier), @@ -568,6 +577,7 @@ class SparkSqlAstBuilder extends AstBuilder { * For example: * {{{ * ALTER TABLE table DROP [IF EXISTS] PARTITION spec1[, PARTITION spec2, ...] [PURGE]; + * ALTER VIEW view DROP [IF EXISTS] PARTITION spec1[, PARTITION spec2, ...]; * }}} */ override def visitDropTablePartitions( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala index 6c2a67f81c..cd7e0eed65 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala @@ -195,16 +195,19 @@ case class DropFunction( isTemp: Boolean)(sql: String) extends NativeDDLCommand(sql) with Logging +/** Rename in ALTER TABLE/VIEW: change the name of a table/view to a different name. */ case class AlterTableRename( oldName: TableIdentifier, newName: TableIdentifier)(sql: String) extends NativeDDLCommand(sql) with Logging +/** Set Properties in ALTER TABLE/VIEW: add metadata to a table/view. */ case class AlterTableSetProperties( tableName: TableIdentifier, properties: Map[String, String])(sql: String) extends NativeDDLCommand(sql) with Logging +/** Unset Properties in ALTER TABLE/VIEW: remove metadata from a table/view. */ case class AlterTableUnsetProperties( tableName: TableIdentifier, properties: Map[String, String], @@ -253,6 +256,12 @@ case class AlterTableSkewedLocation( skewedMap: Map[String, String])(sql: String) extends NativeDDLCommand(sql) with Logging +/** + * Add Partition in ALTER TABLE/VIEW: add the table/view partitions. + * 'partitionSpecsAndLocs': the syntax of ALTER VIEW is identical to ALTER TABLE, + * EXCEPT that it is ILLEGAL to specify a LOCATION clause. + * An error message will be issued if the partition exists, unless 'ifNotExists' is true. + */ case class AlterTableAddPartition( tableName: TableIdentifier, partitionSpecsAndLocs: Seq[(TablePartitionSpec, Option[String])], @@ -271,6 +280,14 @@ case class AlterTableExchangePartition( spec: TablePartitionSpec)(sql: String) extends NativeDDLCommand(sql) with Logging +/** + * Drop Partition in ALTER TABLE/VIEW: to drop a particular partition for a table/view. + * This removes the data and metadata for this partition. + * The data is actually moved to the .Trash/Current directory if Trash is configured, + * unless 'purge' is true, but the metadata is completely lost. + * An error message will be issued if the partition does not exist, unless 'ifExists' is true. + * Note: purge is always false when the target is a view. + */ case class AlterTableDropPartition( tableName: TableIdentifier, specs: Seq[TablePartitionSpec], diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala index ccbfd41cca..cebf9c856d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala @@ -195,33 +195,60 @@ class DDLCommandSuite extends PlanTest { comparePlans(parsed4, expected4) } - test("alter table: rename table") { - val sql = "ALTER TABLE table_name RENAME TO new_table_name" - val parsed = parser.parsePlan(sql) - val expected = AlterTableRename( + // ALTER TABLE table_name RENAME TO new_table_name; + // ALTER VIEW view_name RENAME TO new_view_name; + test("alter table/view: rename table/view") { + val sql_table = "ALTER TABLE table_name RENAME TO new_table_name" + val sql_view = sql_table.replace("TABLE", "VIEW") + val parsed_table = parser.parsePlan(sql_table) + val parsed_view = parser.parsePlan(sql_view) + val expected_table = AlterTableRename( TableIdentifier("table_name", None), - TableIdentifier("new_table_name", None))(sql) - comparePlans(parsed, expected) + TableIdentifier("new_table_name", None))(sql_table) + val expected_view = AlterTableRename( + TableIdentifier("table_name", None), + TableIdentifier("new_table_name", None))(sql_view) + comparePlans(parsed_table, expected_table) + comparePlans(parsed_view, expected_view) } - test("alter table: alter table properties") { - val sql1 = "ALTER TABLE table_name SET TBLPROPERTIES ('test' = 'test', " + + // ALTER TABLE table_name SET TBLPROPERTIES ('comment' = new_comment); + // ALTER TABLE table_name UNSET TBLPROPERTIES [IF EXISTS] ('comment', 'key'); + // ALTER VIEW view_name SET TBLPROPERTIES ('comment' = new_comment); + // ALTER VIEW view_name UNSET TBLPROPERTIES [IF EXISTS] ('comment', 'key'); + test("alter table/view: alter table/view properties") { + val sql1_table = "ALTER TABLE table_name SET TBLPROPERTIES ('test' = 'test', " + "'comment' = 'new_comment')" - val sql2 = "ALTER TABLE table_name UNSET TBLPROPERTIES ('comment', 'test')" - val sql3 = "ALTER TABLE table_name UNSET TBLPROPERTIES IF EXISTS ('comment', 'test')" - val parsed1 = parser.parsePlan(sql1) - val parsed2 = parser.parsePlan(sql2) - val parsed3 = parser.parsePlan(sql3) + val sql2_table = "ALTER TABLE table_name UNSET TBLPROPERTIES ('comment', 'test')" + val sql3_table = "ALTER TABLE table_name UNSET TBLPROPERTIES IF EXISTS ('comment', 'test')" + val sql1_view = sql1_table.replace("TABLE", "VIEW") + val sql2_view = sql2_table.replace("TABLE", "VIEW") + val sql3_view = sql3_table.replace("TABLE", "VIEW") + + val parsed1_table = parser.parsePlan(sql1_table) + val parsed2_table = parser.parsePlan(sql2_table) + val parsed3_table = parser.parsePlan(sql3_table) + val parsed1_view = parser.parsePlan(sql1_view) + val parsed2_view = parser.parsePlan(sql2_view) + val parsed3_view = parser.parsePlan(sql3_view) + val tableIdent = TableIdentifier("table_name", None) - val expected1 = AlterTableSetProperties( - tableIdent, Map("test" -> "test", "comment" -> "new_comment"))(sql1) - val expected2 = AlterTableUnsetProperties( - tableIdent, Map("comment" -> null, "test" -> null), ifExists = false)(sql2) - val expected3 = AlterTableUnsetProperties( - tableIdent, Map("comment" -> null, "test" -> null), ifExists = true)(sql3) - comparePlans(parsed1, expected1) - comparePlans(parsed2, expected2) - comparePlans(parsed3, expected3) + val expected1_table = AlterTableSetProperties( + tableIdent, Map("test" -> "test", "comment" -> "new_comment"))(sql1_table) + val expected2_table = AlterTableUnsetProperties( + tableIdent, Map("comment" -> null, "test" -> null), ifExists = false)(sql2_table) + val expected3_table = AlterTableUnsetProperties( + tableIdent, Map("comment" -> null, "test" -> null), ifExists = true)(sql3_table) + val expected1_view = expected1_table.copy()(sql = sql1_view) + val expected2_view = expected2_table.copy()(sql = sql2_view) + val expected3_view = expected3_table.copy()(sql = sql3_view) + + comparePlans(parsed1_table, expected1_table) + comparePlans(parsed2_table, expected2_table) + comparePlans(parsed3_table, expected3_table) + comparePlans(parsed1_view, expected1_view) + comparePlans(parsed2_view, expected2_view) + comparePlans(parsed3_view, expected3_view) } test("alter table: SerDe properties") { @@ -376,21 +403,66 @@ class DDLCommandSuite extends PlanTest { comparePlans(parsed2, expected2) } + // ALTER TABLE table_name ADD [IF NOT EXISTS] PARTITION partition_spec + // [LOCATION 'location1'] partition_spec [LOCATION 'location2'] ...; test("alter table: add partition") { - val sql = + val sql1 = """ |ALTER TABLE table_name ADD IF NOT EXISTS PARTITION |(dt='2008-08-08', country='us') LOCATION 'location1' PARTITION |(dt='2009-09-09', country='uk') """.stripMargin - val parsed = parser.parsePlan(sql) - val expected = AlterTableAddPartition( + val sql2 = "ALTER TABLE table_name ADD PARTITION (dt='2008-08-08') LOCATION 'loc'" + + val parsed1 = parser.parsePlan(sql1) + val parsed2 = parser.parsePlan(sql2) + + val expected1 = AlterTableAddPartition( TableIdentifier("table_name", None), Seq( (Map("dt" -> "2008-08-08", "country" -> "us"), Some("location1")), (Map("dt" -> "2009-09-09", "country" -> "uk"), None)), - ifNotExists = true)(sql) - comparePlans(parsed, expected) + ifNotExists = true)(sql1) + val expected2 = AlterTableAddPartition( + TableIdentifier("table_name", None), + Seq((Map("dt" -> "2008-08-08"), Some("loc"))), + ifNotExists = false)(sql2) + + comparePlans(parsed1, expected1) + comparePlans(parsed2, expected2) + } + + // ALTER VIEW view_name ADD [IF NOT EXISTS] PARTITION partition_spec PARTITION partition_spec ...; + test("alter view: add partition") { + val sql1 = + """ + |ALTER VIEW view_name ADD IF NOT EXISTS PARTITION + |(dt='2008-08-08', country='us') PARTITION + |(dt='2009-09-09', country='uk') + """.stripMargin + // different constant types in partitioning spec + val sql2 = + """ + |ALTER VIEW view_name ADD PARTITION + |(col1=NULL, cOL2='f', col3=5, COL4=true) + """.stripMargin + + val parsed1 = parser.parsePlan(sql1) + val parsed2 = parser.parsePlan(sql2) + + val expected1 = AlterTableAddPartition( + TableIdentifier("view_name", None), + Seq( + (Map("dt" -> "2008-08-08", "country" -> "us"), None), + (Map("dt" -> "2009-09-09", "country" -> "uk"), None)), + ifNotExists = true)(sql1) + val expected2 = AlterTableAddPartition( + TableIdentifier("view_name", None), + Seq((Map("col1" -> "NULL", "col2" -> "f", "col3" -> "5", "col4" -> "true"), None)), + ifNotExists = false)(sql2) + + comparePlans(parsed1, expected1) + comparePlans(parsed2, expected2) } test("alter table: rename partition") { @@ -421,36 +493,63 @@ class DDLCommandSuite extends PlanTest { comparePlans(parsed, expected) } - test("alter table: drop partitions") { - val sql1 = + // ALTER TABLE table_name DROP [IF EXISTS] PARTITION spec1[, PARTITION spec2, ...] [PURGE] + // ALTER VIEW table_name DROP [IF EXISTS] PARTITION spec1[, PARTITION spec2, ...] + test("alter table/view: drop partitions") { + val sql1_table = """ |ALTER TABLE table_name DROP IF EXISTS PARTITION |(dt='2008-08-08', country='us'), PARTITION (dt='2009-09-09', country='uk') """.stripMargin - val sql2 = + val sql2_table = """ |ALTER TABLE table_name DROP PARTITION |(dt='2008-08-08', country='us'), PARTITION (dt='2009-09-09', country='uk') PURGE """.stripMargin - val parsed1 = parser.parsePlan(sql1) - val parsed2 = parser.parsePlan(sql2) + val sql1_view = sql1_table.replace("TABLE", "VIEW") + // Note: ALTER VIEW DROP PARTITION does not support PURGE + val sql2_view = sql2_table.replace("TABLE", "VIEW").replace("PURGE", "") + + val parsed1_table = parser.parsePlan(sql1_table) + val parsed2_table = parser.parsePlan(sql2_table) + val parsed1_view = parser.parsePlan(sql1_view) + val parsed2_view = parser.parsePlan(sql2_view) + val tableIdent = TableIdentifier("table_name", None) - val expected1 = AlterTableDropPartition( + val expected1_table = AlterTableDropPartition( tableIdent, Seq( Map("dt" -> "2008-08-08", "country" -> "us"), Map("dt" -> "2009-09-09", "country" -> "uk")), ifExists = true, - purge = false)(sql1) - val expected2 = AlterTableDropPartition( + purge = false)(sql1_table) + val expected2_table = AlterTableDropPartition( tableIdent, Seq( Map("dt" -> "2008-08-08", "country" -> "us"), Map("dt" -> "2009-09-09", "country" -> "uk")), ifExists = false, - purge = true)(sql2) - comparePlans(parsed1, expected1) - comparePlans(parsed2, expected2) + purge = true)(sql2_table) + + val expected1_view = AlterTableDropPartition( + tableIdent, + Seq( + Map("dt" -> "2008-08-08", "country" -> "us"), + Map("dt" -> "2009-09-09", "country" -> "uk")), + ifExists = true, + purge = false)(sql1_view) + val expected2_view = AlterTableDropPartition( + tableIdent, + Seq( + Map("dt" -> "2008-08-08", "country" -> "us"), + Map("dt" -> "2009-09-09", "country" -> "uk")), + ifExists = false, + purge = false)(sql2_table) + + comparePlans(parsed1_table, expected1_table) + comparePlans(parsed2_table, expected2_table) + comparePlans(parsed1_view, expected1_view) + comparePlans(parsed2_view, expected2_view) } test("alter table: archive partition") { -- cgit v1.2.3 From 8a333d2da859fd593bda183413630bc3757529c9 Mon Sep 17 00:00:00 2001 From: jeanlyn Date: Thu, 31 Mar 2016 12:04:42 -0700 Subject: [SPARK-14243][CORE] update task metrics when removing blocks ## What changes were proposed in this pull request? This PR try to use `incUpdatedBlockStatuses ` to update the `updatedBlockStatuses ` when removing blocks, making sure `BlockManager` correctly updates `updatedBlockStatuses` ## How was this patch tested? test("updated block statuses") in BlockManagerSuite.scala Author: jeanlyn Closes #12091 from jeanlyn/updateBlock. --- .../src/main/scala/org/apache/spark/storage/BlockManager.scala | 7 +++++-- .../scala/org/apache/spark/storage/BlockManagerSuite.scala | 10 ++++++++++ 2 files changed, 15 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 0c7763f236..3014cafc28 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -1264,9 +1264,12 @@ private[spark] class BlockManager( "the disk, memory, or external block store") } blockInfoManager.removeBlock(blockId) + val removeBlockStatus = getCurrentBlockStatus(blockId, info) if (tellMaster && info.tellMaster) { - val status = getCurrentBlockStatus(blockId, info) - reportBlockStatus(blockId, info, status) + reportBlockStatus(blockId, info, removeBlockStatus) + } + Option(TaskContext.get()).foreach { c => + c.taskMetrics().incUpdatedBlockStatuses(Seq((blockId, removeBlockStatus))) } } } diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index 6fc32cb30a..9f3a775654 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -928,6 +928,16 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE assert(!store.diskStore.contains("list3"), "list3 was in disk store") assert(!store.diskStore.contains("list4"), "list4 was in disk store") assert(!store.diskStore.contains("list5"), "list5 was in disk store") + + // remove block - list2 should be removed from disk + val updatedBlocks6 = getUpdatedBlocks { + store.removeBlock( + "list2", tellMaster = true) + } + assert(updatedBlocks6.size === 1) + assert(updatedBlocks6.head._1 === TestBlockId("list2")) + assert(updatedBlocks6.head._2.storageLevel == StorageLevel.NONE) + assert(!store.diskStore.contains("list2"), "list2 was in disk store") } test("query block statuses") { -- cgit v1.2.3 From 4d93b653f7294698526674950d3dc303691260f8 Mon Sep 17 00:00:00 2001 From: Michael Gummelt Date: Thu, 31 Mar 2016 12:06:16 -0700 Subject: [Docs] Update monitoring.md to accurately describe the history server It looks like the docs were recently updated to reflect the History Server's support for incomplete applications, but they still had wording that suggested only completed applications were viewable. This fixes that. My editor also introduced several whitespace removal changes, that I hope are OK, as text files shouldn't have trailing whitespace. To verify they're purely whitespace changes, add `&w=1` to your browser address. If this isn't acceptable, let me know and I'll update the PR. I also didn't think this required a JIRA. Let me know if I should create one. Not tested Author: Michael Gummelt Closes #12045 from mgummelt/update-history-docs. --- docs/monitoring.md | 58 +++++++++++++++++++++++++++--------------------------- 1 file changed, 29 insertions(+), 29 deletions(-) diff --git a/docs/monitoring.md b/docs/monitoring.md index c139e1cb5a..32d2e02e93 100644 --- a/docs/monitoring.md +++ b/docs/monitoring.md @@ -8,7 +8,7 @@ There are several ways to monitor Spark applications: web UIs, metrics, and exte # Web Interfaces -Every SparkContext launches a web UI, by default on port 4040, that +Every SparkContext launches a web UI, by default on port 4040, that displays useful information about the application. This includes: * A list of scheduler stages and tasks @@ -32,19 +32,19 @@ Spark's Standalone Mode cluster manager also has its own the course of its lifetime, then the Standalone master's web UI will automatically re-render the application's UI after the application has finished. -If Spark is run on Mesos or YARN, it is still possible to reconstruct the UI of a finished +If Spark is run on Mesos or YARN, it is still possible to construct the UI of an application through Spark's history server, provided that the application's event logs exist. You can start the history server by executing: ./sbin/start-history-server.sh This creates a web interface at `http://:18080` by default, listing incomplete -and completed applications and attempts, and allowing them to be viewed +and completed applications and attempts. When using the file-system provider class (see `spark.history.provider` below), the base logging directory must be supplied in the `spark.history.fs.logDirectory` configuration option, and should contain sub-directories that each represents an application's event logs. - + The spark jobs themselves must be configured to log events, and to log them to the same shared, writeable directory. For example, if the server was configured with a log directory of `hdfs://namenode/shared/spark-logs`, then the client-side options would be: @@ -53,7 +53,7 @@ writeable directory. For example, if the server was configured with a log direct spark.eventLog.enabled true spark.eventLog.dir hdfs://namenode/shared/spark-logs ``` - + The history server can be configured as follows: ### Environment Variables @@ -135,9 +135,9 @@ The history server can be configured as follows: @@ -159,12 +159,12 @@ The history server can be configured as follows: @@ -298,14 +298,14 @@ keep the paths consistent in both modes. # Metrics -Spark has a configurable metrics system based on the -[Coda Hale Metrics Library](http://metrics.codahale.com/). -This allows users to report Spark metrics to a variety of sinks including HTTP, JMX, and CSV -files. The metrics system is configured via a configuration file that Spark expects to be present -at `$SPARK_HOME/conf/metrics.properties`. A custom file location can be specified via the +Spark has a configurable metrics system based on the +[Coda Hale Metrics Library](http://metrics.codahale.com/). +This allows users to report Spark metrics to a variety of sinks including HTTP, JMX, and CSV +files. The metrics system is configured via a configuration file that Spark expects to be present +at `$SPARK_HOME/conf/metrics.properties`. A custom file location can be specified via the `spark.metrics.conf` [configuration property](configuration.html#spark-properties). -Spark's metrics are decoupled into different -_instances_ corresponding to Spark components. Within each instance, you can configure a +Spark's metrics are decoupled into different +_instances_ corresponding to Spark components. Within each instance, you can configure a set of sinks to which metrics are reported. The following instances are currently supported: * `master`: The Spark standalone master process. @@ -330,26 +330,26 @@ licensing restrictions: * `GangliaSink`: Sends metrics to a Ganglia node or multicast group. To install the `GangliaSink` you'll need to perform a custom build of Spark. _**Note that -by embedding this library you will include [LGPL](http://www.gnu.org/copyleft/lesser.html)-licensed -code in your Spark package**_. For sbt users, set the -`SPARK_GANGLIA_LGPL` environment variable before building. For Maven users, enable +by embedding this library you will include [LGPL](http://www.gnu.org/copyleft/lesser.html)-licensed +code in your Spark package**_. For sbt users, set the +`SPARK_GANGLIA_LGPL` environment variable before building. For Maven users, enable the `-Pspark-ganglia-lgpl` profile. In addition to modifying the cluster's Spark build user applications will need to link to the `spark-ganglia-lgpl` artifact. -The syntax of the metrics configuration file is defined in an example configuration file, +The syntax of the metrics configuration file is defined in an example configuration file, `$SPARK_HOME/conf/metrics.properties.template`. # Advanced Instrumentation Several external tools can be used to help profile the performance of Spark jobs: -* Cluster-wide monitoring tools, such as [Ganglia](http://ganglia.sourceforge.net/), can provide -insight into overall cluster utilization and resource bottlenecks. For instance, a Ganglia -dashboard can quickly reveal whether a particular workload is disk bound, network bound, or +* Cluster-wide monitoring tools, such as [Ganglia](http://ganglia.sourceforge.net/), can provide +insight into overall cluster utilization and resource bottlenecks. For instance, a Ganglia +dashboard can quickly reveal whether a particular workload is disk bound, network bound, or CPU bound. -* OS profiling tools such as [dstat](http://dag.wieers.com/home-made/dstat/), -[iostat](http://linux.die.net/man/1/iostat), and [iotop](http://linux.die.net/man/1/iotop) +* OS profiling tools such as [dstat](http://dag.wieers.com/home-made/dstat/), +[iostat](http://linux.die.net/man/1/iostat), and [iotop](http://linux.die.net/man/1/iotop) can provide fine-grained profiling on individual nodes. -* JVM utilities such as `jstack` for providing stack traces, `jmap` for creating heap-dumps, -`jstat` for reporting time-series statistics and `jconsole` for visually exploring various JVM +* JVM utilities such as `jstack` for providing stack traces, `jmap` for creating heap-dumps, +`jstat` for reporting time-series statistics and `jconsole` for visually exploring various JVM properties are useful for those comfortable with JVM internals. -- cgit v1.2.3 From 0abee534f0ad9bbe84d8d3d3478ecaa594f1e0f4 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 31 Mar 2016 12:07:19 -0700 Subject: [SPARK-14069][SQL] Improve SparkStatusTracker to also track executor information ## What changes were proposed in this pull request? Track executor information like host and port, cache size, running tasks. TODO: tests ## How was this patch tested? N/A Author: Wenchen Fan Closes #11888 from cloud-fan/status-tracker. --- .../java/org/apache/spark/SparkExecutorInfo.java | 33 ++++++++++++++++++++++ .../main/scala/org/apache/spark/SparkContext.scala | 3 +- .../org/apache/spark/SparkStatusTracker.scala | 20 +++++++++++++ .../scala/org/apache/spark/StatusAPIImpl.scala | 33 +++++++++++++--------- .../apache/spark/scheduler/TaskSchedulerImpl.scala | 2 ++ .../org/apache/spark/storage/StorageUtils.scala | 5 +++- 6 files changed, 80 insertions(+), 16 deletions(-) create mode 100644 core/src/main/java/org/apache/spark/SparkExecutorInfo.java diff --git a/core/src/main/java/org/apache/spark/SparkExecutorInfo.java b/core/src/main/java/org/apache/spark/SparkExecutorInfo.java new file mode 100644 index 0000000000..dc3e826475 --- /dev/null +++ b/core/src/main/java/org/apache/spark/SparkExecutorInfo.java @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark; + +import java.io.Serializable; + +/** + * Exposes information about Spark Executors. + * + * This interface is not designed to be implemented outside of Spark. We may add additional methods + * which may break binary compatibility with outside implementations. + */ +public interface SparkExecutorInfo extends Serializable { + String host(); + int port(); + long cacheSize(); + int numRunningTasks(); +} diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index dcb41f3a40..d7cb253d69 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -147,8 +147,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli appName: String, sparkHome: String = null, jars: Seq[String] = Nil, - environment: Map[String, String] = Map()) = - { + environment: Map[String, String] = Map()) = { this(SparkContext.updatedConf(new SparkConf(), master, appName, sparkHome, jars, environment)) } diff --git a/core/src/main/scala/org/apache/spark/SparkStatusTracker.scala b/core/src/main/scala/org/apache/spark/SparkStatusTracker.scala index 34ee3a48f8..52c4656c27 100644 --- a/core/src/main/scala/org/apache/spark/SparkStatusTracker.scala +++ b/core/src/main/scala/org/apache/spark/SparkStatusTracker.scala @@ -17,6 +17,8 @@ package org.apache.spark +import org.apache.spark.scheduler.TaskSchedulerImpl + /** * Low-level status reporting APIs for monitoring job and stage progress. * @@ -104,4 +106,22 @@ class SparkStatusTracker private[spark] (sc: SparkContext) { } } } + + /** + * Returns information of all known executors, including host, port, cacheSize, numRunningTasks. + */ + def getExecutorInfos: Array[SparkExecutorInfo] = { + val executorIdToRunningTasks: Map[String, Int] = + sc.taskScheduler.asInstanceOf[TaskSchedulerImpl].runningTasksByExecutors() + + sc.getExecutorStorageStatus.map { status => + val bmId = status.blockManagerId + new SparkExecutorInfoImpl( + bmId.host, + bmId.port, + status.cacheSize, + executorIdToRunningTasks.getOrElse(bmId.executorId, 0) + ) + } + } } diff --git a/core/src/main/scala/org/apache/spark/StatusAPIImpl.scala b/core/src/main/scala/org/apache/spark/StatusAPIImpl.scala index e5c7c8d0db..c1f24a6377 100644 --- a/core/src/main/scala/org/apache/spark/StatusAPIImpl.scala +++ b/core/src/main/scala/org/apache/spark/StatusAPIImpl.scala @@ -18,18 +18,25 @@ package org.apache.spark private class SparkJobInfoImpl ( - val jobId: Int, - val stageIds: Array[Int], - val status: JobExecutionStatus) - extends SparkJobInfo + val jobId: Int, + val stageIds: Array[Int], + val status: JobExecutionStatus) + extends SparkJobInfo private class SparkStageInfoImpl( - val stageId: Int, - val currentAttemptId: Int, - val submissionTime: Long, - val name: String, - val numTasks: Int, - val numActiveTasks: Int, - val numCompletedTasks: Int, - val numFailedTasks: Int) - extends SparkStageInfo + val stageId: Int, + val currentAttemptId: Int, + val submissionTime: Long, + val name: String, + val numTasks: Int, + val numActiveTasks: Int, + val numCompletedTasks: Int, + val numFailedTasks: Int) + extends SparkStageInfo + +private class SparkExecutorInfoImpl( + val host: String, + val port: Int, + val cacheSize: Long, + val numRunningTasks: Int) + extends SparkExecutorInfo diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala index f7790fccc6..daed2ff50e 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala @@ -90,6 +90,8 @@ private[spark] class TaskSchedulerImpl( // Number of tasks running on each executor private val executorIdToTaskCount = new HashMap[String, Int] + def runningTasksByExecutors(): Map[String, Int] = executorIdToTaskCount.toMap + // The set of executors we have on each host; this is used to compute hostsAlive, which // in turn is used to decide when we can attain data locality on a given host protected val executorsByHost = new HashMap[String, HashSet[String]] diff --git a/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala b/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala index 199a5fc270..fb9941bbd9 100644 --- a/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala +++ b/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala @@ -175,7 +175,10 @@ class StorageStatus(val blockManagerId: BlockManagerId, val maxMem: Long) { def memRemaining: Long = maxMem - memUsed /** Return the memory used by this block manager. */ - def memUsed: Long = _nonRddStorageInfo._1 + _rddBlocks.keys.toSeq.map(memUsedByRdd).sum + def memUsed: Long = _nonRddStorageInfo._1 + cacheSize + + /** Return the memory used by caching RDDs */ + def cacheSize: Long = _rddBlocks.keys.toSeq.map(memUsedByRdd).sum /** Return the disk space used by this block manager. */ def diskUsed: Long = _nonRddStorageInfo._2 + _rddBlocks.keys.toSeq.map(diskUsedByRdd).sum -- cgit v1.2.3 From 10508f36adcb74a563010636dffcd1f68efd8468 Mon Sep 17 00:00:00 2001 From: Jo Voordeckers Date: Thu, 31 Mar 2016 12:08:10 -0700 Subject: [SPARK-11327][MESOS] Dispatcher does not respect all args from the Submit request Supersedes https://github.com/apache/spark/pull/9752 Author: Jo Voordeckers Author: Iulian Dragos Closes #10370 from jayv/mesos_cluster_params. --- .../cluster/mesos/MesosClusterScheduler.scala | 26 ++++++++++++++++ .../cluster/mesos/MesosClusterSchedulerSuite.scala | 36 ++++++++++++++++++++++ 2 files changed, 62 insertions(+) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala index 2df7b1120b..c41fa58607 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala @@ -423,6 +423,12 @@ private[spark] class MesosClusterScheduler( "--driver-cores", desc.cores.toString, "--driver-memory", s"${desc.mem}M") + val replicatedOptionsBlacklist = Set( + "spark.jars", // Avoids duplicate classes in classpath + "spark.submit.deployMode", // this would be set to `cluster`, but we need client + "spark.master" // this contains the address of the dispatcher, not master + ) + // Assume empty main class means we're running python if (!desc.command.mainClass.equals("")) { options ++= Seq("--class", desc.command.mainClass) @@ -440,9 +446,29 @@ private[spark] class MesosClusterScheduler( .mkString(",") options ++= Seq("--py-files", formattedFiles) } + desc.schedulerProperties + .filter { case (key, _) => !replicatedOptionsBlacklist.contains(key) } + .foreach { case (key, value) => options ++= Seq("--conf", s"$key=${shellEscape(value)}") } options } + /** + * Escape args for Unix-like shells, unless already quoted by the user. + * Based on: http://www.gnu.org/software/bash/manual/html_node/Double-Quotes.html + * and http://www.grymoire.com/Unix/Quote.html + * @param value argument + * @return escaped argument + */ + private[scheduler] def shellEscape(value: String): String = { + val WrappedInQuotes = """^(".+"|'.+')$""".r + val ShellSpecialChars = (""".*([ '<>&|\?\*;!#\\(\)"$`]).*""").r + value match { + case WrappedInQuotes(c) => value // The user quoted his args, don't touch it! + case ShellSpecialChars(c) => "\"" + value.replaceAll("""(["`\$\\])""", """\\$1""") + "\"" + case _: String => value // Don't touch harmless strings + } + } + private class ResourceOffer( val offerId: OfferID, val slaveId: SlaveID, diff --git a/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterSchedulerSuite.scala index dbef6868f2..a32423dc4f 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterSchedulerSuite.scala @@ -136,4 +136,40 @@ class MesosClusterSchedulerSuite extends SparkFunSuite with LocalSparkContext wi capture.capture() ) } + + test("escapes commandline args for the shell") { + val conf = new SparkConf() + conf.setMaster("mesos://localhost:5050") + conf.setAppName("spark mesos") + val scheduler = new MesosClusterScheduler( + new BlackHoleMesosClusterPersistenceEngineFactory, conf) { + override def start(): Unit = { ready = true } + } + val escape = scheduler.shellEscape _ + def wrapped(str: String): String = "\"" + str + "\"" + + // Wrapped in quotes + assert(escape("'should be left untouched'") === "'should be left untouched'") + assert(escape("\"should be left untouched\"") === "\"should be left untouched\"") + + // Harmless + assert(escape("") === "") + assert(escape("harmless") === "harmless") + assert(escape("har-m.l3ss") === "har-m.l3ss") + + // Special Chars escape + assert(escape("should escape this \" quote") === wrapped("should escape this \\\" quote")) + assert(escape("shouldescape\"quote") === wrapped("shouldescape\\\"quote")) + assert(escape("should escape this $ dollar") === wrapped("should escape this \\$ dollar")) + assert(escape("should escape this ` backtick") === wrapped("should escape this \\` backtick")) + assert(escape("""should escape this \ backslash""") + === wrapped("""should escape this \\ backslash""")) + assert(escape("""\"?""") === wrapped("""\\\"?""")) + + + // Special Chars no escape only wrap + List(" ", "'", "<", ">", "&", "|", "?", "*", ";", "!", "#", "(", ")").foreach(char => { + assert(escape(s"onlywrap${char}this") === wrapped(s"onlywrap${char}this")) + }) + } } -- cgit v1.2.3 From 3cfbeb70b1feb1f3a8c4d0b2d2f3715a356c80f2 Mon Sep 17 00:00:00 2001 From: Michel Lemay Date: Thu, 31 Mar 2016 12:15:32 -0700 Subject: [SPARK-13710][SHELL][WINDOWS] Fix jline dependency on Windows ## What changes were proposed in this pull request? Exclude jline from curator-recipes since it conflicts with scala 2.11 when running spark-shell. Should not affect scala 2.10 since it is builtin. ## How was this patch tested? Ran spark-shell manually. Author: Michel Lemay Closes #12043 from michellemay/spark-13710-fix-jline-on-windows. --- pom.xml | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/pom.xml b/pom.xml index 9dab0bca74..25d6136421 100644 --- a/pom.xml +++ b/pom.xml @@ -733,6 +733,10 @@ org.jboss.netty netty + + jline + jline + -- cgit v1.2.3 From e785402826dcd984d9312470464714ba6c908a49 Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Thu, 31 Mar 2016 12:17:25 -0700 Subject: [SPARK-14304][SQL][TESTS] Fix tests that don't create temp files in the `java.io.tmpdir` folder ## What changes were proposed in this pull request? If I press `CTRL-C` when running these tests, the temp files will be left in `sql/core` folder and I need to delete them manually. It's annoying. This PR just moves the temp files to the `java.io.tmpdir` folder and add a name prefix for them. ## How was this patch tested? Existing Jenkins tests Author: Shixiong Zhu Closes #12093 from zsxwing/temp-file. --- .../scala/org/apache/spark/sql/StreamTest.scala | 2 +- .../streaming/ContinuousQueryManagerSuite.scala | 3 ++- .../sql/streaming/DataFrameReaderWriterSuite.scala | 3 ++- .../spark/sql/streaming/FileStreamSinkSuite.scala | 4 ++-- .../sql/streaming/FileStreamSourceSuite.scala | 26 +++++++++++----------- .../spark/sql/streaming/FileStressSuite.scala | 8 +++---- 6 files changed, 24 insertions(+), 22 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala index 4ca739450c..b5be7ef47e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala @@ -265,7 +265,7 @@ trait StreamTest extends QueryTest with Timeouts { } val testThread = Thread.currentThread() - val metadataRoot = Utils.createTempDir("streaming.metadata").getCanonicalPath + val metadataRoot = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath try { startedTest.foreach { action => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/ContinuousQueryManagerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/ContinuousQueryManagerSuite.scala index 54ce98d195..29bd3e018e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/ContinuousQueryManagerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/ContinuousQueryManagerSuite.scala @@ -236,7 +236,8 @@ class ContinuousQueryManagerSuite extends StreamTest with SharedSQLContext with @volatile var query: StreamExecution = null try { val df = ds.toDF - val metadataRoot = Utils.createTempDir("streaming.metadata").getCanonicalPath + val metadataRoot = + Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath query = sqlContext .streams .startQuery( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/DataFrameReaderWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/DataFrameReaderWriterSuite.scala index c1bab9b577..102473d7d0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/DataFrameReaderWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/DataFrameReaderWriterSuite.scala @@ -69,7 +69,8 @@ class DefaultSource extends StreamSourceProvider with StreamSinkProvider { class DataFrameReaderWriterSuite extends StreamTest with SharedSQLContext with BeforeAndAfter { import testImplicits._ - private def newMetadataDir = Utils.createTempDir("streaming.metadata").getCanonicalPath + private def newMetadataDir = + Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath after { sqlContext.streams.active.foreach(_.stop()) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala index 7f31611383..8cf5dedabc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala @@ -29,8 +29,8 @@ class FileStreamSinkSuite extends StreamTest with SharedSQLContext { val inputData = MemoryStream[Int] val df = inputData.toDF() - val outputDir = Utils.createTempDir("stream.output").getCanonicalPath - val checkpointDir = Utils.createTempDir("stream.checkpoint").getCanonicalPath + val outputDir = Utils.createTempDir(namePrefix = "stream.output").getCanonicalPath + val checkpointDir = Utils.createTempDir(namePrefix = "stream.checkpoint").getCanonicalPath val query = df.write diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala index 89de15acf5..054f5c9fa2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala @@ -202,8 +202,8 @@ class FileStreamSourceSuite extends FileStreamSourceTest with SharedSQLContext { } test("read from text files") { - val src = Utils.createTempDir("streaming.src") - val tmp = Utils.createTempDir("streaming.tmp") + val src = Utils.createTempDir(namePrefix = "streaming.src") + val tmp = Utils.createTempDir(namePrefix = "streaming.tmp") val textSource = createFileStreamSource("text", src.getCanonicalPath) val filtered = textSource.toDF().filter($"value" contains "keep") @@ -224,8 +224,8 @@ class FileStreamSourceSuite extends FileStreamSourceTest with SharedSQLContext { } test("read from json files") { - val src = Utils.createTempDir("streaming.src") - val tmp = Utils.createTempDir("streaming.tmp") + val src = Utils.createTempDir(namePrefix = "streaming.src") + val tmp = Utils.createTempDir(namePrefix = "streaming.tmp") val textSource = createFileStreamSource("json", src.getCanonicalPath, Some(valueSchema)) val filtered = textSource.toDF().filter($"value" contains "keep") @@ -258,8 +258,8 @@ class FileStreamSourceSuite extends FileStreamSourceTest with SharedSQLContext { } test("read from json files with inferring schema") { - val src = Utils.createTempDir("streaming.src") - val tmp = Utils.createTempDir("streaming.tmp") + val src = Utils.createTempDir(namePrefix = "streaming.src") + val tmp = Utils.createTempDir(namePrefix = "streaming.tmp") // Add a file so that we can infer its schema stringToFile(new File(src, "existing"), "{'c': 'drop1'}\n{'c': 'keep2'}\n{'c': 'keep3'}") @@ -279,8 +279,8 @@ class FileStreamSourceSuite extends FileStreamSourceTest with SharedSQLContext { } test("read from parquet files") { - val src = Utils.createTempDir("streaming.src") - val tmp = Utils.createTempDir("streaming.tmp") + val src = Utils.createTempDir(namePrefix = "streaming.src") + val tmp = Utils.createTempDir(namePrefix = "streaming.tmp") val fileSource = createFileStreamSource("parquet", src.getCanonicalPath, Some(valueSchema)) val filtered = fileSource.toDF().filter($"value" contains "keep") @@ -301,7 +301,7 @@ class FileStreamSourceSuite extends FileStreamSourceTest with SharedSQLContext { } test("file stream source without schema") { - val src = Utils.createTempDir("streaming.src") + val src = Utils.createTempDir(namePrefix = "streaming.src") // Only "text" doesn't need a schema createFileStreamSource("text", src.getCanonicalPath) @@ -318,8 +318,8 @@ class FileStreamSourceSuite extends FileStreamSourceTest with SharedSQLContext { } test("fault tolerance") { - val src = Utils.createTempDir("streaming.src") - val tmp = Utils.createTempDir("streaming.tmp") + val src = Utils.createTempDir(namePrefix = "streaming.src") + val tmp = Utils.createTempDir(namePrefix = "streaming.tmp") val textSource = createFileStreamSource("text", src.getCanonicalPath) val filtered = textSource.toDF().filter($"value" contains "keep") @@ -346,8 +346,8 @@ class FileStreamSourceStressTestSuite extends FileStreamSourceTest with SharedSQ import testImplicits._ test("file source stress test") { - val src = Utils.createTempDir("streaming.src") - val tmp = Utils.createTempDir("streaming.tmp") + val src = Utils.createTempDir(namePrefix = "streaming.src") + val tmp = Utils.createTempDir(namePrefix = "streaming.tmp") val textSource = createFileStreamSource("text", src.getCanonicalPath) val ds = textSource.toDS[String]().map(_.toInt + 1) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStressSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStressSuite.scala index 5a1bfb3a00..3916430cdf 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStressSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStressSuite.scala @@ -43,10 +43,10 @@ class FileStressSuite extends StreamTest with SharedSQLContext { test("fault tolerance stress test") { val numRecords = 10000 - val inputDir = Utils.createTempDir("stream.input").getCanonicalPath - val stagingDir = Utils.createTempDir("stream.staging").getCanonicalPath - val outputDir = Utils.createTempDir("stream.output").getCanonicalPath - val checkpoint = Utils.createTempDir("stream.checkpoint").getCanonicalPath + val inputDir = Utils.createTempDir(namePrefix = "stream.input").getCanonicalPath + val stagingDir = Utils.createTempDir(namePrefix = "stream.staging").getCanonicalPath + val outputDir = Utils.createTempDir(namePrefix = "stream.output").getCanonicalPath + val checkpoint = Utils.createTempDir(namePrefix = "stream.checkpoint").getCanonicalPath @volatile var continue = true -- cgit v1.2.3 From b11887c086974dbab18b9f53e99a26bbe06e9c86 Mon Sep 17 00:00:00 2001 From: sethah Date: Thu, 31 Mar 2016 13:00:10 -0700 Subject: [SPARK-14264][PYSPARK][ML] Add feature importance for GBTs in pyspark ## What changes were proposed in this pull request? Feature importances are exposed in the python API for GBTs. Other changes: * Update the random forest feature importance documentation to not repeat decision tree docstring and instead place a reference to it. ## How was this patch tested? Python doc tests were updated to validate GBT feature importance. Author: sethah Closes #12056 from sethah/Pyspark_GBT_feature_importance. --- python/pyspark/ml/classification.py | 33 +++++++++++++++++++++++---------- python/pyspark/ml/regression.py | 33 +++++++++++++++++++++++---------- 2 files changed, 46 insertions(+), 20 deletions(-) diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index 07cafa0993..f5335a3114 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -396,7 +396,7 @@ class DecisionTreeClassificationModel(DecisionTreeModel, JavaMLWritable, JavaMLR - Normalize importances for tree to sum to 1. Note: Feature importance for single decision trees can have high variance due to - correlated predictor variables. Consider using a :class:`RandomForestClassifier` + correlated predictor variables. Consider using a :py:class:`RandomForestClassifier` to determine feature importance instead. """ return self._call_java("featureImportances") @@ -500,16 +500,12 @@ class RandomForestClassificationModel(TreeEnsembleModels): """ Estimate of the importance of each feature. - This generalizes the idea of "Gini" importance to other losses, - following the explanation of Gini importance from "Random Forests" documentation - by Leo Breiman and Adele Cutler, and following the implementation from scikit-learn. + Each feature's importance is the average of its importance across all trees in the ensemble + The importance vector is normalized to sum to 1. This method is suggested by Hastie et al. + (Hastie, Tibshirani, Friedman. "The Elements of Statistical Learning, 2nd Edition." 2001.) + and follows the implementation from scikit-learn. - This feature importance is calculated as follows: - - Average over trees: - - importance(feature j) = sum (over nodes which split on feature j) of the gain, - where gain is scaled by the number of instances passing through node - - Normalize importances for tree to sum to 1. - - Normalize feature importance vector to sum to 1. + .. seealso:: :py:attr:`DecisionTreeClassificationModel.featureImportances` """ return self._call_java("featureImportances") @@ -534,6 +530,8 @@ class GBTClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol >>> td = si_model.transform(df) >>> gbt = GBTClassifier(maxIter=5, maxDepth=2, labelCol="indexed", seed=42) >>> model = gbt.fit(td) + >>> model.featureImportances + SparseVector(1, {0: 1.0}) >>> allclose(model.treeWeights, [1.0, 0.1, 0.1, 0.1, 0.1]) True >>> test0 = sqlContext.createDataFrame([(Vectors.dense(-1.0),)], ["features"]) @@ -613,6 +611,21 @@ class GBTClassificationModel(TreeEnsembleModels): .. versionadded:: 1.4.0 """ + @property + @since("2.0.0") + def featureImportances(self): + """ + Estimate of the importance of each feature. + + Each feature's importance is the average of its importance across all trees in the ensemble + The importance vector is normalized to sum to 1. This method is suggested by Hastie et al. + (Hastie, Tibshirani, Friedman. "The Elements of Statistical Learning, 2nd Edition." 2001.) + and follows the implementation from scikit-learn. + + .. seealso:: :py:attr:`DecisionTreeClassificationModel.featureImportances` + """ + return self._call_java("featureImportances") + @inherit_doc class NaiveBayes(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasProbabilityCol, diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py index 37648549de..de8a5e4bed 100644 --- a/python/pyspark/ml/regression.py +++ b/python/pyspark/ml/regression.py @@ -533,7 +533,7 @@ class DecisionTreeRegressionModel(DecisionTreeModel, JavaMLWritable, JavaMLReada - Normalize importances for tree to sum to 1. Note: Feature importance for single decision trees can have high variance due to - correlated predictor variables. Consider using a :class:`RandomForestRegressor` + correlated predictor variables. Consider using a :py:class:`RandomForestRegressor` to determine feature importance instead. """ return self._call_java("featureImportances") @@ -626,16 +626,12 @@ class RandomForestRegressionModel(TreeEnsembleModels): """ Estimate of the importance of each feature. - This generalizes the idea of "Gini" importance to other losses, - following the explanation of Gini importance from "Random Forests" documentation - by Leo Breiman and Adele Cutler, and following the implementation from scikit-learn. + Each feature's importance is the average of its importance across all trees in the ensemble + The importance vector is normalized to sum to 1. This method is suggested by Hastie et al. + (Hastie, Tibshirani, Friedman. "The Elements of Statistical Learning, 2nd Edition." 2001.) + and follows the implementation from scikit-learn. - This feature importance is calculated as follows: - - Average over trees: - - importance(feature j) = sum (over nodes which split on feature j) of the gain, - where gain is scaled by the number of instances passing through node - - Normalize importances for tree to sum to 1. - - Normalize feature importance vector to sum to 1. + .. seealso:: :py:attr:`DecisionTreeRegressionModel.featureImportances` """ return self._call_java("featureImportances") @@ -655,6 +651,8 @@ class GBTRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, ... (0.0, Vectors.sparse(1, [], []))], ["label", "features"]) >>> gbt = GBTRegressor(maxIter=5, maxDepth=2, seed=42) >>> model = gbt.fit(df) + >>> model.featureImportances + SparseVector(1, {0: 1.0}) >>> allclose(model.treeWeights, [1.0, 0.1, 0.1, 0.1, 0.1]) True >>> test0 = sqlContext.createDataFrame([(Vectors.dense(-1.0),)], ["features"]) @@ -734,6 +732,21 @@ class GBTRegressionModel(TreeEnsembleModels): .. versionadded:: 1.4.0 """ + @property + @since("2.0.0") + def featureImportances(self): + """ + Estimate of the importance of each feature. + + Each feature's importance is the average of its importance across all trees in the ensemble + The importance vector is normalized to sum to 1. This method is suggested by Hastie et al. + (Hastie, Tibshirani, Friedman. "The Elements of Statistical Learning, 2nd Edition." 2001.) + and follows the implementation from scikit-learn. + + .. seealso:: :py:attr:`DecisionTreeRegressionModel.featureImportances` + """ + return self._call_java("featureImportances") + @inherit_doc class AFTSurvivalRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, -- cgit v1.2.3 From a7af6cd2eaf9f6ff491b9e1fabfc9c6f3d0f54bf Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Thu, 31 Mar 2016 13:52:59 -0700 Subject: [SPARK-14281][TESTS] Fix java8-tests and simplify their build This patch fixes a compilation / build break in Spark's `java8-tests` and refactors their POM to simplify the build. See individual commit messages for more details. Author: Josh Rosen Closes #12073 from JoshRosen/fix-java8-tests. --- docs/building-spark.md | 8 +-- external/java8-tests/README.md | 8 +-- external/java8-tests/pom.xml | 75 ++++------------------ .../test/java/org/apache/spark/Java8APISuite.java | 10 +-- .../org/apache/spark/streaming/Java8APISuite.java | 18 +++--- .../src/test/resources/log4j.properties | 1 - pom.xml | 31 ++++----- 7 files changed, 46 insertions(+), 105 deletions(-) diff --git a/docs/building-spark.md b/docs/building-spark.md index 1e202acb9e..13aa80496e 100644 --- a/docs/building-spark.md +++ b/docs/building-spark.md @@ -180,14 +180,14 @@ For help in setting up IntelliJ IDEA or Eclipse for Spark development, and troub Running only Java 8 tests and nothing else. - mvn install -DskipTests -Pjava8-tests + mvn install -DskipTests + mvn -pl :java8-tests_2.11 test or - sbt -Pjava8-tests java8-tests/test + sbt java8-tests/test -Java 8 tests are run when `-Pjava8-tests` profile is enabled, they will run in spite of `-DskipTests`. -For these tests to run your system must have a JDK 8 installation. +Java 8 tests are automatically enabled when a Java 8 JDK is detected. If you have JDK 8 installed but it is not the system default, you can set JAVA_HOME to point to JDK 8 before running the tests. # Building for PySpark on YARN diff --git a/external/java8-tests/README.md b/external/java8-tests/README.md index dc9e87f2ee..aa87901695 100644 --- a/external/java8-tests/README.md +++ b/external/java8-tests/README.md @@ -8,16 +8,14 @@ to your Java location. The set-up depends a bit on the build system: `-java-home` to the sbt launch script. If a Java 8 JDK is detected sbt will automatically include the Java 8 test project. - `$ JAVA_HOME=/opt/jdk1.8.0/ build/sbt clean "test-only org.apache.spark.Java8APISuite"` + `$ JAVA_HOME=/opt/jdk1.8.0/ build/sbt clean java8-tests/test * For Maven users, - Maven users can also refer to their Java 8 directory using JAVA_HOME. However, Maven will not - automatically detect the presence of a Java 8 JDK, so a special build profile `-Pjava8-tests` - must be used. + Maven users can also refer to their Java 8 directory using JAVA_HOME. `$ JAVA_HOME=/opt/jdk1.8.0/ mvn clean install -DskipTests` - `$ JAVA_HOME=/opt/jdk1.8.0/ mvn test -Pjava8-tests -DwildcardSuites=org.apache.spark.Java8APISuite` + `$ JAVA_HOME=/opt/jdk1.8.0/ mvn -pl :java8-tests_2.11 test` Note that the above command can only be run from project root directory since this module depends on core and the test-jars of core and streaming. This means an install step is diff --git a/external/java8-tests/pom.xml b/external/java8-tests/pom.xml index 0ad9c5303a..f5a06467ee 100644 --- a/external/java8-tests/pom.xml +++ b/external/java8-tests/pom.xml @@ -27,7 +27,7 @@ org.apache.spark java8-tests_2.11 pom - Spark Project Java8 Tests POM + Spark Project Java 8 Tests java8-tests @@ -64,11 +64,6 @@ - - - java8-tests - - @@ -85,76 +80,28 @@ true - - org.apache.maven.plugins - maven-surefire-plugin - - - test - - test - - - - - - - - file:src/test/resources/log4j.properties - - - false - - **/Suite*.java - **/*Suite.java - - - org.apache.maven.plugins maven-compiler-plugin - - - test-compile-first - process-test-resources - - testCompile - - - - true - true true 1.8 - 1.8 1.8 - UTF-8 - 1024m + 1.8 - net.alchim31.maven scala-maven-plugin - - - none - - - scala-compile-first - none - - - scala-test-compile-first - none - - - attach-scaladocs - none - - + + + -source + 1.8 + -target + 1.8 + -Xlint:all,-serial,-path + + diff --git a/external/java8-tests/src/test/java/org/apache/spark/Java8APISuite.java b/external/java8-tests/src/test/java/org/apache/spark/Java8APISuite.java index c0b58e713f..6ac5ca9cf5 100644 --- a/external/java8-tests/src/test/java/org/apache/spark/Java8APISuite.java +++ b/external/java8-tests/src/test/java/org/apache/spark/Java8APISuite.java @@ -188,7 +188,7 @@ public class Java8APISuite implements Serializable { public void flatMap() { JavaRDD rdd = sc.parallelize(Arrays.asList("Hello World!", "The quick brown fox jumps over the lazy dog.")); - JavaRDD words = rdd.flatMap(x -> Arrays.asList(x.split(" "))); + JavaRDD words = rdd.flatMap(x -> Arrays.asList(x.split(" ")).iterator()); Assert.assertEquals("Hello", words.first()); Assert.assertEquals(11, words.count()); @@ -198,7 +198,7 @@ public class Java8APISuite implements Serializable { for (String word : s.split(" ")) { pairs2.add(new Tuple2<>(word, word)); } - return pairs2; + return pairs2.iterator(); }); Assert.assertEquals(new Tuple2<>("Hello", "Hello"), pairs.first()); @@ -209,7 +209,7 @@ public class Java8APISuite implements Serializable { for (String word : s.split(" ")) { lengths.add((double) word.length()); } - return lengths; + return lengths.iterator(); }); Assert.assertEquals(5.0, doubles.first(), 0.01); @@ -227,7 +227,7 @@ public class Java8APISuite implements Serializable { // Regression test for SPARK-668: JavaPairRDD swapped = - pairRDD.flatMapToPair(x -> Collections.singletonList(x.swap())); + pairRDD.flatMapToPair(x -> Collections.singletonList(x.swap()).iterator()); swapped.collect(); // There was never a bug here, but it's worth testing: @@ -242,7 +242,7 @@ public class Java8APISuite implements Serializable { while (iter.hasNext()) { sum += iter.next(); } - return Collections.singletonList(sum); + return Collections.singletonList(sum).iterator(); }); Assert.assertEquals("[3, 7]", partitionSums.collect().toString()); diff --git a/external/java8-tests/src/test/java/org/apache/spark/streaming/Java8APISuite.java b/external/java8-tests/src/test/java/org/apache/spark/streaming/Java8APISuite.java index 604d818ef1..67bc64a444 100644 --- a/external/java8-tests/src/test/java/org/apache/spark/streaming/Java8APISuite.java +++ b/external/java8-tests/src/test/java/org/apache/spark/streaming/Java8APISuite.java @@ -29,6 +29,7 @@ import org.junit.Test; import org.apache.spark.Accumulator; import org.apache.spark.HashPartitioner; +import org.apache.spark.api.java.Optional; import org.apache.spark.api.java.JavaPairRDD; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.function.PairFunction; @@ -95,7 +96,7 @@ public class Java8APISuite extends LocalJavaStreamingContext implements Serializ while (in.hasNext()) { out = out + in.next().toUpperCase(); } - return Lists.newArrayList(out); + return Lists.newArrayList(out).iterator(); }); JavaTestUtils.attachTestOutputStream(mapped); List> result = JavaTestUtils.runStreams(ssc, 2, 2); @@ -351,7 +352,8 @@ public class Java8APISuite extends LocalJavaStreamingContext implements Serializ Arrays.asList("a", "t", "h", "l", "e", "t", "i", "c", "s")); JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); - JavaDStream flatMapped = stream.flatMap(s -> Lists.newArrayList(s.split("(?!^)"))); + JavaDStream flatMapped = stream.flatMap( + s -> Lists.newArrayList(s.split("(?!^)")).iterator()); JavaTestUtils.attachTestOutputStream(flatMapped); List> result = JavaTestUtils.runStreams(ssc, 3, 3); @@ -360,8 +362,8 @@ public class Java8APISuite extends LocalJavaStreamingContext implements Serializ @Test public void testForeachRDD() { - final Accumulator accumRdd = ssc.sc().accumulator(0); - final Accumulator accumEle = ssc.sc().accumulator(0); + final Accumulator accumRdd = ssc.sparkContext().accumulator(0); + final Accumulator accumEle = ssc.sparkContext().accumulator(0); List> inputData = Arrays.asList( Arrays.asList(1,1,1), Arrays.asList(1,1,1)); @@ -375,7 +377,7 @@ public class Java8APISuite extends LocalJavaStreamingContext implements Serializ }); // This is a test to make sure foreachRDD(VoidFunction2) can be called from Java - stream.foreachRDD((rdd, time) -> null); + stream.foreachRDD((rdd, time) -> { return; }); JavaTestUtils.runStreams(ssc, 2, 2); @@ -423,7 +425,7 @@ public class Java8APISuite extends LocalJavaStreamingContext implements Serializ for (String letter : s.split("(?!^)")) { out.add(new Tuple2<>(s.length(), letter)); } - return out; + return out.iterator(); }); JavaTestUtils.attachTestOutputStream(flatMapped); @@ -541,7 +543,7 @@ public class Java8APISuite extends LocalJavaStreamingContext implements Serializ Tuple2 next = in.next(); out.add(next.swap()); } - return out; + return out.iterator(); }); JavaTestUtils.attachTestOutputStream(reversed); @@ -598,7 +600,7 @@ public class Java8APISuite extends LocalJavaStreamingContext implements Serializ for (Character s : in._1().toCharArray()) { out.add(new Tuple2<>(in._2(), s.toString())); } - return out; + return out.iterator(); }); JavaTestUtils.attachTestOutputStream(flatMapped); diff --git a/external/java8-tests/src/test/resources/log4j.properties b/external/java8-tests/src/test/resources/log4j.properties index eb3b1999eb..edbecdae92 100644 --- a/external/java8-tests/src/test/resources/log4j.properties +++ b/external/java8-tests/src/test/resources/log4j.properties @@ -25,4 +25,3 @@ log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{ # Ignore messages below warning level from Jetty, because it's a bit verbose log4j.logger.org.spark-project.jetty=WARN -org.spark-project.jetty.LEVEL=WARN diff --git a/pom.xml b/pom.xml index 25d6136421..37606926d7 100644 --- a/pom.xml +++ b/pom.xml @@ -1920,6 +1920,7 @@ ${test.java.home} + file:src/test/resources/log4j.properties test true ${project.build.directory}/tmp @@ -1935,6 +1936,14 @@ false ${test.exclude.tags} + + + test + + test + + + @@ -1959,6 +1968,7 @@ ${test.java.home} + file:src/test/resources/log4j.properties test true ${project.build.directory}/tmp @@ -2343,27 +2353,12 @@ java8-tests - - - - - org.apache.maven.plugins - maven-jar-plugin - - - - test-jar - - - - - - - + + [1.8,) + external/java8-tests - -- cgit v1.2.3 From 8de201baedc8e839e06098c536ba31b3dafd54b5 Mon Sep 17 00:00:00 2001 From: Sital Kedia Date: Thu, 31 Mar 2016 16:06:44 -0700 Subject: [SPARK-14277][CORE] Upgrade Snappy Java to 1.1.2.4 ## What changes were proposed in this pull request? Upgrade snappy to 1.1.2.4 to improve snappy read/write performance. ## How was this patch tested? Tested by running a job on the cluster and saw 7.5% cpu savings after this change. Author: Sital Kedia Closes #12096 from sitalkedia/snappyRelease. --- dev/deps/spark-deps-hadoop-2.2 | 2 +- dev/deps/spark-deps-hadoop-2.3 | 2 +- dev/deps/spark-deps-hadoop-2.4 | 2 +- dev/deps/spark-deps-hadoop-2.6 | 2 +- dev/deps/spark-deps-hadoop-2.7 | 2 +- pom.xml | 2 +- 6 files changed, 6 insertions(+), 6 deletions(-) diff --git a/dev/deps/spark-deps-hadoop-2.2 b/dev/deps/spark-deps-hadoop-2.2 index 0c4e43b9c8..115018e7c1 100644 --- a/dev/deps/spark-deps-hadoop-2.2 +++ b/dev/deps/spark-deps-hadoop-2.2 @@ -168,7 +168,7 @@ servlet-api-2.5.jar slf4j-api-1.7.16.jar slf4j-log4j12-1.7.16.jar snappy-0.2.jar -snappy-java-1.1.2.1.jar +snappy-java-1.1.2.4.jar spire-macros_2.11-0.7.4.jar spire_2.11-0.7.4.jar stax-api-1.0-2.jar diff --git a/dev/deps/spark-deps-hadoop-2.3 b/dev/deps/spark-deps-hadoop-2.3 index a0d62a1c30..246d1147bf 100644 --- a/dev/deps/spark-deps-hadoop-2.3 +++ b/dev/deps/spark-deps-hadoop-2.3 @@ -159,7 +159,7 @@ servlet-api-2.5.jar slf4j-api-1.7.16.jar slf4j-log4j12-1.7.16.jar snappy-0.2.jar -snappy-java-1.1.2.1.jar +snappy-java-1.1.2.4.jar spire-macros_2.11-0.7.4.jar spire_2.11-0.7.4.jar stax-api-1.0-2.jar diff --git a/dev/deps/spark-deps-hadoop-2.4 b/dev/deps/spark-deps-hadoop-2.4 index cc6e40329c..0e2cdaf0d2 100644 --- a/dev/deps/spark-deps-hadoop-2.4 +++ b/dev/deps/spark-deps-hadoop-2.4 @@ -160,7 +160,7 @@ servlet-api-2.5.jar slf4j-api-1.7.16.jar slf4j-log4j12-1.7.16.jar snappy-0.2.jar -snappy-java-1.1.2.1.jar +snappy-java-1.1.2.4.jar spire-macros_2.11-0.7.4.jar spire_2.11-0.7.4.jar stax-api-1.0-2.jar diff --git a/dev/deps/spark-deps-hadoop-2.6 b/dev/deps/spark-deps-hadoop-2.6 index 5c93db5082..1ed15595be 100644 --- a/dev/deps/spark-deps-hadoop-2.6 +++ b/dev/deps/spark-deps-hadoop-2.6 @@ -166,7 +166,7 @@ servlet-api-2.5.jar slf4j-api-1.7.16.jar slf4j-log4j12-1.7.16.jar snappy-0.2.jar -snappy-java-1.1.2.1.jar +snappy-java-1.1.2.4.jar spire-macros_2.11-0.7.4.jar spire_2.11-0.7.4.jar stax-api-1.0-2.jar diff --git a/dev/deps/spark-deps-hadoop-2.7 b/dev/deps/spark-deps-hadoop-2.7 index 860fd79aad..218631ed6e 100644 --- a/dev/deps/spark-deps-hadoop-2.7 +++ b/dev/deps/spark-deps-hadoop-2.7 @@ -167,7 +167,7 @@ servlet-api-2.5.jar slf4j-api-1.7.16.jar slf4j-log4j12-1.7.16.jar snappy-0.2.jar -snappy-java-1.1.2.1.jar +snappy-java-1.1.2.4.jar spire-macros_2.11-0.7.4.jar spire_2.11-0.7.4.jar stax-api-1.0-2.jar diff --git a/pom.xml b/pom.xml index 37606926d7..be80e6b80c 100644 --- a/pom.xml +++ b/pom.xml @@ -162,7 +162,7 @@ org.scala-lang 1.9.13 2.5.3 - 1.1.2.1 + 1.1.2.4 1.1.2 1.2.0-incubating 1.10 -- cgit v1.2.3 From f0afafdc5dfee80d7e5cd2fc1fa8187def7f262d Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Thu, 31 Mar 2016 16:40:20 -0700 Subject: [SPARK-14267] [SQL] [PYSPARK] execute multiple Python UDFs within single batch ## What changes were proposed in this pull request? This PR support multiple Python UDFs within single batch, also improve the performance. ```python >>> from pyspark.sql.types import IntegerType >>> sqlContext.registerFunction("double", lambda x: x * 2, IntegerType()) >>> sqlContext.registerFunction("add", lambda x, y: x + y, IntegerType()) >>> sqlContext.sql("SELECT double(add(1, 2)), add(double(2), 1)").explain(True) == Parsed Logical Plan == 'Project [unresolvedalias('double('add(1, 2)), None),unresolvedalias('add('double(2), 1), None)] +- OneRowRelation$ == Analyzed Logical Plan == double(add(1, 2)): int, add(double(2), 1): int Project [double(add(1, 2))#14,add(double(2), 1)#15] +- Project [double(add(1, 2))#14,add(double(2), 1)#15] +- Project [pythonUDF0#16 AS double(add(1, 2))#14,pythonUDF0#18 AS add(double(2), 1)#15] +- EvaluatePython [add(pythonUDF1#17, 1)], [pythonUDF0#18] +- EvaluatePython [double(add(1, 2)),double(2)], [pythonUDF0#16,pythonUDF1#17] +- OneRowRelation$ == Optimized Logical Plan == Project [pythonUDF0#16 AS double(add(1, 2))#14,pythonUDF0#18 AS add(double(2), 1)#15] +- EvaluatePython [add(pythonUDF1#17, 1)], [pythonUDF0#18] +- EvaluatePython [double(add(1, 2)),double(2)], [pythonUDF0#16,pythonUDF1#17] +- OneRowRelation$ == Physical Plan == WholeStageCodegen : +- Project [pythonUDF0#16 AS double(add(1, 2))#14,pythonUDF0#18 AS add(double(2), 1)#15] : +- INPUT +- !BatchPythonEvaluation [add(pythonUDF1#17, 1)], [pythonUDF0#16,pythonUDF1#17,pythonUDF0#18] +- !BatchPythonEvaluation [double(add(1, 2)),double(2)], [pythonUDF0#16,pythonUDF1#17] +- Scan OneRowRelation[] ``` ## How was this patch tested? Added new tests. Using the following script to benchmark 1, 2 and 3 udfs, ``` df = sqlContext.range(1, 1 << 23, 1, 4) double = F.udf(lambda x: x * 2, LongType()) print df.select(double(df.id)).count() print df.select(double(df.id), double(df.id + 1)).count() print df.select(double(df.id), double(df.id + 1), double(df.id + 2)).count() ``` Here is the results: N | Before | After | speed up ---- |------------ | -------------|------ 1 | 22 s | 7 s | 3.1X 2 | 38 s | 13 s | 2.9X 3 | 58 s | 16 s | 3.6X This benchmark ran locally with 4 CPUs. For 3 UDFs, it launched 12 Python before before this patch, 4 process after this patch. After this patch, it will use less memory for multiple UDFs than before (less buffering). Author: Davies Liu Closes #12057 from davies/multi_udfs. --- .../org/apache/spark/api/python/PythonRDD.scala | 64 +++++++++++++----- python/pyspark/sql/functions.py | 3 +- python/pyspark/sql/tests.py | 12 +++- python/pyspark/worker.py | 68 ++++++++++++++----- .../spark/sql/execution/SparkStrategies.scala | 4 +- .../execution/python/BatchPythonEvaluation.scala | 78 ++++++++++++++++------ .../sql/execution/python/EvaluatePython.scala | 28 ++++++-- .../sql/execution/python/ExtractPythonUDFs.scala | 77 +++++++++++---------- 8 files changed, 233 insertions(+), 101 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index 0f579b4ef5..6faa03c12b 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -59,7 +59,7 @@ private[spark] class PythonRDD( val asJavaRDD: JavaRDD[Array[Byte]] = JavaRDD.fromRDD(this) override def compute(split: Partition, context: TaskContext): Iterator[Array[Byte]] = { - val runner = new PythonRunner(Seq(func), bufferSize, reuse_worker, false) + val runner = PythonRunner(func, bufferSize, reuse_worker) runner.compute(firstParent.iterator(split, context), split.index, context) } } @@ -78,21 +78,41 @@ private[spark] case class PythonFunction( accumulator: Accumulator[JList[Array[Byte]]]) /** - * A helper class to run Python UDFs in Spark. + * A wrapper for chained Python functions (from bottom to top). + * @param funcs + */ +private[spark] case class ChainedPythonFunctions(funcs: Seq[PythonFunction]) + +private[spark] object PythonRunner { + def apply(func: PythonFunction, bufferSize: Int, reuse_worker: Boolean): PythonRunner = { + new PythonRunner( + Seq(ChainedPythonFunctions(Seq(func))), bufferSize, reuse_worker, false, Array(Array(0))) + } +} + +/** + * A helper class to run Python mapPartition/UDFs in Spark. + * + * funcs is a list of independent Python functions, each one of them is a list of chained Python + * functions (from bottom to top). */ private[spark] class PythonRunner( - funcs: Seq[PythonFunction], + funcs: Seq[ChainedPythonFunctions], bufferSize: Int, reuse_worker: Boolean, - rowBased: Boolean) + isUDF: Boolean, + argOffsets: Array[Array[Int]]) extends Logging { + require(funcs.length == argOffsets.length, "argOffsets should have the same length as funcs") + // All the Python functions should have the same exec, version and envvars. - private val envVars = funcs.head.envVars - private val pythonExec = funcs.head.pythonExec - private val pythonVer = funcs.head.pythonVer + private val envVars = funcs.head.funcs.head.envVars + private val pythonExec = funcs.head.funcs.head.pythonExec + private val pythonVer = funcs.head.funcs.head.pythonVer - private val accumulator = funcs.head.accumulator // TODO: support accumulator in multiple UDF + // TODO: support accumulator in multiple UDF + private val accumulator = funcs.head.funcs.head.accumulator def compute( inputIterator: Iterator[_], @@ -232,8 +252,8 @@ private[spark] class PythonRunner( @volatile private var _exception: Exception = null - private val pythonIncludes = funcs.flatMap(_.pythonIncludes.asScala).toSet - private val broadcastVars = funcs.flatMap(_.broadcastVars.asScala) + private val pythonIncludes = funcs.flatMap(_.funcs.flatMap(_.pythonIncludes.asScala)).toSet + private val broadcastVars = funcs.flatMap(_.funcs.flatMap(_.broadcastVars.asScala)) setDaemon(true) @@ -284,11 +304,25 @@ private[spark] class PythonRunner( } dataOut.flush() // Serialized command: - dataOut.writeInt(if (rowBased) 1 else 0) - dataOut.writeInt(funcs.length) - funcs.foreach { f => - dataOut.writeInt(f.command.length) - dataOut.write(f.command) + if (isUDF) { + dataOut.writeInt(1) + dataOut.writeInt(funcs.length) + funcs.zip(argOffsets).foreach { case (chained, offsets) => + dataOut.writeInt(offsets.length) + offsets.foreach { offset => + dataOut.writeInt(offset) + } + dataOut.writeInt(chained.funcs.length) + chained.funcs.foreach { f => + dataOut.writeInt(f.command.length) + dataOut.write(f.command) + } + } + } else { + dataOut.writeInt(0) + val command = funcs.head.funcs.head.command + dataOut.writeInt(command.length) + dataOut.write(command) } // Data values PythonRDD.writeIteratorToStream(inputIterator, dataOut) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 3211834226..3b20ba5177 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -1649,8 +1649,7 @@ def sort_array(col, asc=True): # ---------------------------- User Defined Function ---------------------------------- def _wrap_function(sc, func, returnType): - ser = AutoBatchedSerializer(PickleSerializer()) - command = (func, returnType, ser) + command = (func, returnType) pickled_command, broadcast_vars, env, includes = _prepare_for_python_RDD(sc, command) return sc._jvm.PythonFunction(bytearray(pickled_command), env, includes, sc.pythonExec, sc.pythonVer, broadcast_vars, sc._javaAccumulator) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 84947560e7..536ef55251 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -305,7 +305,7 @@ class SQLTests(ReusedPySparkTestCase): [res] = self.sqlCtx.sql("SELECT strlen(a) FROM test WHERE strlen(a) > 1").collect() self.assertEqual(4, res[0]) - def test_chained_python_udf(self): + def test_chained_udf(self): self.sqlCtx.registerFunction("double", lambda x: x + x, IntegerType()) [row] = self.sqlCtx.sql("SELECT double(1)").collect() self.assertEqual(row[0], 2) @@ -314,6 +314,16 @@ class SQLTests(ReusedPySparkTestCase): [row] = self.sqlCtx.sql("SELECT double(double(1) + 1)").collect() self.assertEqual(row[0], 6) + def test_multiple_udfs(self): + self.sqlCtx.registerFunction("double", lambda x: x * 2, IntegerType()) + [row] = self.sqlCtx.sql("SELECT double(1), double(2)").collect() + self.assertEqual(tuple(row), (2, 4)) + [row] = self.sqlCtx.sql("SELECT double(double(1)), double(double(2) + 2)").collect() + self.assertEqual(tuple(row), (4, 12)) + self.sqlCtx.registerFunction("add", lambda x, y: x + y, IntegerType()) + [row] = self.sqlCtx.sql("SELECT double(add(1, 2)), add(double(2), 1)").collect() + self.assertEqual(tuple(row), (6, 5)) + def test_udf_with_array_type(self): d = [Row(l=list(range(3)), d={"key": list(range(5))})] rdd = self.sc.parallelize(d) diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 0f05fe31aa..cf47ab8f96 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -29,7 +29,7 @@ from pyspark.accumulators import _accumulatorRegistry from pyspark.broadcast import Broadcast, _broadcastRegistry from pyspark.files import SparkFiles from pyspark.serializers import write_with_length, write_int, read_long, \ - write_long, read_int, SpecialLengths, UTF8Deserializer, PickleSerializer + write_long, read_int, SpecialLengths, UTF8Deserializer, PickleSerializer, BatchedSerializer from pyspark import shuffle pickleSer = PickleSerializer() @@ -59,7 +59,54 @@ def read_command(serializer, file): def chain(f, g): """chain two function together """ - return lambda x: g(f(x)) + return lambda *a: g(f(*a)) + + +def wrap_udf(f, return_type): + if return_type.needConversion(): + toInternal = return_type.toInternal + return lambda *a: toInternal(f(*a)) + else: + return lambda *a: f(*a) + + +def read_single_udf(pickleSer, infile): + num_arg = read_int(infile) + arg_offsets = [read_int(infile) for i in range(num_arg)] + row_func = None + for i in range(read_int(infile)): + f, return_type = read_command(pickleSer, infile) + if row_func is None: + row_func = f + else: + row_func = chain(row_func, f) + # the last returnType will be the return type of UDF + return arg_offsets, wrap_udf(row_func, return_type) + + +def read_udfs(pickleSer, infile): + num_udfs = read_int(infile) + if num_udfs == 1: + # fast path for single UDF + _, udf = read_single_udf(pickleSer, infile) + mapper = lambda a: udf(*a) + else: + udfs = {} + call_udf = [] + for i in range(num_udfs): + arg_offsets, udf = read_single_udf(pickleSer, infile) + udfs['f%d' % i] = udf + args = ["a[%d]" % o for o in arg_offsets] + call_udf.append("f%d(%s)" % (i, ", ".join(args))) + # Create function like this: + # lambda a: (f0(a0), f1(a1, a2), f2(a3)) + mapper_str = "lambda a: (%s)" % (", ".join(call_udf)) + mapper = eval(mapper_str, udfs) + + func = lambda _, it: map(mapper, it) + ser = BatchedSerializer(PickleSerializer(), 100) + # profiling is not supported for UDF + return func, None, ser, ser def main(infile, outfile): @@ -107,21 +154,10 @@ def main(infile, outfile): _broadcastRegistry.pop(bid) _accumulatorRegistry.clear() - row_based = read_int(infile) - num_commands = read_int(infile) - if row_based: - profiler = None # profiling is not supported for UDF - row_func = None - for i in range(num_commands): - f, returnType, deserializer = read_command(pickleSer, infile) - if row_func is None: - row_func = f - else: - row_func = chain(row_func, f) - serializer = deserializer - func = lambda _, it: map(lambda x: returnType.toInternal(row_func(*x)), it) + is_sql_udf = read_int(infile) + if is_sql_udf: + func, profiler, deserializer, serializer = read_udfs(pickleSer, infile) else: - assert num_commands == 1 func, profiler, deserializer, serializer = read_command(pickleSer, infile) init_time = time.time() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 7841ff01f9..7a2e2b7382 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -426,8 +426,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case logical.RepartitionByExpression(expressions, child, nPartitions) => exchange.ShuffleExchange(HashPartitioning( expressions, nPartitions.getOrElse(numPartitions)), planLater(child)) :: Nil - case e @ python.EvaluatePython(udf, child, _) => - python.BatchPythonEvaluation(udf, e.output, planLater(child)) :: Nil + case e @ python.EvaluatePython(udfs, child, _) => + python.BatchPythonEvaluation(udfs, e.output, planLater(child)) :: Nil case LogicalRDD(output, rdd) => PhysicalRDD(output, rdd, "ExistingRDD") :: Nil case BroadcastHint(child) => planLater(child) :: Nil case _ => Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchPythonEvaluation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchPythonEvaluation.scala index a76009e7df..c9ab40a0a9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchPythonEvaluation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchPythonEvaluation.scala @@ -18,16 +18,17 @@ package org.apache.spark.sql.execution.python import scala.collection.JavaConverters._ +import scala.collection.mutable.ArrayBuffer import net.razorvine.pickle.{Pickler, Unpickler} import org.apache.spark.TaskContext -import org.apache.spark.api.python.{PythonFunction, PythonRunner} +import org.apache.spark.api.python.{ChainedPythonFunctions, PythonRunner} import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.execution.SparkPlan -import org.apache.spark.sql.types.{StructField, StructType} +import org.apache.spark.sql.types.{DataType, StructField, StructType} /** @@ -40,20 +41,20 @@ import org.apache.spark.sql.types.{StructField, StructType} * we drain the queue to find the original input row. Note that if the Python process is way too * slow, this could lead to the queue growing unbounded and eventually run out of memory. */ -case class BatchPythonEvaluation(udf: PythonUDF, output: Seq[Attribute], child: SparkPlan) +case class BatchPythonEvaluation(udfs: Seq[PythonUDF], output: Seq[Attribute], child: SparkPlan) extends SparkPlan { def children: Seq[SparkPlan] = child :: Nil - private def collectFunctions(udf: PythonUDF): (Seq[PythonFunction], Seq[Expression]) = { + private def collectFunctions(udf: PythonUDF): (ChainedPythonFunctions, Seq[Expression]) = { udf.children match { case Seq(u: PythonUDF) => - val (fs, children) = collectFunctions(u) - (fs ++ Seq(udf.func), children) + val (chained, children) = collectFunctions(u) + (ChainedPythonFunctions(chained.funcs ++ Seq(udf.func)), children) case children => // There should not be any other UDFs, or the children can't be evaluated directly. assert(children.forall(_.find(_.isInstanceOf[PythonUDF]).isEmpty)) - (Seq(udf.func), udf.children) + (ChainedPythonFunctions(Seq(udf.func)), udf.children) } } @@ -69,19 +70,47 @@ case class BatchPythonEvaluation(udf: PythonUDF, output: Seq[Attribute], child: // combine input with output from Python. val queue = new java.util.concurrent.ConcurrentLinkedQueue[InternalRow]() - val (pyFuncs, children) = collectFunctions(udf) - - val pickle = new Pickler - val currentRow = newMutableProjection(children, child.output)() - val fields = children.map(_.dataType) - val schema = new StructType(fields.map(t => new StructField("", t, true)).toArray) + val (pyFuncs, inputs) = udfs.map(collectFunctions).unzip + + // flatten all the arguments + val allInputs = new ArrayBuffer[Expression] + val dataTypes = new ArrayBuffer[DataType] + val argOffsets = inputs.map { input => + input.map { e => + if (allInputs.exists(_.semanticEquals(e))) { + allInputs.indexWhere(_.semanticEquals(e)) + } else { + allInputs += e + dataTypes += e.dataType + allInputs.length - 1 + } + }.toArray + }.toArray + val projection = newMutableProjection(allInputs, child.output)() + val schema = StructType(dataTypes.map(dt => StructField("", dt))) + val needConversion = dataTypes.exists(EvaluatePython.needConversionInPython) + // enable memo iff we serialize the row with schema (schema and class should be memorized) + val pickle = new Pickler(needConversion) // Input iterator to Python: input rows are grouped so we send them in batches to Python. // For each row, add it to the queue. val inputIterator = iter.grouped(100).map { inputRows => - val toBePickled = inputRows.map { row => - queue.add(row) - EvaluatePython.toJava(currentRow(row), schema) + val toBePickled = inputRows.map { inputRow => + queue.add(inputRow) + val row = projection(inputRow) + if (needConversion) { + EvaluatePython.toJava(row, schema) + } else { + // fast path for these types that does not need conversion in Python + val fields = new Array[Any](row.numFields) + var i = 0 + while (i < row.numFields) { + val dt = dataTypes(i) + fields(i) = EvaluatePython.toJava(row.get(i, dt), dt) + i += 1 + } + fields + } }.toArray pickle.dumps(toBePickled) } @@ -89,19 +118,30 @@ case class BatchPythonEvaluation(udf: PythonUDF, output: Seq[Attribute], child: val context = TaskContext.get() // Output iterator for results from Python. - val outputIterator = new PythonRunner(pyFuncs, bufferSize, reuseWorker, true) + val outputIterator = new PythonRunner(pyFuncs, bufferSize, reuseWorker, true, argOffsets) .compute(inputIterator, context.partitionId(), context) val unpickle = new Unpickler - val row = new GenericMutableRow(1) + val mutableRow = new GenericMutableRow(1) val joined = new JoinedRow + val resultType = if (udfs.length == 1) { + udfs.head.dataType + } else { + StructType(udfs.map(u => StructField("", u.dataType, u.nullable))) + } val resultProj = UnsafeProjection.create(output, output) outputIterator.flatMap { pickedResult => val unpickledBatch = unpickle.loads(pickedResult) unpickledBatch.asInstanceOf[java.util.ArrayList[Any]].asScala }.map { result => - row(0) = EvaluatePython.fromJava(result, udf.dataType) + val row = if (udfs.length == 1) { + // fast path for single UDF + mutableRow(0) = EvaluatePython.fromJava(result, resultType) + mutableRow + } else { + EvaluatePython.fromJava(result, resultType).asInstanceOf[InternalRow] + } resultProj(joined(queue.poll(), row)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala index da28ec4f53..f3d1c44b25 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala @@ -36,24 +36,28 @@ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String /** - * Evaluates a [[PythonUDF]], appending the result to the end of the input tuple. + * Evaluates a list of [[PythonUDF]], appending the result to the end of the input tuple. */ case class EvaluatePython( - udf: PythonUDF, + udfs: Seq[PythonUDF], child: LogicalPlan, - resultAttribute: AttributeReference) + resultAttribute: Seq[AttributeReference]) extends logical.UnaryNode { - def output: Seq[Attribute] = child.output :+ resultAttribute + def output: Seq[Attribute] = child.output ++ resultAttribute // References should not include the produced attribute. - override def references: AttributeSet = udf.references + override def references: AttributeSet = AttributeSet(udfs.flatMap(_.references)) } object EvaluatePython { - def apply(udf: PythonUDF, child: LogicalPlan): EvaluatePython = - new EvaluatePython(udf, child, AttributeReference("pythonUDF", udf.dataType)()) + def apply(udfs: Seq[PythonUDF], child: LogicalPlan): EvaluatePython = { + val resultAttrs = udfs.zipWithIndex.map { case (u, i) => + AttributeReference(s"pythonUDF$i", u.dataType)() + } + new EvaluatePython(udfs, child, resultAttrs) + } def takeAndServe(df: DataFrame, n: Int): Int = { registerPicklers() @@ -66,6 +70,16 @@ object EvaluatePython { } } + def needConversionInPython(dt: DataType): Boolean = dt match { + case DateType | TimestampType => true + case _: StructType => true + case _: UserDefinedType[_] => true + case ArrayType(elementType, _) => needConversionInPython(elementType) + case MapType(keyType, valueType, _) => + needConversionInPython(keyType) || needConversionInPython(valueType) + case _ => false + } + /** * Helper for converting from Catalyst type to java type suitable for Pyrolite. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala index c486ce18e8..0934cd135d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala @@ -17,7 +17,9 @@ package org.apache.spark.sql.execution.python -import org.apache.spark.sql.catalyst.expressions.Expression +import scala.collection.mutable + +import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression} import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.Rule @@ -47,10 +49,9 @@ private[spark] object ExtractPythonUDFs extends Rule[LogicalPlan] { } } - private def collectEvaluatableUDF(expr: Expression): Seq[PythonUDF] = { - expr.collect { - case udf: PythonUDF if canEvaluateInPython(udf) => udf - } + private def collectEvaluatableUDF(expr: Expression): Seq[PythonUDF] = expr match { + case udf: PythonUDF if canEvaluateInPython(udf) => Seq(udf) + case e => e.children.flatMap(collectEvaluatableUDF) } def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { @@ -59,45 +60,43 @@ private[spark] object ExtractPythonUDFs extends Rule[LogicalPlan] { case plan: LogicalPlan if plan.resolved => // Extract any PythonUDFs from the current operator. - val udfs = plan.expressions.flatMap(collectEvaluatableUDF) + val udfs = plan.expressions.flatMap(collectEvaluatableUDF).filter(_.resolved) if (udfs.isEmpty) { // If there aren't any, we are done. plan } else { - // Pick the UDF we are going to evaluate (TODO: Support evaluating multiple UDFs at a time) - // If there is more than one, we will add another evaluation operator in a subsequent pass. - udfs.find(_.resolved) match { - case Some(udf) => - var evaluation: EvaluatePython = null - - // Rewrite the child that has the input required for the UDF - val newChildren = plan.children.map { child => - // Check to make sure that the UDF can be evaluated with only the input of this child. - // Other cases are disallowed as they are ambiguous or would require a cartesian - // product. - if (udf.references.subsetOf(child.outputSet)) { - evaluation = EvaluatePython(udf, child) - evaluation - } else if (udf.references.intersect(child.outputSet).nonEmpty) { - sys.error(s"Invalid PythonUDF $udf, requires attributes from more than one child.") - } else { - child - } - } - - assert(evaluation != null, "Unable to evaluate PythonUDF. Missing input attributes.") - - // Trim away the new UDF value if it was only used for filtering or something. - logical.Project( - plan.output, - plan.transformExpressions { - case p: PythonUDF if p.fastEquals(udf) => evaluation.resultAttribute - }.withNewChildren(newChildren)) - - case None => - // If there is no Python UDF that is resolved, skip this round. - plan + val attributeMap = mutable.HashMap[PythonUDF, Expression]() + // Rewrite the child that has the input required for the UDF + val newChildren = plan.children.map { child => + // Pick the UDF we are going to evaluate + val validUdfs = udfs.filter { case udf => + // Check to make sure that the UDF can be evaluated with only the input of this child. + udf.references.subsetOf(child.outputSet) + } + if (validUdfs.nonEmpty) { + val evaluation = EvaluatePython(validUdfs, child) + attributeMap ++= validUdfs.zip(evaluation.resultAttribute) + evaluation + } else { + child + } } + // Other cases are disallowed as they are ambiguous or would require a cartesian + // product. + udfs.filterNot(attributeMap.contains).foreach { udf => + if (udf.references.subsetOf(plan.inputSet)) { + sys.error(s"Invalid PythonUDF $udf, requires attributes from more than one child.") + } else { + sys.error(s"Unable to evaluate PythonUDF $udf. Missing input attributes.") + } + } + + // Trim away the new UDF value if it was only used for filtering or something. + logical.Project( + plan.output, + plan.transformExpressions { + case p: PythonUDF if attributeMap.contains(p) => attributeMap(p) + }.withNewChildren(newChildren)) } } } -- cgit v1.2.3 From 96941b12f8b465df21423275f3cd3ade579b4fa1 Mon Sep 17 00:00:00 2001 From: "Zhang, Liye" Date: Thu, 31 Mar 2016 20:17:52 -0700 Subject: [SPARK-14242][CORE][NETWORK] avoid copy in compositeBuffer for frame decoder ## What changes were proposed in this pull request? In this patch, we set the initial `maxNumComponents` to `Integer.MAX_VALUE` instead of the default size ( which is 16) when allocating `compositeBuffer` in `TransportFrameDecoder` because `compositeBuffer` will introduce too many memory copies underlying if `compositeBuffer` is with default `maxNumComponents` when the frame size is large (which result in many transport messages). For details, please refer to [SPARK-14242](https://issues.apache.org/jira/browse/SPARK-14242). ## How was this patch tested? spark unit tests and manual tests. For manual tests, we can reproduce the performance issue with following code: `sc.parallelize(Array(1,2,3),3).mapPartitions(a=>Array(new Array[Double](1024 * 1024 * 50)).iterator).reduce((a,b)=> a).length` It's easy to see the performance gain, both from the running time and CPU usage. Author: Zhang, Liye Closes #12038 from liyezhang556520/spark-14242. --- .../main/java/org/apache/spark/network/util/TransportFrameDecoder.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/common/network-common/src/main/java/org/apache/spark/network/util/TransportFrameDecoder.java b/common/network-common/src/main/java/org/apache/spark/network/util/TransportFrameDecoder.java index bd1830e6ab..fcec7dfd0c 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/util/TransportFrameDecoder.java +++ b/common/network-common/src/main/java/org/apache/spark/network/util/TransportFrameDecoder.java @@ -140,7 +140,7 @@ public class TransportFrameDecoder extends ChannelInboundHandlerAdapter { } // Otherwise, create a composite buffer. - CompositeByteBuf frame = buffers.getFirst().alloc().compositeBuffer(); + CompositeByteBuf frame = buffers.getFirst().alloc().compositeBuffer(Integer.MAX_VALUE); while (remaining > 0) { ByteBuf next = nextBufferForFrame(remaining); remaining -= next.readableBytes(); -- cgit v1.2.3 From 1b070637fa03ab4966f76427b15e433050eaa956 Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Thu, 31 Mar 2016 23:46:08 -0700 Subject: [SPARK-14295][SPARK-14274][SQL] Implements buildReader() for LibSVM ## What changes were proposed in this pull request? This PR implements `FileFormat.buildReader()` for the LibSVM data source. Besides that, a new interface method `prepareRead()` is added to `FileFormat`: ```scala def prepareRead( sqlContext: SQLContext, options: Map[String, String], files: Seq[FileStatus]): Map[String, String] = options ``` After migrating from `buildInternalScan()` to `buildReader()`, we lost the opportunity to collect necessary global information, since `buildReader()` works in a per-partition manner. For example, LibSVM needs to infer the total number of features if the `numFeatures` data source option is not set. Any necessary collected global information should be returned using the data source options map. By default, this method just returns the original options untouched. An alternative approach is to absorb `inferSchema()` into `prepareRead()`, since schema inference is also some kind of global information gathering. However, this approach wasn't chosen because schema inference is optional, while `prepareRead()` must be called whenever a `HadoopFsRelation` based data source relation is instantiated. One unaddressed problem is that, when `numFeatures` is absent, now the input data will be scanned twice. The `buildInternalScan()` code path doesn't need to do this because it caches the raw parsed RDD in memory before computing the total number of features. However, with `FileScanRDD`, the raw parsed RDD is created in a different way (e.g. partitioning) from the final RDD. ## How was this patch tested? Tested using existing test suites. Author: Cheng Lian Closes #12088 from liancheng/spark-14295-libsvm-build-reader. --- .../spark/ml/source/libsvm/LibSVMRelation.scala | 87 +++++++++++++++++++++- .../org/apache/spark/mllib/util/MLUtils.scala | 73 ++++++++++-------- .../sql/execution/datasources/DataSource.scala | 5 +- .../execution/datasources/FileSourceStrategy.scala | 1 + .../org/apache/spark/sql/sources/interfaces.scala | 9 +++ 5 files changed, 141 insertions(+), 34 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala index 13a13f0a7e..2e9b6be9a2 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala @@ -19,6 +19,7 @@ package org.apache.spark.ml.source.libsvm import java.io.IOException +import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileStatus, Path} import org.apache.hadoop.io.{NullWritable, Text} import org.apache.hadoop.mapreduce.{Job, RecordWriter, TaskAttemptContext} @@ -26,12 +27,16 @@ import org.apache.hadoop.mapreduce.lib.output.TextOutputFormat import org.apache.spark.annotation.Since import org.apache.spark.broadcast.Broadcast -import org.apache.spark.mllib.linalg.{Vector, VectorUDT} +import org.apache.spark.mllib.linalg.{Vector, Vectors, VectorUDT} +import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.util.MLUtils import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, DataFrameReader, Row, SQLContext} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.encoders.RowEncoder +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, JoinedRow} +import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection +import org.apache.spark.sql.execution.datasources.{CaseInsensitiveMap, HadoopFileLinesReader, PartitionedFile} import org.apache.spark.sql.sources._ import org.apache.spark.sql.types._ import org.apache.spark.util.SerializableConfiguration @@ -110,13 +115,16 @@ class DefaultSource extends FileFormat with DataSourceRegister { @Since("1.6.0") override def shortName(): String = "libsvm" + override def toString: String = "LibSVM" + private def verifySchema(dataSchema: StructType): Unit = { if (dataSchema.size != 2 || (!dataSchema(0).dataType.sameType(DataTypes.DoubleType) || !dataSchema(1).dataType.sameType(new VectorUDT()))) { - throw new IOException(s"Illegal schema for libsvm data, schema=${dataSchema}") + throw new IOException(s"Illegal schema for libsvm data, schema=$dataSchema") } } + override def inferSchema( sqlContext: SQLContext, options: Map[String, String], @@ -127,6 +135,32 @@ class DefaultSource extends FileFormat with DataSourceRegister { StructField("features", new VectorUDT(), nullable = false) :: Nil)) } + override def prepareRead( + sqlContext: SQLContext, + options: Map[String, String], + files: Seq[FileStatus]): Map[String, String] = { + def computeNumFeatures(): Int = { + val dataFiles = files.filterNot(_.getPath.getName startsWith "_") + val path = if (dataFiles.length == 1) { + dataFiles.head.getPath.toUri.toString + } else if (dataFiles.isEmpty) { + throw new IOException("No input path specified for libsvm data") + } else { + throw new IOException("Multiple input paths are not supported for libsvm data.") + } + + val sc = sqlContext.sparkContext + val parsed = MLUtils.parseLibSVMFile(sc, path, sc.defaultParallelism) + MLUtils.computeNumFeatures(parsed) + } + + val numFeatures = options.get("numFeatures").filter(_.toInt > 0).getOrElse { + computeNumFeatures() + } + + new CaseInsensitiveMap(options + ("numFeatures" -> numFeatures.toString)) + } + override def prepareWrite( sqlContext: SQLContext, job: Job, @@ -158,7 +192,7 @@ class DefaultSource extends FileFormat with DataSourceRegister { verifySchema(dataSchema) val dataFiles = inputFiles.filterNot(_.getPath.getName startsWith "_") - val path = if (dataFiles.length == 1) dataFiles(0).getPath.toUri.toString + val path = if (dataFiles.length == 1) dataFiles.head.getPath.toUri.toString else if (dataFiles.isEmpty) throw new IOException("No input path specified for libsvm data") else throw new IOException("Multiple input paths are not supported for libsvm data.") @@ -176,4 +210,51 @@ class DefaultSource extends FileFormat with DataSourceRegister { externalRows.map(converter.toRow) } } + + override def buildReader( + sqlContext: SQLContext, + dataSchema: StructType, + partitionSchema: StructType, + requiredSchema: StructType, + filters: Seq[Filter], + options: Map[String, String]): (PartitionedFile) => Iterator[InternalRow] = { + val numFeatures = options("numFeatures").toInt + assert(numFeatures > 0) + + val sparse = options.getOrElse("vectorType", "sparse") == "sparse" + + val broadcastedConf = sqlContext.sparkContext.broadcast( + new SerializableConfiguration(new Configuration(sqlContext.sparkContext.hadoopConfiguration)) + ) + + (file: PartitionedFile) => { + val points = + new HadoopFileLinesReader(file, broadcastedConf.value.value) + .map(_.toString.trim) + .filterNot(line => line.isEmpty || line.startsWith("#")) + .map { line => + val (label, indices, values) = MLUtils.parseLibSVMRecord(line) + LabeledPoint(label, Vectors.sparse(numFeatures, indices, values)) + } + + val converter = RowEncoder(requiredSchema) + + val unsafeRowIterator = points.map { pt => + val features = if (sparse) pt.features.toSparse else pt.features.toDense + converter.toRow(Row(pt.label, features)) + } + + def toAttribute(f: StructField): AttributeReference = + AttributeReference(f.name, f.dataType, f.nullable, f.metadata)() + + // Appends partition values + val fullOutput = (requiredSchema ++ partitionSchema).map(toAttribute) + val joinedRow = new JoinedRow() + val appendPartitionColumns = GenerateUnsafeProjection.generate(fullOutput, fullOutput) + + unsafeRowIterator.map { dataRow => + appendPartitionColumns(joinedRow(dataRow, file.partitionValues)) + } + } + } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala index c3b1d5cdd7..4b9d77949f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala @@ -67,42 +67,14 @@ object MLUtils { path: String, numFeatures: Int, minPartitions: Int): RDD[LabeledPoint] = { - val parsed = sc.textFile(path, minPartitions) - .map(_.trim) - .filter(line => !(line.isEmpty || line.startsWith("#"))) - .map { line => - val items = line.split(' ') - val label = items.head.toDouble - val (indices, values) = items.tail.filter(_.nonEmpty).map { item => - val indexAndValue = item.split(':') - val index = indexAndValue(0).toInt - 1 // Convert 1-based indices to 0-based. - val value = indexAndValue(1).toDouble - (index, value) - }.unzip - - // check if indices are one-based and in ascending order - var previous = -1 - var i = 0 - val indicesLength = indices.length - while (i < indicesLength) { - val current = indices(i) - require(current > previous, s"indices should be one-based and in ascending order;" - + " found current=$current, previous=$previous; line=\"$line\"") - previous = current - i += 1 - } - - (label, indices.toArray, values.toArray) - } + val parsed = parseLibSVMFile(sc, path, minPartitions) // Determine number of features. val d = if (numFeatures > 0) { numFeatures } else { parsed.persist(StorageLevel.MEMORY_ONLY) - parsed.map { case (label, indices, values) => - indices.lastOption.getOrElse(0) - }.reduce(math.max) + 1 + computeNumFeatures(parsed) } parsed.map { case (label, indices, values) => @@ -110,6 +82,47 @@ object MLUtils { } } + private[spark] def computeNumFeatures(rdd: RDD[(Double, Array[Int], Array[Double])]): Int = { + rdd.map { case (label, indices, values) => + indices.lastOption.getOrElse(0) + }.reduce(math.max) + 1 + } + + private[spark] def parseLibSVMFile( + sc: SparkContext, + path: String, + minPartitions: Int): RDD[(Double, Array[Int], Array[Double])] = { + sc.textFile(path, minPartitions) + .map(_.trim) + .filter(line => !(line.isEmpty || line.startsWith("#"))) + .map(parseLibSVMRecord) + } + + private[spark] def parseLibSVMRecord(line: String): (Double, Array[Int], Array[Double]) = { + val items = line.split(' ') + val label = items.head.toDouble + val (indices, values) = items.tail.filter(_.nonEmpty).map { item => + val indexAndValue = item.split(':') + val index = indexAndValue(0).toInt - 1 // Convert 1-based indices to 0-based. + val value = indexAndValue(1).toDouble + (index, value) + }.unzip + + // check if indices are one-based and in ascending order + var previous = -1 + var i = 0 + val indicesLength = indices.length + while (i < indicesLength) { + val current = indices(i) + require(current > previous, s"indices should be one-based and in ascending order;" + + " found current=$current, previous=$previous; line=\"$line\"") + previous = current + i += 1 + } + + (label, indices, values) + } + /** * Loads labeled data in the LIBSVM format into an RDD[LabeledPoint], with the default number of * partitions. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala index c66921f485..1850810270 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala @@ -299,6 +299,9 @@ case class DataSource( "It must be specified manually") } + val enrichedOptions = + format.prepareRead(sqlContext, caseInsensitiveOptions, fileCatalog.allFiles()) + HadoopFsRelation( sqlContext, fileCatalog, @@ -306,7 +309,7 @@ case class DataSource( dataSchema = dataSchema.asNullable, bucketSpec = bucketSpec, format, - options) + enrichedOptions) case _ => throw new AnalysisException( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala index 554298772a..a143ac6aec 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala @@ -59,6 +59,7 @@ private[sql] object FileSourceStrategy extends Strategy with Logging { if (files.fileFormat.toString == "TestFileFormat" || files.fileFormat.isInstanceOf[parquet.DefaultSource] || files.fileFormat.toString == "ORC" || + files.fileFormat.toString == "LibSVM" || files.fileFormat.isInstanceOf[csv.DefaultSource] || files.fileFormat.isInstanceOf[text.DefaultSource] || files.fileFormat.isInstanceOf[json.DefaultSource]) && diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala index 6b95a3d25b..e8834d052c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala @@ -438,6 +438,15 @@ trait FileFormat { options: Map[String, String], files: Seq[FileStatus]): Option[StructType] + /** + * Prepares a read job and returns a potentially updated data source option [[Map]]. This method + * can be useful for collecting necessary global information for scanning input data. + */ + def prepareRead( + sqlContext: SQLContext, + options: Map[String, String], + files: Seq[FileStatus]): Map[String, String] = options + /** * Prepares a write job and returns an [[OutputWriterFactory]]. Client side job preparation can * be put here. For example, user defined output committer can be configured here -- cgit v1.2.3 From 26867ebc67edab97376c5d8fee76df294359e461 Mon Sep 17 00:00:00 2001 From: Alexander Ulanov Date: Thu, 31 Mar 2016 23:48:36 -0700 Subject: [SPARK-11262][ML] Unit test for gradient, loss layers, memory management for multilayer perceptron 1.Implement LossFunction trait and implement squared error and cross entropy loss with it 2.Implement unit test for gradient and loss 3.Implement InPlace trait and in-place layer evaluation 4.Refactor interface for ActivationFunction 5.Update of Layer and LayerModel interfaces 6.Fix random weights assignment 7.Implement memory allocation by MLP model instead of individual layers These features decreased the memory usage and increased flexibility of internal API. Author: Alexander Ulanov Author: avulanov Closes #9229 from avulanov/mlp-refactoring. --- .../main/scala/org/apache/spark/ml/ann/Layer.scala | 662 ++++++++++----------- .../org/apache/spark/ml/ann/LossFunction.scala | 124 ++++ .../MultilayerPerceptronClassifier.scala | 82 ++- .../JavaMultilayerPerceptronClassifierSuite.java | 2 +- .../scala/org/apache/spark/ml/ann/ANNSuite.scala | 9 +- .../org/apache/spark/ml/ann/GradientSuite.scala | 76 +++ .../MultilayerPerceptronClassifierSuite.scala | 26 +- project/MimaExcludes.scala | 5 + python/pyspark/ml/classification.py | 2 +- 9 files changed, 601 insertions(+), 387 deletions(-) create mode 100644 mllib/src/main/scala/org/apache/spark/ml/ann/LossFunction.scala create mode 100644 mllib/src/test/scala/org/apache/spark/ml/ann/GradientSuite.scala diff --git a/mllib/src/main/scala/org/apache/spark/ml/ann/Layer.scala b/mllib/src/main/scala/org/apache/spark/ml/ann/Layer.scala index 2cd94fa8f5..a5b84116e6 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/ann/Layer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/ann/Layer.scala @@ -17,9 +17,9 @@ package org.apache.spark.ml.ann -import breeze.linalg.{*, axpy => Baxpy, sum => Bsum, DenseMatrix => BDM, DenseVector => BDV, - Vector => BV} -import breeze.numerics.{log => Blog, sigmoid => Bsigmoid} +import java.util.Random + +import breeze.linalg.{*, axpy => Baxpy, DenseMatrix => BDM, DenseVector => BDV, Vector => BV} import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.optimization._ @@ -32,20 +32,46 @@ import org.apache.spark.util.random.XORShiftRandom * */ private[ann] trait Layer extends Serializable { + /** - * Returns the instance of the layer based on weights provided - * @param weights vector with layer weights - * @param position position of weights in the vector - * @return the layer model + * Number of weights that is used to allocate memory for the weights vector + */ + val weightSize: Int + + /** + * Returns the output size given the input size (not counting the stack size). + * Output size is used to allocate memory for the output. + * + * @param inputSize input size + * @return output size */ - def getInstance(weights: Vector, position: Int): LayerModel + def getOutputSize(inputSize: Int): Int + /** + * If true, the memory is not allocated for the output of this layer. + * The memory allocated to the previous layer is used to write the output of this layer. + * Developer can set this to true if computing delta of a previous layer + * does not involve its output, so the current layer can write there. + * This also mean that both layers have the same number of outputs. + */ + val inPlace: Boolean + + /** + * Returns the instance of the layer based on weights provided. + * Size of weights must be equal to weightSize + * + * @param initialWeights vector with layer weights + * @return the layer model + */ + def createModel(initialWeights: BDV[Double]): LayerModel /** * Returns the instance of the layer with random generated weights - * @param seed seed + * + * @param weights vector for weights initialization, must be equal to weightSize + * @param random random number generator * @return the layer model */ - def getInstance(seed: Long): LayerModel + def initModel(weights: BDV[Double], random: Random): LayerModel } /** @@ -54,92 +80,102 @@ private[ann] trait Layer extends Serializable { * Can return weights in Vector format. */ private[ann] trait LayerModel extends Serializable { - /** - * number of weights - */ - val size: Int + val weights: BDV[Double] /** * Evaluates the data (process the data through the layer) + * Output is allocated based on the size provided by the + * LayerModel implementation and the stack (batch) size + * Developer is responsible for checking the size of output + * when writing to it + * * @param data data - * @return processed data + * @param output output (modified in place) */ - def eval(data: BDM[Double]): BDM[Double] + def eval(data: BDM[Double], output: BDM[Double]): Unit /** * Computes the delta for back propagation - * @param nextDelta delta of the next layer - * @param input input data - * @return delta + * Delta is allocated based on the size provided by the + * LayerModel implementation and the stack (batch) size + * Developer is responsible for checking the size of + * prevDelta when writing to it + * + * @param delta delta of this layer + * @param output output of this layer + * @param prevDelta the previous delta (modified in place) */ - def prevDelta(nextDelta: BDM[Double], input: BDM[Double]): BDM[Double] + def computePrevDelta(delta: BDM[Double], output: BDM[Double], prevDelta: BDM[Double]): Unit /** * Computes the gradient + * cumGrad is a wrapper on the part of the weight vector + * size of cumGrad is based on weightSize provided by + * implementation of LayerModel + * * @param delta delta for this layer * @param input input data - * @return gradient + * @param cumGrad cumulative gradient (modified in place) */ - def grad(delta: BDM[Double], input: BDM[Double]): Array[Double] - - /** - * Returns weights for the layer in a single vector - * @return layer weights - */ - def weights(): Vector + def grad(delta: BDM[Double], input: BDM[Double], cumGrad: BDV[Double]): Unit } /** * Layer properties of affine transformations, that is y=A*x+b + * * @param numIn number of inputs * @param numOut number of outputs */ private[ann] class AffineLayer(val numIn: Int, val numOut: Int) extends Layer { - override def getInstance(weights: Vector, position: Int): LayerModel = { - AffineLayerModel(this, weights, position) - } + override val weightSize = numIn * numOut + numOut - override def getInstance(seed: Long = 11L): LayerModel = { - AffineLayerModel(this, seed) - } + override def getOutputSize(inputSize: Int): Int = numOut + + override val inPlace = false + + override def createModel(weights: BDV[Double]): LayerModel = new AffineLayerModel(weights, this) + + override def initModel(weights: BDV[Double], random: Random): LayerModel = + AffineLayerModel(this, weights, random) } /** - * Model of Affine layer y=A*x+b - * @param w weights (matrix A) - * @param b bias (vector b) + * Model of Affine layer + * + * @param weights weights + * @param layer layer properties */ -private[ann] class AffineLayerModel private(w: BDM[Double], b: BDV[Double]) extends LayerModel { - val size = w.size + b.length - val gwb = new Array[Double](size) - private lazy val gw: BDM[Double] = new BDM[Double](w.rows, w.cols, gwb) - private lazy val gb: BDV[Double] = new BDV[Double](gwb, w.size) - private var z: BDM[Double] = null - private var d: BDM[Double] = null +private[ann] class AffineLayerModel private[ann] ( + val weights: BDV[Double], + val layer: AffineLayer) extends LayerModel { + val w = new BDM[Double](layer.numOut, layer.numIn, weights.data, weights.offset) + val b = + new BDV[Double](weights.data, weights.offset + (layer.numOut * layer.numIn), 1, layer.numOut) + private var ones: BDV[Double] = null - override def eval(data: BDM[Double]): BDM[Double] = { - if (z == null || z.cols != data.cols) z = new BDM[Double](w.rows, data.cols) - z(::, *) := b - BreezeUtil.dgemm(1.0, w, data, 1.0, z) - z + override def eval(data: BDM[Double], output: BDM[Double]): Unit = { + output(::, *) := b + BreezeUtil.dgemm(1.0, w, data, 1.0, output) } - override def prevDelta(nextDelta: BDM[Double], input: BDM[Double]): BDM[Double] = { - if (d == null || d.cols != nextDelta.cols) d = new BDM[Double](w.cols, nextDelta.cols) - BreezeUtil.dgemm(1.0, w.t, nextDelta, 0.0, d) - d + override def computePrevDelta( + delta: BDM[Double], + output: BDM[Double], + prevDelta: BDM[Double]): Unit = { + BreezeUtil.dgemm(1.0, w.t, delta, 0.0, prevDelta) } - override def grad(delta: BDM[Double], input: BDM[Double]): Array[Double] = { - BreezeUtil.dgemm(1.0 / input.cols, delta, input.t, 0.0, gw) + override def grad(delta: BDM[Double], input: BDM[Double], cumGrad: BDV[Double]): Unit = { + // compute gradient of weights + val cumGradientOfWeights = new BDM[Double](w.rows, w.cols, cumGrad.data, cumGrad.offset) + BreezeUtil.dgemm(1.0 / input.cols, delta, input.t, 1.0, cumGradientOfWeights) if (ones == null || ones.length != delta.cols) ones = BDV.ones[Double](delta.cols) - BreezeUtil.dgemv(1.0 / input.cols, delta, ones, 0.0, gb) - gwb + // compute gradient of bias + val cumGradientOfBias = new BDV[Double](cumGrad.data, cumGrad.offset + w.size, 1, b.length) + BreezeUtil.dgemv(1.0 / input.cols, delta, ones, 1.0, cumGradientOfBias) } - - override def weights(): Vector = AffineLayerModel.roll(w, b) } /** @@ -149,73 +185,40 @@ private[ann] object AffineLayerModel { /** * Creates a model of Affine layer + * * @param layer layer properties - * @param weights vector with weights - * @param position position of weights in the vector - * @return model of Affine layer - */ - def apply(layer: AffineLayer, weights: Vector, position: Int): AffineLayerModel = { - val (w, b) = unroll(weights, position, layer.numIn, layer.numOut) - new AffineLayerModel(w, b) - } - - /** - * Creates a model of Affine layer - * @param layer layer properties - * @param seed seed + * @param weights vector for weights initialization + * @param random random number generator * @return model of Affine layer */ - def apply(layer: AffineLayer, seed: Long): AffineLayerModel = { - val (w, b) = randomWeights(layer.numIn, layer.numOut, seed) - new AffineLayerModel(w, b) - } - - /** - * Unrolls the weights from the vector - * @param weights vector with weights - * @param position position of weights for this layer - * @param numIn number of layer inputs - * @param numOut number of layer outputs - * @return matrix A and vector b - */ - def unroll( - weights: Vector, - position: Int, - numIn: Int, - numOut: Int): (BDM[Double], BDV[Double]) = { - val weightsCopy = weights.toArray - // TODO: the array is not copied to BDMs, make sure this is OK! - val a = new BDM[Double](numOut, numIn, weightsCopy, position) - val b = new BDV[Double](weightsCopy, position + (numOut * numIn), 1, numOut) - (a, b) - } - - /** - * Roll the layer weights into a vector - * @param a matrix A - * @param b vector b - * @return vector of weights - */ - def roll(a: BDM[Double], b: BDV[Double]): Vector = { - val result = new Array[Double](a.size + b.length) - // TODO: make sure that we need to copy! - System.arraycopy(a.toArray, 0, result, 0, a.size) - System.arraycopy(b.toArray, 0, result, a.size, b.length) - Vectors.dense(result) + def apply(layer: AffineLayer, weights: BDV[Double], random: Random): AffineLayerModel = { + randomWeights(layer.numIn, layer.numOut, weights, random) + new AffineLayerModel(weights, layer) } /** - * Generate random weights for the layer - * @param numIn number of inputs + * Initialize weights randomly in the interval + * Uses [Bottou-88] heuristic [-a/sqrt(in); a/sqrt(in)] + * where a is chosen in a such way that the weight variance corresponds + * to the points to the maximal curvature of the activation function + * (which is approximately 2.38 for a standard sigmoid) + * + * @param numIn number of inputs * @param numOut number of outputs - * @param seed seed - * @return (matrix A, vector b) + * @param weights vector for weights initialization + * @param random random number generator */ - def randomWeights(numIn: Int, numOut: Int, seed: Long = 11L): (BDM[Double], BDV[Double]) = { - val rand: XORShiftRandom = new XORShiftRandom(seed) - val weights = BDM.fill[Double](numOut, numIn) { (rand.nextDouble * 4.8 - 2.4) / numIn } - val bias = BDV.fill[Double](numOut) { (rand.nextDouble * 4.8 - 2.4) / numIn } - (weights, bias) + def randomWeights( + numIn: Int, + numOut: Int, + weights: BDV[Double], + random: Random): Unit = { + var i = 0 + val sqrtIn = math.sqrt(numIn) + while (i < weights.length) { + weights(i) = (random.nextDouble * 4.8 - 2.4) / sqrtIn + i += 1 + } } } @@ -226,44 +229,21 @@ private[ann] trait ActivationFunction extends Serializable { /** * Implements a function - * @param x input data - * @param y output data */ - def eval(x: BDM[Double], y: BDM[Double]): Unit + def eval: Double => Double /** * Implements a derivative of a function (needed for the back propagation) - * @param x input data - * @param y output data */ - def derivative(x: BDM[Double], y: BDM[Double]): Unit - - /** - * Implements a cross entropy error of a function. - * Needed if the functional layer that contains this function is the output layer - * of the network. - * @param target target output - * @param output computed output - * @param result intermediate result - * @return cross-entropy - */ - def crossEntropy(target: BDM[Double], output: BDM[Double], result: BDM[Double]): Double - - /** - * Implements a mean squared error of a function - * @param target target output - * @param output computed output - * @param result intermediate result - * @return mean squared error - */ - def squared(target: BDM[Double], output: BDM[Double], result: BDM[Double]): Double + def derivative: Double => Double } /** - * Implements in-place application of functions + * Implements in-place application of functions in the arrays */ -private[ann] object ActivationFunction { +private[ann] object ApplyInPlace { + // TODO: use Breeze UFunc def apply(x: BDM[Double], y: BDM[Double], func: Double => Double): Unit = { var i = 0 while (i < x.rows) { @@ -276,6 +256,7 @@ private[ann] object ActivationFunction { } } + // TODO: use Breeze UFunc def apply( x1: BDM[Double], x2: BDM[Double], @@ -293,180 +274,87 @@ private[ann] object ActivationFunction { } } -/** - * Implements SoftMax activation function - */ -private[ann] class SoftmaxFunction extends ActivationFunction { - override def eval(x: BDM[Double], y: BDM[Double]): Unit = { - var j = 0 - // find max value to make sure later that exponent is computable - while (j < x.cols) { - var i = 0 - var max = Double.MinValue - while (i < x.rows) { - if (x(i, j) > max) { - max = x(i, j) - } - i += 1 - } - var sum = 0.0 - i = 0 - while (i < x.rows) { - val res = Math.exp(x(i, j) - max) - y(i, j) = res - sum += res - i += 1 - } - i = 0 - while (i < x.rows) { - y(i, j) /= sum - i += 1 - } - j += 1 - } - } - - override def crossEntropy( - output: BDM[Double], - target: BDM[Double], - result: BDM[Double]): Double = { - def m(o: Double, t: Double): Double = o - t - ActivationFunction(output, target, result, m) - -Bsum( target :* Blog(output)) / output.cols - } - - override def derivative(x: BDM[Double], y: BDM[Double]): Unit = { - def sd(z: Double): Double = (1 - z) * z - ActivationFunction(x, y, sd) - } - - override def squared(output: BDM[Double], target: BDM[Double], result: BDM[Double]): Double = { - throw new UnsupportedOperationException("Sorry, squared error is not defined for SoftMax.") - } -} - /** * Implements Sigmoid activation function */ private[ann] class SigmoidFunction extends ActivationFunction { - override def eval(x: BDM[Double], y: BDM[Double]): Unit = { - def s(z: Double): Double = Bsigmoid(z) - ActivationFunction(x, y, s) - } - - override def crossEntropy( - output: BDM[Double], - target: BDM[Double], - result: BDM[Double]): Double = { - def m(o: Double, t: Double): Double = o - t - ActivationFunction(output, target, result, m) - -Bsum(target :* Blog(output)) / output.cols - } - override def derivative(x: BDM[Double], y: BDM[Double]): Unit = { - def sd(z: Double): Double = (1 - z) * z - ActivationFunction(x, y, sd) - } + override def eval: (Double) => Double = x => 1.0 / (1 + math.exp(-x)) - override def squared(output: BDM[Double], target: BDM[Double], result: BDM[Double]): Double = { - // TODO: make it readable - def m(o: Double, t: Double): Double = (o - t) - ActivationFunction(output, target, result, m) - val e = Bsum(result :* result) / 2 / output.cols - def m2(x: Double, o: Double) = x * (o - o * o) - ActivationFunction(result, output, result, m2) - e - } + override def derivative: (Double) => Double = z => (1 - z) * z } /** * Functional layer properties, y = f(x) + * * @param activationFunction activation function */ private[ann] class FunctionalLayer (val activationFunction: ActivationFunction) extends Layer { - override def getInstance(weights: Vector, position: Int): LayerModel = getInstance(0L) - override def getInstance(seed: Long): LayerModel = - FunctionalLayerModel(this) + override val weightSize = 0 + + override def getOutputSize(inputSize: Int): Int = inputSize + + override val inPlace = true + + override def createModel(weights: BDV[Double]): LayerModel = new FunctionalLayerModel(this) + + override def initModel(weights: BDV[Double], random: Random): LayerModel = + createModel(weights) } /** * Functional layer model. Holds no weights. - * @param activationFunction activation function + * + * @param layer functiona layer */ -private[ann] class FunctionalLayerModel private (val activationFunction: ActivationFunction) +private[ann] class FunctionalLayerModel private[ann] (val layer: FunctionalLayer) extends LayerModel { - val size = 0 - // matrices for in-place computations - // outputs - private var f: BDM[Double] = null - // delta - private var d: BDM[Double] = null - // matrix for error computation - private var e: BDM[Double] = null - // delta gradient - private lazy val dg = new Array[Double](0) - override def eval(data: BDM[Double]): BDM[Double] = { - if (f == null || f.cols != data.cols) f = new BDM[Double](data.rows, data.cols) - activationFunction.eval(data, f) - f - } + // empty weights + val weights = new BDV[Double](0) - override def prevDelta(nextDelta: BDM[Double], input: BDM[Double]): BDM[Double] = { - if (d == null || d.cols != nextDelta.cols) d = new BDM[Double](nextDelta.rows, nextDelta.cols) - activationFunction.derivative(input, d) - d :*= nextDelta - d + override def eval(data: BDM[Double], output: BDM[Double]): Unit = { + ApplyInPlace(data, output, layer.activationFunction.eval) } - override def grad(delta: BDM[Double], input: BDM[Double]): Array[Double] = dg - - override def weights(): Vector = Vectors.dense(new Array[Double](0)) - - def crossEntropy(output: BDM[Double], target: BDM[Double]): (BDM[Double], Double) = { - if (e == null || e.cols != output.cols) e = new BDM[Double](output.rows, output.cols) - val error = activationFunction.crossEntropy(output, target, e) - (e, error) + override def computePrevDelta( + nextDelta: BDM[Double], + input: BDM[Double], + delta: BDM[Double]): Unit = { + ApplyInPlace(input, delta, layer.activationFunction.derivative) + delta :*= nextDelta } - def squared(output: BDM[Double], target: BDM[Double]): (BDM[Double], Double) = { - if (e == null || e.cols != output.cols) e = new BDM[Double](output.rows, output.cols) - val error = activationFunction.squared(output, target, e) - (e, error) - } - - def error(output: BDM[Double], target: BDM[Double]): (BDM[Double], Double) = { - // TODO: allow user pick error - activationFunction match { - case sigmoid: SigmoidFunction => squared(output, target) - case softmax: SoftmaxFunction => crossEntropy(output, target) - } - } -} - -/** - * Fabric of functional layer models - */ -private[ann] object FunctionalLayerModel { - def apply(layer: FunctionalLayer): FunctionalLayerModel = - new FunctionalLayerModel(layer.activationFunction) + override def grad(delta: BDM[Double], input: BDM[Double], cumGrad: BDV[Double]): Unit = {} } /** * Trait for the artificial neural network (ANN) topology properties */ -private[ann] trait Topology extends Serializable{ - def getInstance(weights: Vector): TopologyModel - def getInstance(seed: Long): TopologyModel +private[ann] trait Topology extends Serializable { + def model(weights: Vector): TopologyModel + def model(seed: Long): TopologyModel } /** * Trait for ANN topology model */ -private[ann] trait TopologyModel extends Serializable{ +private[ann] trait TopologyModel extends Serializable { + + val weights: Vector + /** + * Array of layers + */ + val layers: Array[Layer] + + /** + * Array of layer models + */ + val layerModels: Array[LayerModel] /** * Forward propagation + * * @param data input data * @return array of outputs for each of the layers */ @@ -474,6 +362,7 @@ private[ann] trait TopologyModel extends Serializable{ /** * Prediction of the model + * * @param data input data * @return prediction */ @@ -481,6 +370,7 @@ private[ann] trait TopologyModel extends Serializable{ /** * Computes gradient for the network + * * @param data input data * @param target target output * @param cumGradient cumulative gradient @@ -489,22 +379,17 @@ private[ann] trait TopologyModel extends Serializable{ */ def computeGradient(data: BDM[Double], target: BDM[Double], cumGradient: Vector, blockSize: Int): Double - - /** - * Returns the weights of the ANN - * @return weights - */ - def weights(): Vector } /** * Feed forward ANN + * * @param layers */ private[ann] class FeedForwardTopology private(val layers: Array[Layer]) extends Topology { - override def getInstance(weights: Vector): TopologyModel = FeedForwardModel(this, weights) + override def model(weights: Vector): TopologyModel = FeedForwardModel(this, weights) - override def getInstance(seed: Long): TopologyModel = FeedForwardModel(this, seed) + override def model(seed: Long): TopologyModel = FeedForwardModel(this, seed) } /** @@ -513,6 +398,7 @@ private[ann] class FeedForwardTopology private(val layers: Array[Layer]) extends private[ml] object FeedForwardTopology { /** * Creates a feed forward topology from the array of layers + * * @param layers array of layers * @return feed forward topology */ @@ -522,18 +408,26 @@ private[ml] object FeedForwardTopology { /** * Creates a multi-layer perceptron + * * @param layerSizes sizes of layers including input and output size - * @param softmax whether to use SoftMax or Sigmoid function for an output layer. + * @param softmaxOnTop wether to use SoftMax or Sigmoid function for an output layer. * Softmax is default * @return multilayer perceptron topology */ - def multiLayerPerceptron(layerSizes: Array[Int], softmax: Boolean = true): FeedForwardTopology = { + def multiLayerPerceptron( + layerSizes: Array[Int], + softmaxOnTop: Boolean = true): FeedForwardTopology = { val layers = new Array[Layer]((layerSizes.length - 1) * 2) - for(i <- 0 until layerSizes.length - 1) { + for (i <- 0 until layerSizes.length - 1) { layers(i * 2) = new AffineLayer(layerSizes(i), layerSizes(i + 1)) layers(i * 2 + 1) = - if (softmax && i == layerSizes.length - 2) { - new FunctionalLayer(new SoftmaxFunction()) + if (i == layerSizes.length - 2) { + if (softmaxOnTop) { + new SoftmaxLayerWithCrossEntropyLoss() + } else { + // TODO: squared error is more natural but converges slower + new SigmoidLayerWithSquaredError() + } } else { new FunctionalLayer(new SigmoidFunction()) } @@ -545,17 +439,45 @@ private[ml] object FeedForwardTopology { /** * Model of Feed Forward Neural Network. * Implements forward, gradient computation and can return weights in vector format. - * @param layerModels models of layers - * @param topology topology of the network + * + * @param weights network weights + * @param topology network topology */ private[ml] class FeedForwardModel private( - val layerModels: Array[LayerModel], + val weights: Vector, val topology: FeedForwardTopology) extends TopologyModel { + + val layers = topology.layers + val layerModels = new Array[LayerModel](layers.length) + private var offset = 0 + for (i <- 0 until layers.length) { + layerModels(i) = layers(i).createModel( + new BDV[Double](weights.toArray, offset, 1, layers(i).weightSize)) + offset += layers(i).weightSize + } + private var outputs: Array[BDM[Double]] = null + private var deltas: Array[BDM[Double]] = null + override def forward(data: BDM[Double]): Array[BDM[Double]] = { - val outputs = new Array[BDM[Double]](layerModels.length) - outputs(0) = layerModels(0).eval(data) + // Initialize output arrays for all layers. Special treatment for InPlace + val currentBatchSize = data.cols + // TODO: allocate outputs as one big array and then create BDMs from it + if (outputs == null || outputs(0).cols != currentBatchSize) { + outputs = new Array[BDM[Double]](layers.length) + var inputSize = data.rows + for (i <- 0 until layers.length) { + if (layers(i).inPlace) { + outputs(i) = outputs(i - 1) + } else { + val outputSize = layers(i).getOutputSize(inputSize) + outputs(i) = new BDM[Double](outputSize, currentBatchSize) + inputSize = outputSize + } + } + } + layerModels(0).eval(data, outputs(0)) for (i <- 1 until layerModels.length) { - outputs(i) = layerModels(i).eval(outputs(i-1)) + layerModels(i).eval(outputs(i - 1), outputs(i)) } outputs } @@ -566,54 +488,36 @@ private[ml] class FeedForwardModel private( cumGradient: Vector, realBatchSize: Int): Double = { val outputs = forward(data) - val deltas = new Array[BDM[Double]](layerModels.length) + val currentBatchSize = data.cols + // TODO: allocate deltas as one big array and then create BDMs from it + if (deltas == null || deltas(0).cols != currentBatchSize) { + deltas = new Array[BDM[Double]](layerModels.length) + var inputSize = data.rows + for (i <- 0 until layerModels.length - 1) { + val outputSize = layers(i).getOutputSize(inputSize) + deltas(i) = new BDM[Double](outputSize, currentBatchSize) + inputSize = outputSize + } + } val L = layerModels.length - 1 - val (newE, newError) = layerModels.last match { - case flm: FunctionalLayerModel => flm.error(outputs.last, target) + // TODO: explain why delta of top layer is null (because it might contain loss+layer) + val loss = layerModels.last match { + case levelWithError: LossFunction => levelWithError.loss(outputs.last, target, deltas(L - 1)) case _ => - throw new UnsupportedOperationException("Non-functional layer not supported at the top") + throw new UnsupportedOperationException("Top layer is required to have objective.") } - deltas(L) = new BDM[Double](0, 0) - deltas(L - 1) = newE for (i <- (L - 2) to (0, -1)) { - deltas(i) = layerModels(i + 1).prevDelta(deltas(i + 1), outputs(i + 1)) - } - val grads = new Array[Array[Double]](layerModels.length) - for (i <- 0 until layerModels.length) { - val input = if (i==0) data else outputs(i - 1) - grads(i) = layerModels(i).grad(deltas(i), input) + layerModels(i + 1).computePrevDelta(deltas(i + 1), outputs(i + 1), deltas(i)) } - // update cumGradient val cumGradientArray = cumGradient.toArray var offset = 0 - // TODO: extract roll - for (i <- 0 until grads.length) { - val gradArray = grads(i) - var k = 0 - while (k < gradArray.length) { - cumGradientArray(offset + k) += gradArray(k) - k += 1 - } - offset += gradArray.length - } - newError - } - - // TODO: do we really need to copy the weights? they should be read-only - override def weights(): Vector = { - // TODO: extract roll - var size = 0 - for (i <- 0 until layerModels.length) { - size += layerModels(i).size - } - val array = new Array[Double](size) - var offset = 0 for (i <- 0 until layerModels.length) { - val layerWeights = layerModels(i).weights().toArray - System.arraycopy(layerWeights, 0, array, offset, layerWeights.length) - offset += layerWeights.length + val input = if (i == 0) data else outputs(i - 1) + layerModels(i).grad(deltas(i), input, + new BDV[Double](cumGradientArray, offset, 1, layers(i).weightSize)) + offset += layers(i).weightSize } - Vectors.dense(array) + loss } override def predict(data: Vector): Vector = { @@ -630,23 +534,19 @@ private[ann] object FeedForwardModel { /** * Creates a model from a topology and weights + * * @param topology topology * @param weights weights * @return model */ def apply(topology: FeedForwardTopology, weights: Vector): FeedForwardModel = { - val layers = topology.layers - val layerModels = new Array[LayerModel](layers.length) - var offset = 0 - for (i <- 0 until layers.length) { - layerModels(i) = layers(i).getInstance(weights, offset) - offset += layerModels(i).size - } - new FeedForwardModel(layerModels, topology) + // TODO: check that weights size is equal to sum of layers sizes + new FeedForwardModel(weights, topology) } /** * Creates a model given a topology and seed + * * @param topology topology * @param seed seed for generating the weights * @return model @@ -654,17 +554,25 @@ private[ann] object FeedForwardModel { def apply(topology: FeedForwardTopology, seed: Long = 11L): FeedForwardModel = { val layers = topology.layers val layerModels = new Array[LayerModel](layers.length) + var totalSize = 0 + for (i <- 0 until topology.layers.length) { + totalSize += topology.layers(i).weightSize + } + val weights = BDV.zeros[Double](totalSize) var offset = 0 - for(i <- 0 until layers.length) { - layerModels(i) = layers(i).getInstance(seed) - offset += layerModels(i).size + val random = new XORShiftRandom(seed) + for (i <- 0 until layers.length) { + layerModels(i) = layers(i). + initModel(new BDV[Double](weights.data, offset, 1, layers(i).weightSize), random) + offset += layers(i).weightSize } - new FeedForwardModel(layerModels, topology) + new FeedForwardModel(Vectors.fromBreeze(weights), topology) } } /** * Neural network gradient. Does nothing but calling Model's gradient + * * @param topology topology * @param dataStacker data stacker */ @@ -682,7 +590,7 @@ private[ann] class ANNGradient(topology: Topology, dataStacker: DataStacker) ext weights: Vector, cumGradient: Vector): Double = { val (input, target, realBatchSize) = dataStacker.unstack(data) - val model = topology.getInstance(weights) + val model = topology.model(weights) model.computeGradient(input, target, cumGradient, realBatchSize) } } @@ -692,6 +600,7 @@ private[ann] class ANNGradient(topology: Topology, dataStacker: DataStacker) ext * through Optimizer/Gradient interfaces. If stackSize is more than one, makes blocks * or matrices of inputs and outputs and then stack them in one vector. * This can be used for further batch computations after unstacking. + * * @param stackSize stack size * @param inputSize size of the input vectors * @param outputSize size of the output vectors @@ -701,6 +610,7 @@ private[ann] class DataStacker(stackSize: Int, inputSize: Int, outputSize: Int) /** * Stacks the data + * * @param data RDD of vector pairs * @return RDD of double (always zero) and vector that contains the stacked vectors */ @@ -733,6 +643,7 @@ private[ann] class DataStacker(stackSize: Int, inputSize: Int, outputSize: Int) /** * Unstack the stacked vectors into matrices for batch operations + * * @param data stacked vector * @return pair of matrices holding input and output data and the real stack size */ @@ -765,6 +676,7 @@ private[ann] class ANNUpdater extends Updater { /** * MLlib-style trainer class that trains a network given the data and topology + * * @param topology topology of ANN * @param inputSize input size * @param outputSize output size @@ -774,36 +686,50 @@ private[ml] class FeedForwardTrainer( val inputSize: Int, val outputSize: Int) extends Serializable { - // TODO: what if we need to pass random seed? - private var _weights = topology.getInstance(11L).weights() + private var _seed = this.getClass.getName.hashCode.toLong + private var _weights: Vector = null private var _stackSize = 128 private var dataStacker = new DataStacker(_stackSize, inputSize, outputSize) private var _gradient: Gradient = new ANNGradient(topology, dataStacker) private var _updater: Updater = new ANNUpdater() private var optimizer: Optimizer = LBFGSOptimizer.setConvergenceTol(1e-4).setNumIterations(100) + /** + * Returns seed + */ + def getSeed: Long = _seed + + /** + * Sets seed + */ + def setSeed(value: Long): this.type = { + _seed = value + this + } + /** * Returns weights - * @return weights */ def getWeights: Vector = _weights /** * Sets weights + * * @param value weights * @return trainer */ - def setWeights(value: Vector): FeedForwardTrainer = { + def setWeights(value: Vector): this.type = { _weights = value this } /** * Sets the stack size + * * @param value stack size * @return trainer */ - def setStackSize(value: Int): FeedForwardTrainer = { + def setStackSize(value: Int): this.type = { _stackSize = value dataStacker = new DataStacker(value, inputSize, outputSize) this @@ -811,6 +737,7 @@ private[ml] class FeedForwardTrainer( /** * Sets the SGD optimizer + * * @return SGD optimizer */ def SGDOptimizer: GradientDescent = { @@ -821,6 +748,7 @@ private[ml] class FeedForwardTrainer( /** * Sets the LBFGS optimizer + * * @return LBGS optimizer */ def LBFGSOptimizer: LBFGS = { @@ -831,10 +759,11 @@ private[ml] class FeedForwardTrainer( /** * Sets the updater + * * @param value updater * @return trainer */ - def setUpdater(value: Updater): FeedForwardTrainer = { + def setUpdater(value: Updater): this.type = { _updater = value updateUpdater(value) this @@ -842,10 +771,11 @@ private[ml] class FeedForwardTrainer( /** * Sets the gradient + * * @param value gradient * @return trainer */ - def setGradient(value: Gradient): FeedForwardTrainer = { + def setGradient(value: Gradient): this.type = { _gradient = value updateGradient(value) this @@ -871,12 +801,20 @@ private[ml] class FeedForwardTrainer( /** * Trains the ANN + * * @param data RDD of input and output vector pairs * @return model */ def train(data: RDD[(Vector, Vector)]): TopologyModel = { - val newWeights = optimizer.optimize(dataStacker.stack(data), getWeights) - topology.getInstance(newWeights) + val w = if (getWeights == null) { + // TODO: will make a copy if vector is a subvector of BDV (see Vectors code) + topology.model(_seed).weights + } else { + getWeights + } + // TODO: deprecate standard optimizer because it needs Vector + val newWeights = optimizer.optimize(dataStacker.stack(data), w) + topology.model(newWeights) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/ann/LossFunction.scala b/mllib/src/main/scala/org/apache/spark/ml/ann/LossFunction.scala new file mode 100644 index 0000000000..32d78e9b22 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/ann/LossFunction.scala @@ -0,0 +1,124 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.ann + +import java.util.Random + +import breeze.linalg.{sum => Bsum, DenseMatrix => BDM, DenseVector => BDV} +import breeze.numerics.{log => brzlog} + +/** + * Trait for loss function + */ +private[ann] trait LossFunction { + /** + * Returns the value of loss function. + * Computes loss based on target and output. + * Writes delta (error) to delta in place. + * Delta is allocated based on the outputSize + * of model implementation. + * + * @param output actual output + * @param target target output + * @param delta delta (updated in place) + * @return loss + */ + def loss(output: BDM[Double], target: BDM[Double], delta: BDM[Double]): Double +} + +private[ann] class SigmoidLayerWithSquaredError extends Layer { + override val weightSize = 0 + override val inPlace = true + + override def getOutputSize(inputSize: Int): Int = inputSize + override def createModel(weights: BDV[Double]): LayerModel = + new SigmoidLayerModelWithSquaredError() + override def initModel(weights: BDV[Double], random: Random): LayerModel = + new SigmoidLayerModelWithSquaredError() +} + +private[ann] class SigmoidLayerModelWithSquaredError + extends FunctionalLayerModel(new FunctionalLayer(new SigmoidFunction)) with LossFunction { + override def loss(output: BDM[Double], target: BDM[Double], delta: BDM[Double]): Double = { + ApplyInPlace(output, target, delta, (o: Double, t: Double) => o - t) + val error = Bsum(delta :* delta) / 2 / output.cols + ApplyInPlace(delta, output, delta, (x: Double, o: Double) => x * (o - o * o)) + error + } +} + +private[ann] class SoftmaxLayerWithCrossEntropyLoss extends Layer { + override val weightSize = 0 + override val inPlace = true + + override def getOutputSize(inputSize: Int): Int = inputSize + override def createModel(weights: BDV[Double]): LayerModel = + new SoftmaxLayerModelWithCrossEntropyLoss() + override def initModel(weights: BDV[Double], random: Random): LayerModel = + new SoftmaxLayerModelWithCrossEntropyLoss() +} + +private[ann] class SoftmaxLayerModelWithCrossEntropyLoss extends LayerModel with LossFunction { + + // loss layer models do not have weights + val weights = new BDV[Double](0) + + override def eval(data: BDM[Double], output: BDM[Double]): Unit = { + var j = 0 + // find max value to make sure later that exponent is computable + while (j < data.cols) { + var i = 0 + var max = Double.MinValue + while (i < data.rows) { + if (data(i, j) > max) { + max = data(i, j) + } + i += 1 + } + var sum = 0.0 + i = 0 + while (i < data.rows) { + val res = math.exp(data(i, j) - max) + output(i, j) = res + sum += res + i += 1 + } + i = 0 + while (i < data.rows) { + output(i, j) /= sum + i += 1 + } + j += 1 + } + } + override def computePrevDelta( + nextDelta: BDM[Double], + input: BDM[Double], + delta: BDM[Double]): Unit = { + /* loss layer model computes delta in loss function */ + } + + override def grad(delta: BDM[Double], input: BDM[Double], cumGrad: BDV[Double]): Unit = { + /* loss layer model does not have weights */ + } + + override def loss(output: BDM[Double], target: BDM[Double], delta: BDM[Double]): Double = { + ApplyInPlace(output, target, delta, (o: Double, t: Double) => o - t) + -Bsum( target :* brzlog(output)) / output.cols + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala index 7ce3ec68da..79bb2a8855 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala @@ -24,8 +24,8 @@ import org.apache.hadoop.fs.Path import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.{PredictionModel, Predictor, PredictorParams} import org.apache.spark.ml.ann.{FeedForwardTopology, FeedForwardTrainer} -import org.apache.spark.ml.param.{IntArrayParam, IntParam, ParamMap, ParamValidators} -import org.apache.spark.ml.param.shared.{HasMaxIter, HasSeed, HasTol} +import org.apache.spark.ml.param._ +import org.apache.spark.ml.param.shared.{HasMaxIter, HasSeed, HasStepSize, HasTol} import org.apache.spark.ml.util._ import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.regression.LabeledPoint @@ -33,11 +33,12 @@ import org.apache.spark.sql.DataFrame /** Params for Multilayer Perceptron. */ private[ml] trait MultilayerPerceptronParams extends PredictorParams - with HasSeed with HasMaxIter with HasTol { + with HasSeed with HasMaxIter with HasTol with HasStepSize { /** * Layer sizes including input size and output size. * Default: Array(1, 1) - * @group param + * + * @group param */ final val layers: IntArrayParam = new IntArrayParam(this, "layers", "Sizes of layers from input layer to output layer" + @@ -55,7 +56,8 @@ private[ml] trait MultilayerPerceptronParams extends PredictorParams * a partition then it is adjusted to the size of this data. * Recommended size is between 10 and 1000. * Default: 128 - * @group expertParam + * + * @group expertParam */ final val blockSize: IntParam = new IntParam(this, "blockSize", "Block size for stacking input data in matrices. Data is stacked within partitions." + @@ -66,7 +68,33 @@ private[ml] trait MultilayerPerceptronParams extends PredictorParams /** @group getParam */ final def getBlockSize: Int = $(blockSize) - setDefault(maxIter -> 100, tol -> 1e-4, layers -> Array(1, 1), blockSize -> 128) + /** + * Allows setting the solver: minibatch gradient descent (gd) or l-bfgs. + * l-bfgs is the default one. + * + * @group expertParam + */ + final val solver: Param[String] = new Param[String](this, "solver", + " Allows setting the solver: minibatch gradient descent (gd) or l-bfgs. " + + " l-bfgs is the default one.", + ParamValidators.inArray[String](Array("gd", "l-bfgs"))) + + /** @group getParam */ + final def getOptimizer: String = $(solver) + + /** + * Model weights. Can be returned either after training or after explicit setting + * + * @group expertParam + */ + final val weights: Param[Vector] = new Param[Vector](this, "weights", + " Sets the weights of the model ") + + /** @group getParam */ + final def getWeights: Vector = $(weights) + + + setDefault(maxIter -> 100, tol -> 1e-4, blockSize -> 128, solver -> "l-bfgs", stepSize -> 0.03) } /** Label to vector converter. */ @@ -105,6 +133,7 @@ private object LabelConverter { * Each layer has sigmoid activation function, output layer has softmax. * Number of inputs has to be equal to the size of feature vectors. * Number of outputs has to be equal to the total number of labels. + * */ @Since("1.5.0") @Experimental @@ -127,7 +156,8 @@ class MultilayerPerceptronClassifier @Since("1.5.0") ( /** * Set the maximum number of iterations. * Default is 100. - * @group setParam + * + * @group setParam */ @Since("1.5.0") def setMaxIter(value: Int): this.type = set(maxIter, value) @@ -136,18 +166,28 @@ class MultilayerPerceptronClassifier @Since("1.5.0") ( * Set the convergence tolerance of iterations. * Smaller value will lead to higher accuracy with the cost of more iterations. * Default is 1E-4. - * @group setParam + * + * @group setParam */ @Since("1.5.0") def setTol(value: Double): this.type = set(tol, value) /** - * Set the seed for weights initialization. - * @group setParam + * Set the seed for weights initialization if weights are not set + * + * @group setParam */ @Since("1.5.0") def setSeed(value: Long): this.type = set(seed, value) + /** + * Sets the model weights. + * + * @group expertParam + */ + @Since("2.0.0") + def setWeights(value: Vector): this.type = set(weights, value) + @Since("1.5.0") override def copy(extra: ParamMap): MultilayerPerceptronClassifier = defaultCopy(extra) @@ -165,11 +205,18 @@ class MultilayerPerceptronClassifier @Since("1.5.0") ( val lpData = extractLabeledPoints(dataset) val data = lpData.map(lp => LabelConverter.encodeLabeledPoint(lp, labels)) val topology = FeedForwardTopology.multiLayerPerceptron(myLayers, true) - val FeedForwardTrainer = new FeedForwardTrainer(topology, myLayers(0), myLayers.last) - FeedForwardTrainer.LBFGSOptimizer.setConvergenceTol($(tol)).setNumIterations($(maxIter)) - FeedForwardTrainer.setStackSize($(blockSize)) - val mlpModel = FeedForwardTrainer.train(data) - new MultilayerPerceptronClassificationModel(uid, myLayers, mlpModel.weights()) + val trainer = new FeedForwardTrainer(topology, myLayers(0), myLayers.last) + if (isDefined(weights)) { + trainer.setWeights($(weights)) + } else { + trainer.setSeed($(seed)) + } + trainer.LBFGSOptimizer + .setConvergenceTol($(tol)) + .setNumIterations($(maxIter)) + trainer.setStackSize($(blockSize)) + val mlpModel = trainer.train(data) + new MultilayerPerceptronClassificationModel(uid, myLayers, mlpModel.weights) } } @@ -185,7 +232,8 @@ object MultilayerPerceptronClassifier * :: Experimental :: * Classification model based on the Multilayer Perceptron. * Each layer has sigmoid activation function, output layer has softmax. - * @param uid uid + * + * @param uid uid * @param layers array of layer sizes including input and output layers * @param weights vector of initial weights for the model that consists of the weights of layers * @return prediction model @@ -202,7 +250,7 @@ class MultilayerPerceptronClassificationModel private[ml] ( @Since("1.6.0") override val numFeatures: Int = layers.head - private val mlpModel = FeedForwardTopology.multiLayerPerceptron(layers, true).getInstance(weights) + private val mlpModel = FeedForwardTopology.multiLayerPerceptron(layers, true).model(weights) /** * Returns layers in a Java List. diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaMultilayerPerceptronClassifierSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaMultilayerPerceptronClassifierSuite.java index d499d363f1..bc955f3cf6 100644 --- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaMultilayerPerceptronClassifierSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaMultilayerPerceptronClassifierSuite.java @@ -63,7 +63,7 @@ public class JavaMultilayerPerceptronClassifierSuite implements Serializable { MultilayerPerceptronClassifier mlpc = new MultilayerPerceptronClassifier() .setLayers(new int[] {2, 5, 2}) .setBlockSize(1) - .setSeed(11L) + .setSeed(123L) .setMaxIter(100); MultilayerPerceptronClassificationModel model = mlpc.fit(dataFrame); Dataset result = model.transform(dataFrame); diff --git a/mllib/src/test/scala/org/apache/spark/ml/ann/ANNSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/ann/ANNSuite.scala index 1292e57d7c..dc91fc5f9e 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/ann/ANNSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/ann/ANNSuite.scala @@ -42,7 +42,7 @@ class ANNSuite extends SparkFunSuite with MLlibTestSparkContext { val dataSample = rddData.first() val layerSizes = dataSample._1.size +: hiddenLayersTopology :+ dataSample._2.size val topology = FeedForwardTopology.multiLayerPerceptron(layerSizes, false) - val initialWeights = FeedForwardModel(topology, 23124).weights() + val initialWeights = FeedForwardModel(topology, 23124).weights val trainer = new FeedForwardTrainer(topology, 2, 1) trainer.setWeights(initialWeights) trainer.LBFGSOptimizer.setNumIterations(20) @@ -76,10 +76,11 @@ class ANNSuite extends SparkFunSuite with MLlibTestSparkContext { val dataSample = rddData.first() val layerSizes = dataSample._1.size +: hiddenLayersTopology :+ dataSample._2.size val topology = FeedForwardTopology.multiLayerPerceptron(layerSizes, false) - val initialWeights = FeedForwardModel(topology, 23124).weights() + val initialWeights = FeedForwardModel(topology, 23124).weights val trainer = new FeedForwardTrainer(topology, 2, 2) - trainer.SGDOptimizer.setNumIterations(2000) - trainer.setWeights(initialWeights) + // TODO: add a test for SGD + trainer.LBFGSOptimizer.setConvergenceTol(1e-4).setNumIterations(20) + trainer.setWeights(initialWeights).setStackSize(1) val model = trainer.train(rddData) val predictionAndLabels = rddData.map { case (input, label) => (model.predict(input), label) diff --git a/mllib/src/test/scala/org/apache/spark/ml/ann/GradientSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/ann/GradientSuite.scala new file mode 100644 index 0000000000..04cc426c40 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/ann/GradientSuite.scala @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.ann + +import breeze.linalg.{DenseMatrix => BDM} + +import org.apache.spark.SparkFunSuite +import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.util.MLlibTestSparkContext + +class GradientSuite extends SparkFunSuite with MLlibTestSparkContext { + + test("Gradient computation against numerical differentiation") { + val input = new BDM[Double](3, 1, Array(1.0, 1.0, 1.0)) + // output must contain zeros and one 1 for SoftMax + val target = new BDM[Double](2, 1, Array(0.0, 1.0)) + val topology = FeedForwardTopology.multiLayerPerceptron(Array(3, 4, 2), softmaxOnTop = false) + val layersWithErrors = Seq( + new SigmoidLayerWithSquaredError(), + new SoftmaxLayerWithCrossEntropyLoss() + ) + // check all layers that provide loss computation + // 1) compute loss and gradient given the model and initial weights + // 2) modify weights with small number epsilon (per dimension i) + // 3) compute new loss + // 4) ((newLoss - loss) / epsilon) should be close to the i-th component of the gradient + for (layerWithError <- layersWithErrors) { + topology.layers(topology.layers.length - 1) = layerWithError + val model = topology.model(seed = 12L) + val weights = model.weights.toArray + val numWeights = weights.size + val gradient = Vectors.dense(Array.fill[Double](numWeights)(0.0)) + val loss = model.computeGradient(input, target, gradient, 1) + val eps = 1e-4 + var i = 0 + val tol = 1e-4 + while (i < numWeights) { + val originalValue = weights(i) + weights(i) += eps + val newModel = topology.model(Vectors.dense(weights)) + val newLoss = computeLoss(input, target, newModel) + val derivativeEstimate = (newLoss - loss) / eps + assert(math.abs(gradient(i) - derivativeEstimate) < tol, "Layer failed gradient check: " + + layerWithError.getClass) + weights(i) = originalValue + i += 1 + } + } + } + + private def computeLoss(input: BDM[Double], target: BDM[Double], model: TopologyModel): Double = { + val outputs = model.forward(input) + model.layerModels.last match { + case layerWithLoss: LossFunction => + layerWithLoss.loss(outputs.last, target, new BDM[Double](target.rows, target.cols)) + case _ => + throw new UnsupportedOperationException("Top layer is required to have loss." + + " Failed layer:" + model.layerModels.last.getClass) + } + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala index 53c7a559e3..43781385db 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala @@ -65,7 +65,7 @@ class MultilayerPerceptronClassifierSuite val trainer = new MultilayerPerceptronClassifier() .setLayers(layers) .setBlockSize(1) - .setSeed(11L) + .setSeed(123L) .setMaxIter(100) val model = trainer.fit(dataset) val result = model.transform(dataset) @@ -75,7 +75,29 @@ class MultilayerPerceptronClassifierSuite } } - // TODO: implement a more rigorous test + test("Test setWeights by training restart") { + val dataFrame = sqlContext.createDataFrame(Seq( + (Vectors.dense(0.0, 0.0), 0.0), + (Vectors.dense(0.0, 1.0), 1.0), + (Vectors.dense(1.0, 0.0), 1.0), + (Vectors.dense(1.0, 1.0), 0.0)) + ).toDF("features", "label") + val layers = Array[Int](2, 5, 2) + val trainer = new MultilayerPerceptronClassifier() + .setLayers(layers) + .setBlockSize(1) + .setSeed(12L) + .setMaxIter(1) + .setTol(1e-6) + val initialWeights = trainer.fit(dataFrame).weights + trainer.setWeights(initialWeights.copy) + val weights1 = trainer.fit(dataFrame).weights + trainer.setWeights(initialWeights.copy) + val weights2 = trainer.fit(dataFrame).weights + assert(weights1 ~== weights2 absTol 10e-5, + "Training should produce the same weights given equal initial weights and number of steps") + } + test("3 class classification with 2 hidden layers") { val nPoints = 1000 diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 94621d7fa3..ff11775412 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -592,6 +592,11 @@ object MimaExcludes { ) ++ Seq( // [SPARK-14205][SQL] remove trait Queryable ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.sql.Dataset") + ) ++ Seq( + // [SPARK-11262][ML] Unit test for gradient, loss layers, memory management + // for multilayer perceptron. + // This class is marked as `private`. + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.ml.ann.SoftmaxFunction") ) case v if v.startsWith("1.6") => Seq( diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index f5335a3114..067009559b 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -788,7 +788,7 @@ class MultilayerPerceptronClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, ... (1.0, Vectors.dense([0.0, 1.0])), ... (1.0, Vectors.dense([1.0, 0.0])), ... (0.0, Vectors.dense([1.0, 1.0]))], ["label", "features"]) - >>> mlp = MultilayerPerceptronClassifier(maxIter=100, layers=[2, 5, 2], blockSize=1, seed=11) + >>> mlp = MultilayerPerceptronClassifier(maxIter=100, layers=[2, 5, 2], blockSize=1, seed=123) >>> model = mlp.fit(df) >>> model.layers [2, 5, 2] -- cgit v1.2.3 From 22249afb4a932a82ff1f7a3befea9fda5a60a3f4 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Thu, 31 Mar 2016 23:49:58 -0700 Subject: [SPARK-14303][ML][SPARKR] Define and use KMeansWrapper for SparkR::kmeans ## What changes were proposed in this pull request? Define and use ```KMeansWrapper``` for ```SparkR::kmeans```. It's only the code refactor for the original ```KMeans``` wrapper. ## How was this patch tested? Existing tests. cc mengxr Author: Yanbo Liang Closes #12039 from yanboliang/spark-14059. --- R/pkg/R/mllib.R | 91 +++++++++++++++------- .../org/apache/spark/ml/r/KMeansWrapper.scala | 85 ++++++++++++++++++++ .../org/apache/spark/ml/r/SparkRWrappers.scala | 52 +------------ 3 files changed, 148 insertions(+), 80 deletions(-) create mode 100644 mllib/src/main/scala/org/apache/spark/ml/r/KMeansWrapper.scala diff --git a/R/pkg/R/mllib.R b/R/pkg/R/mllib.R index 33654d5216..f3152cc232 100644 --- a/R/pkg/R/mllib.R +++ b/R/pkg/R/mllib.R @@ -32,6 +32,11 @@ setClass("NaiveBayesModel", representation(jobj = "jobj")) #' @export setClass("AFTSurvivalRegressionModel", representation(jobj = "jobj")) +#' @title S4 class that represents a KMeansModel +#' @param jobj a Java object reference to the backing Scala KMeansModel +#' @export +setClass("KMeansModel", representation(jobj = "jobj")) + #' Fits a generalized linear model #' #' Fits a generalized linear model, similarly to R's glm(). Also see the glmnet package. @@ -154,17 +159,6 @@ setMethod("summary", signature(object = "PipelineModel"), colnames(coefficients) <- c("Estimate") rownames(coefficients) <- unlist(features) return(list(coefficients = coefficients)) - } else if (modelName == "KMeansModel") { - modelSize <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers", - "getKMeansModelSize", object@model) - cluster <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers", - "getKMeansCluster", object@model, "classes") - k <- unlist(modelSize)[1] - size <- unlist(modelSize)[-1] - coefficients <- t(matrix(coefficients, ncol = k)) - colnames(coefficients) <- unlist(features) - rownames(coefficients) <- 1:k - return(list(coefficients = coefficients, size = size, cluster = dataFrame(cluster))) } else { stop(paste("Unsupported model", modelName, sep = " ")) } @@ -213,21 +207,21 @@ setMethod("summary", signature(object = "NaiveBayesModel"), #' @examples #' \dontrun{ #' model <- kmeans(x, centers = 2, algorithm="random") -#'} +#' } setMethod("kmeans", signature(x = "DataFrame"), function(x, centers, iter.max = 10, algorithm = c("random", "k-means||")) { columnNames <- as.array(colnames(x)) algorithm <- match.arg(algorithm) - model <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers", "fitKMeans", x@sdf, - algorithm, iter.max, centers, columnNames) - return(new("PipelineModel", model = model)) + jobj <- callJStatic("org.apache.spark.ml.r.KMeansWrapper", "fit", x@sdf, + centers, iter.max, algorithm, columnNames) + return(new("KMeansModel", jobj = jobj)) }) -#' Get fitted result from a model +#' Get fitted result from a k-means model #' -#' Get fitted result from a model, similarly to R's fitted(). +#' Get fitted result from a k-means model, similarly to R's fitted(). #' -#' @param object A fitted MLlib model +#' @param object A fitted k-means model #' @return DataFrame containing fitted values #' @rdname fitted #' @export @@ -237,19 +231,58 @@ setMethod("kmeans", signature(x = "DataFrame"), #' fitted.model <- fitted(model) #' showDF(fitted.model) #'} -setMethod("fitted", signature(object = "PipelineModel"), +setMethod("fitted", signature(object = "KMeansModel"), function(object, method = c("centers", "classes"), ...) { - modelName <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers", - "getModelName", object@model) + method <- match.arg(method) + return(dataFrame(callJMethod(object@jobj, "fitted", method))) + }) - if (modelName == "KMeansModel") { - method <- match.arg(method) - fittedResult <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers", - "getKMeansCluster", object@model, method) - return(dataFrame(fittedResult)) - } else { - stop(paste("Unsupported model", modelName, sep = " ")) - } +#' Get the summary of a k-means model +#' +#' Returns the summary of a k-means model produced by kmeans(), +#' similarly to R's summary(). +#' +#' @param object a fitted k-means model +#' @return the model's coefficients, size and cluster +#' @rdname summary +#' @export +#' @examples +#' \dontrun{ +#' model <- kmeans(trainingData, 2) +#' summary(model) +#' } +setMethod("summary", signature(object = "KMeansModel"), + function(object, ...) { + jobj <- object@jobj + features <- callJMethod(jobj, "features") + coefficients <- callJMethod(jobj, "coefficients") + cluster <- callJMethod(jobj, "cluster") + k <- callJMethod(jobj, "k") + size <- callJMethod(jobj, "size") + coefficients <- t(matrix(coefficients, ncol = k)) + colnames(coefficients) <- unlist(features) + rownames(coefficients) <- 1:k + return(list(coefficients = coefficients, size = size, cluster = dataFrame(cluster))) + }) + +#' Make predictions from a k-means model +#' +#' Make predictions from a model produced by kmeans(). +#' +#' @param object A fitted k-means model +#' @param newData DataFrame for testing +#' @return DataFrame containing predicted labels in a column named "prediction" +#' @rdname predict +#' @export +#' @examples +#' \dontrun{ +#' model <- kmeans(trainingData, 2) +#' predicted <- predict(model, testData) +#' showDF(predicted) +#' } +setMethod("predict", signature(object = "KMeansModel"), + function(object, newData) { + return(dataFrame(callJMethod(object@jobj, "transform", newData@sdf))) }) #' Fit a Bernoulli naive Bayes model diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/KMeansWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/KMeansWrapper.scala new file mode 100644 index 0000000000..d3a0df4063 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/r/KMeansWrapper.scala @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.r + +import org.apache.spark.ml.{Pipeline, PipelineModel} +import org.apache.spark.ml.attribute.AttributeGroup +import org.apache.spark.ml.clustering.{KMeans, KMeansModel} +import org.apache.spark.ml.feature.VectorAssembler +import org.apache.spark.sql.DataFrame + +private[r] class KMeansWrapper private ( + pipeline: PipelineModel) { + + private val kMeansModel: KMeansModel = pipeline.stages(1).asInstanceOf[KMeansModel] + + lazy val coefficients: Array[Double] = kMeansModel.clusterCenters.flatMap(_.toArray) + + private lazy val attrs = AttributeGroup.fromStructField( + kMeansModel.summary.predictions.schema(kMeansModel.getFeaturesCol)) + + lazy val features: Array[String] = attrs.attributes.get.map(_.name.get) + + lazy val k: Int = kMeansModel.getK + + lazy val size: Array[Int] = kMeansModel.summary.size + + lazy val cluster: DataFrame = kMeansModel.summary.cluster + + def fitted(method: String): DataFrame = { + if (method == "centers") { + kMeansModel.summary.predictions.drop(kMeansModel.getFeaturesCol) + } else if (method == "classes") { + kMeansModel.summary.cluster + } else { + throw new UnsupportedOperationException( + s"Method (centers or classes) required but $method found.") + } + } + + def transform(dataset: DataFrame): DataFrame = { + pipeline.transform(dataset).drop(kMeansModel.getFeaturesCol) + } + +} + +private[r] object KMeansWrapper { + + def fit( + data: DataFrame, + k: Double, + maxIter: Double, + initMode: String, + columns: Array[String]): KMeansWrapper = { + + val assembler = new VectorAssembler() + .setInputCols(columns) + .setOutputCol("features") + + val kMeans = new KMeans() + .setK(k.toInt) + .setMaxIter(maxIter.toInt) + .setInitMode(initMode) + + val pipeline = new Pipeline() + .setStages(Array(assembler, kMeans)) + .fit(data) + + new KMeansWrapper(pipeline) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala b/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala index d23e4fc9d1..551e75dc0a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala @@ -20,8 +20,7 @@ package org.apache.spark.ml.api.r import org.apache.spark.ml.{Pipeline, PipelineModel} import org.apache.spark.ml.attribute._ import org.apache.spark.ml.classification.{LogisticRegression, LogisticRegressionModel} -import org.apache.spark.ml.clustering.{KMeans, KMeansModel} -import org.apache.spark.ml.feature.{RFormula, VectorAssembler} +import org.apache.spark.ml.feature.RFormula import org.apache.spark.ml.regression.{LinearRegression, LinearRegressionModel} import org.apache.spark.sql.DataFrame @@ -52,22 +51,6 @@ private[r] object SparkRWrappers { pipeline.fit(df) } - def fitKMeans( - df: DataFrame, - initMode: String, - maxIter: Double, - k: Double, - columns: Array[String]): PipelineModel = { - val assembler = new VectorAssembler().setInputCols(columns) - val kMeans = new KMeans() - .setInitMode(initMode) - .setMaxIter(maxIter.toInt) - .setK(k.toInt) - .setFeaturesCol(assembler.getOutputCol) - val pipeline = new Pipeline().setStages(Array(assembler, kMeans)) - pipeline.fit(df) - } - def getModelCoefficients(model: PipelineModel): Array[Double] = { model.stages.last match { case m: LinearRegressionModel => { @@ -89,8 +72,6 @@ private[r] object SparkRWrappers { m.coefficients.toArray } } - case m: KMeansModel => - m.clusterCenters.flatMap(_.toArray) } } @@ -104,31 +85,6 @@ private[r] object SparkRWrappers { } } - def getKMeansModelSize(model: PipelineModel): Array[Int] = { - model.stages.last match { - case m: KMeansModel => Array(m.getK) ++ m.summary.size - case other => throw new UnsupportedOperationException( - s"KMeansModel required but ${other.getClass.getSimpleName} found.") - } - } - - def getKMeansCluster(model: PipelineModel, method: String): DataFrame = { - model.stages.last match { - case m: KMeansModel => - if (method == "centers") { - // Drop the assembled vector for easy-print to R side. - m.summary.predictions.drop(m.summary.featuresCol) - } else if (method == "classes") { - m.summary.cluster - } else { - throw new UnsupportedOperationException( - s"Method (centers or classes) required but $method found.") - } - case other => throw new UnsupportedOperationException( - s"KMeansModel required but ${other.getClass.getSimpleName} found.") - } - } - def getModelFeatures(model: PipelineModel): Array[String] = { model.stages.last match { case m: LinearRegressionModel => @@ -147,10 +103,6 @@ private[r] object SparkRWrappers { } else { attrs.attributes.get.map(_.name.get) } - case m: KMeansModel => - val attrs = AttributeGroup.fromStructField( - m.summary.predictions.schema(m.summary.featuresCol)) - attrs.attributes.get.map(_.name.get) } } @@ -160,8 +112,6 @@ private[r] object SparkRWrappers { "LinearRegressionModel" case m: LogisticRegressionModel => "LogisticRegressionModel" - case m: KMeansModel => - "KMeansModel" } } } -- cgit v1.2.3 From 3715ecdf417b47423ff07145a5623d8d817c45ef Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Fri, 1 Apr 2016 17:02:48 +0800 Subject: [SPARK-14295][MLLIB][HOTFIX] Fixes Scala 2.10 compilation failure ## What changes were proposed in this pull request? Fixes a compilation failure introduced in PR #12088 under Scala 2.10. ## How was this patch tested? Compilation. Author: Cheng Lian Closes #12107 from liancheng/spark-14295-hotfix. --- mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala index 4b9d77949f..774170ff40 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala @@ -120,7 +120,7 @@ object MLUtils { i += 1 } - (label, indices, values) + (label, indices.toArray, values.toArray) } /** -- cgit v1.2.3 From 0b04f8fdf1614308cb3e7e0c7282f7365cc3d1bb Mon Sep 17 00:00:00 2001 From: Dilip Biswal Date: Fri, 1 Apr 2016 18:27:11 +0200 Subject: [SPARK-14184][SQL] Support native execution of SHOW DATABASE command and fix SHOW TABLE to use table identifier pattern ## What changes were proposed in this pull request? This PR addresses the following 1. Supports native execution of SHOW DATABASES command 2. Fixes SHOW TABLES to apply the identifier_with_wildcards pattern if supplied. SHOW TABLE syntax ``` SHOW TABLES [IN database_name] ['identifier_with_wildcards']; ``` SHOW DATABASES syntax ``` SHOW (DATABASES|SCHEMAS) [LIKE 'identifier_with_wildcards']; ``` ## How was this patch tested? Tests added in SQLQuerySuite (both hive and sql contexts) and DDLCommandSuite Note: Since the table name pattern was not working , tests are added in both SQLQuerySuite to verify the application of the table pattern. Author: Dilip Biswal Closes #11991 from dilipbiswal/dkb_show_database. --- .../apache/spark/sql/catalyst/parser/SqlBase.g4 | 6 +- .../scala/org/apache/spark/sql/SQLContext.scala | 4 +- .../spark/sql/execution/SparkSqlParser.scala | 22 +++++-- .../spark/sql/execution/command/commands.scala | 42 ++++++++++--- .../sql/execution/command/DDLCommandSuite.scala | 11 ++++ .../spark/sql/execution/command/DDLSuite.scala | 72 ++++++++++++++++++++++ .../spark/sql/hive/thriftserver/CliSuite.scala | 2 +- .../spark/sql/hive/execution/SQLQuerySuite.scala | 22 +++++++ 8 files changed, 163 insertions(+), 18 deletions(-) diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index a857e670da..5513bbdc7f 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -114,7 +114,8 @@ statement | DROP TEMPORARY? FUNCTION (IF EXISTS)? qualifiedName #dropFunction | EXPLAIN explainOption* statement #explain | SHOW TABLES ((FROM | IN) db=identifier)? - (LIKE (qualifiedName | pattern=STRING))? #showTables + (LIKE? pattern=STRING)? #showTables + | SHOW DATABASES (LIKE pattern=STRING)? #showDatabases | SHOW FUNCTIONS (LIKE? (qualifiedName | pattern=STRING))? #showFunctions | (DESC | DESCRIBE) FUNCTION EXTENDED? qualifiedName #describeFunction | (DESC | DESCRIBE) option=(EXTENDED | FORMATTED)? @@ -618,7 +619,7 @@ number ; nonReserved - : SHOW | TABLES | COLUMNS | COLUMN | PARTITIONS | FUNCTIONS + : SHOW | TABLES | COLUMNS | COLUMN | PARTITIONS | FUNCTIONS | DATABASES | ADD | OVER | PARTITION | RANGE | ROWS | PRECEDING | FOLLOWING | CURRENT | ROW | MAP | ARRAY | STRUCT | LATERAL | WINDOW | REDUCE | TRANSFORM | USING | SERDE | SERDEPROPERTIES | RECORDREADER @@ -836,6 +837,7 @@ OUTPUTFORMAT: 'OUTPUTFORMAT'; INPUTDRIVER: 'INPUTDRIVER'; OUTPUTDRIVER: 'OUTPUTDRIVER'; DATABASE: 'DATABASE' | 'SCHEMA'; +DATABASES: 'DATABASES' | 'SCHEMAS'; DFS: 'DFS'; TRUNCATE: 'TRUNCATE'; METADATA: 'METADATA'; diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 0576a1a178..221782ee8f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -781,7 +781,7 @@ class SQLContext private[sql]( * @since 1.3.0 */ def tables(): DataFrame = { - Dataset.ofRows(this, ShowTablesCommand(None)) + Dataset.ofRows(this, ShowTablesCommand(None, None)) } /** @@ -793,7 +793,7 @@ class SQLContext private[sql]( * @since 1.3.0 */ def tables(databaseName: String): DataFrame = { - Dataset.ofRows(this, ShowTablesCommand(Some(databaseName))) + Dataset.ofRows(this, ShowTablesCommand(Some(databaseName), None)) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala index 16a899e01f..7efe98dd18 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala @@ -70,12 +70,26 @@ class SparkSqlAstBuilder extends AstBuilder { /** * Create a [[ShowTablesCommand]] logical plan. + * Example SQL : + * {{{ + * SHOW TABLES [(IN|FROM) database_name] [[LIKE] 'identifier_with_wildcards']; + * }}} */ override def visitShowTables(ctx: ShowTablesContext): LogicalPlan = withOrigin(ctx) { - if (ctx.LIKE != null) { - logWarning("SHOW TABLES LIKE option is ignored.") - } - ShowTablesCommand(Option(ctx.db).map(_.getText)) + ShowTablesCommand( + Option(ctx.db).map(_.getText), + Option(ctx.pattern).map(string)) + } + + /** + * Create a [[ShowDatabasesCommand]] logical plan. + * Example SQL: + * {{{ + * SHOW (DATABASES|SCHEMAS) [LIKE 'identifier_with_wildcards']; + * }}} + */ + override def visitShowDatabases(ctx: ShowDatabasesContext): LogicalPlan = withOrigin(ctx) { + ShowDatabasesCommand(Option(ctx.pattern).map(string)) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala index 964f0a7a7b..f90d8717ca 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala @@ -322,18 +322,17 @@ case class DescribeCommand( * If a databaseName is not given, the current database will be used. * The syntax of using this command in SQL is: * {{{ - * SHOW TABLES [IN databaseName] + * SHOW TABLES [(IN|FROM) database_name] [[LIKE] 'identifier_with_wildcards']; * }}} */ -case class ShowTablesCommand(databaseName: Option[String]) extends RunnableCommand { +case class ShowTablesCommand( + databaseName: Option[String], + tableIdentifierPattern: Option[String]) extends RunnableCommand { // The result of SHOW TABLES has two columns, tableName and isTemporary. override val output: Seq[Attribute] = { - val schema = StructType( - StructField("tableName", StringType, false) :: - StructField("isTemporary", BooleanType, false) :: Nil) - - schema.toAttributes + AttributeReference("tableName", StringType, nullable = false)() :: + AttributeReference("isTemporary", BooleanType, nullable = false)() :: Nil } override def run(sqlContext: SQLContext): Seq[Row] = { @@ -341,11 +340,36 @@ case class ShowTablesCommand(databaseName: Option[String]) extends RunnableComma // instead of calling tables in sqlContext. val catalog = sqlContext.sessionState.catalog val db = databaseName.getOrElse(catalog.getCurrentDatabase) - val rows = catalog.listTables(db).map { t => + val tables = + tableIdentifierPattern.map(catalog.listTables(db, _)).getOrElse(catalog.listTables(db)) + tables.map { t => val isTemp = t.database.isEmpty Row(t.table, isTemp) } - rows + } +} + +/** + * A command for users to list the databases/schemas. + * If a databasePattern is supplied then the databases that only matches the + * pattern would be listed. + * The syntax of using this command in SQL is: + * {{{ + * SHOW (DATABASES|SCHEMAS) [LIKE 'identifier_with_wildcards']; + * }}} + */ +case class ShowDatabasesCommand(databasePattern: Option[String]) extends RunnableCommand { + + // The result of SHOW DATABASES has one column called 'result' + override val output: Seq[Attribute] = { + AttributeReference("result", StringType, nullable = false)() :: Nil + } + + override def run(sqlContext: SQLContext): Seq[Row] = { + val catalog = sqlContext.sessionState.catalog + val databases = + databasePattern.map(catalog.listDatabases(_)).getOrElse(catalog.listDatabases()) + databases.map { d => Row(d) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala index cebf9c856d..458f36e832 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala @@ -762,4 +762,15 @@ class DDLCommandSuite extends PlanTest { comparePlans(parsed2, expected2) } + test("show databases") { + val sql1 = "SHOW DATABASES" + val sql2 = "SHOW DATABASES LIKE 'defau*'" + val parsed1 = parser.parsePlan(sql1) + val expected1 = ShowDatabasesCommand(None) + val parsed2 = parser.parsePlan(sql2) + val expected2 = ShowDatabasesCommand(Some("defau*")) + comparePlans(parsed1, expected1) + comparePlans(parsed2, expected2) + } + } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala index f148f2d4ea..885a04af59 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala @@ -45,6 +45,7 @@ class DDLSuite extends QueryTest with SharedSQLContext { dbNames.foreach { name => sqlContext.sql(s"DROP DATABASE IF EXISTS $name CASCADE") } + sqlContext.sessionState.catalog.setCurrentDatabase("default") } } @@ -159,4 +160,75 @@ class DDLSuite extends QueryTest with SharedSQLContext { } // TODO: ADD a testcase for Drop Database in Restric when we can create tables in SQLContext + + test("show tables") { + withTempTable("show1a", "show2b") { + sql( + """ + |CREATE TEMPORARY TABLE show1a + |USING org.apache.spark.sql.sources.DDLScanSource + |OPTIONS ( + | From '1', + | To '10', + | Table 'test1' + | + |) + """.stripMargin) + sql( + """ + |CREATE TEMPORARY TABLE show2b + |USING org.apache.spark.sql.sources.DDLScanSource + |OPTIONS ( + | From '1', + | To '10', + | Table 'test1' + |) + """.stripMargin) + checkAnswer( + sql("SHOW TABLES IN default 'show1*'"), + Row("show1a", true) :: Nil) + + checkAnswer( + sql("SHOW TABLES IN default 'show1*|show2*'"), + Row("show1a", true) :: + Row("show2b", true) :: Nil) + + checkAnswer( + sql("SHOW TABLES 'show1*|show2*'"), + Row("show1a", true) :: + Row("show2b", true) :: Nil) + + assert( + sql("SHOW TABLES").count() >= 2) + assert( + sql("SHOW TABLES IN default").count() >= 2) + } + } + + test("show databases") { + withDatabase("showdb1A", "showdb2B") { + sql("CREATE DATABASE showdb1A") + sql("CREATE DATABASE showdb2B") + + assert( + sql("SHOW DATABASES").count() >= 2) + + checkAnswer( + sql("SHOW DATABASES LIKE '*db1A'"), + Row("showdb1A") :: Nil) + + checkAnswer( + sql("SHOW DATABASES LIKE 'showdb1A'"), + Row("showdb1A") :: Nil) + + checkAnswer( + sql("SHOW DATABASES LIKE '*db1A|*db2B'"), + Row("showdb1A") :: + Row("showdb2B") :: Nil) + + checkAnswer( + sql("SHOW DATABASES LIKE 'non-existentdb'"), + Nil) + } + } } diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala index 7ad7f92bd2..e93b0c145f 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala @@ -177,7 +177,7 @@ class CliSuite extends SparkFunSuite with BeforeAndAfterAll with Logging { } test("Single command with -e") { - runCliWithin(2.minute, Seq("-e", "SHOW DATABASES;"))("" -> "OK") + runCliWithin(2.minute, Seq("-e", "SHOW DATABASES;"))("" -> "") } test("Single command with --database") { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 6199253d34..c203518fdd 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -1811,4 +1811,26 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { } } } + + test("show tables") { + withTable("show1a", "show2b") { + sql("CREATE TABLE show1a(c1 int)") + sql("CREATE TABLE show2b(c2 int)") + checkAnswer( + sql("SHOW TABLES IN default 'show1*'"), + Row("show1a", false) :: Nil) + checkAnswer( + sql("SHOW TABLES IN default 'show1*|show2*'"), + Row("show1a", false) :: + Row("show2b", false) :: Nil) + checkAnswer( + sql("SHOW TABLES 'show1*|show2*'"), + Row("show1a", false) :: + Row("show2b", false) :: Nil) + assert( + sql("SHOW TABLES").count() >= 2) + assert( + sql("SHOW TABLES IN default").count() >= 2) + } + } } -- cgit v1.2.3 From a471c7f9eaa59d55dfff5b9d1a858f304a6b3a84 Mon Sep 17 00:00:00 2001 From: sureshthalamati Date: Fri, 1 Apr 2016 18:33:31 +0200 Subject: [SPARK-14133][SQL] Throws exception for unsupported create/drop/alter index , and lock/unlock operations. ## What changes were proposed in this pull request? This PR throws Unsupported Operation exception for create index, drop index, alter index , lock table , lock database, unlock table, and unlock database operations that are not supported in Spark SQL. Currently these operations are executed executed by Hive. Error: spark-sql> drop index my_index on my_table; Error in query: Unsupported operation: drop index(line 1, pos 0) ## How was this patch tested? Added test cases to HiveQuerySuite yhuai hvanhovell andrewor14 Author: sureshthalamati Closes #12069 from sureshthalamati/unsupported_ddl_spark-14133. --- .../org/apache/spark/sql/catalyst/parser/SqlBase.g4 | 12 ++++++++++-- .../spark/sql/hive/execution/HiveCompatibilitySuite.scala | 10 ++++++---- .../apache/spark/sql/hive/execution/HiveQuerySuite.scala | 15 +++++++++++++++ 3 files changed, 31 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index 5513bbdc7f..d1747b9915 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -146,7 +146,7 @@ hiveNativeCommands | ROLLBACK WORK? | SHOW PARTITIONS tableIdentifier partitionSpec? | DFS .*? - | (CREATE | ALTER | DROP | SHOW | DESC | DESCRIBE | LOCK | UNLOCK | MSCK | LOAD) .*? + | (CREATE | ALTER | DROP | SHOW | DESC | DESCRIBE | MSCK | LOAD) .*? ; unsupportedHiveNativeCommands @@ -166,6 +166,13 @@ unsupportedHiveNativeCommands | kw1=SHOW kw2=TRANSACTIONS | kw1=SHOW kw2=INDEXES | kw1=SHOW kw2=LOCKS + | kw1=CREATE kw2=INDEX + | kw1=DROP kw2=INDEX + | kw1=ALTER kw2=INDEX + | kw1=LOCK kw2=TABLE + | kw1=LOCK kw2=DATABASE + | kw1=UNLOCK kw2=TABLE + | kw1=UNLOCK kw2=DATABASE ; createTableHeader @@ -640,7 +647,7 @@ nonReserved | INPUTDRIVER | OUTPUTDRIVER | DBPROPERTIES | DFS | TRUNCATE | METADATA | REPLICATION | COMPUTE | STATISTICS | ANALYZE | PARTITIONED | EXTERNAL | DEFINED | RECORDWRITER | REVOKE | GRANT | LOCK | UNLOCK | MSCK | EXPORT | IMPORT | LOAD | VALUES | COMMENT | ROLE - | ROLES | COMPACTIONS | PRINCIPALS | TRANSACTIONS | INDEXES | LOCKS | OPTION + | ROLES | COMPACTIONS | PRINCIPALS | TRANSACTIONS | INDEX | INDEXES | LOCKS | OPTION ; SELECT: 'SELECT'; @@ -861,6 +868,7 @@ ROLES: 'ROLES'; COMPACTIONS: 'COMPACTIONS'; PRINCIPALS: 'PRINCIPALS'; TRANSACTIONS: 'TRANSACTIONS'; +INDEX: 'INDEX'; INDEXES: 'INDEXES'; LOCKS: 'LOCKS'; OPTION: 'OPTION'; diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala index bedbf9ae17..695b5ef733 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala @@ -352,7 +352,12 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "show_create_table_index", "show_create_table_partitioned", "show_create_table_serde", - "show_create_table_view" + "show_create_table_view", + + // Index commands are not supported + "drop_index", + "drop_index_removes_partition_dirs", + "alter_index" ) /** @@ -369,7 +374,6 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "alter3", "alter4", "alter5", - "alter_index", "alter_merge_2", "alter_partition_format_loc", "alter_partition_with_whitelist", @@ -496,8 +500,6 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "distinct_stats", "drop_database_removes_partition_dirs", "drop_function", - "drop_index", - "drop_index_removes_partition_dirs", "drop_multi_partitions", "drop_partitions_filter", "drop_partitions_filter2", diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index 79774f5913..58259060bf 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -1280,6 +1280,21 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { assertUnsupportedFeature { sql("SHOW LOCKS my_table") } } + test("lock/unlock table and database commands are not supported") { + assertUnsupportedFeature { sql("LOCK TABLE my_table SHARED") } + assertUnsupportedFeature { sql("UNLOCK TABLE my_table") } + assertUnsupportedFeature { sql("LOCK DATABASE my_db SHARED") } + assertUnsupportedFeature { sql("UNLOCK DATABASE my_db") } + } + + test("create/drop/alter index commands are not supported") { + assertUnsupportedFeature { + sql("CREATE INDEX my_index ON TABLE my_table(a) as 'COMPACT' WITH DEFERRED REBUILD")} + assertUnsupportedFeature { sql("DROP INDEX my_index ON my_table") } + assertUnsupportedFeature { sql("ALTER INDEX my_index ON my_table REBUILD")} + assertUnsupportedFeature { + sql("ALTER INDEX my_index ON my_table set IDXPROPERTIES (\"prop1\"=\"val1_new\")")} + } } // for SPARK-2180 test -- cgit v1.2.3 From 58e6bc827f1f9dc1afee07dca1bee1f56553dd20 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Fri, 1 Apr 2016 10:36:01 -0700 Subject: [MINOR] [SQL] Update usage of `debug` by removing `typeCheck` and adding `debugCodegen` ## What changes were proposed in this pull request? This PR updates the usage comments of `debug` according to the following commits. - [SPARK-9754](https://issues.apache.org/jira/browse/SPARK-9754) removed `typeCheck`. - [SPARK-14227](https://issues.apache.org/jira/browse/SPARK-14227) added `debugCodegen`. ## How was this patch tested? Manual. Author: Dongjoon Hyun Closes #12094 from dongjoon-hyun/minor_fix_debug_usage. --- .../src/main/scala/org/apache/spark/sql/execution/debug/package.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala index 9916482a68..3a174ed94c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala @@ -35,8 +35,8 @@ import org.apache.spark.sql.internal.SQLConf * Usage: * {{{ * import org.apache.spark.sql.execution.debug._ - * sql("SELECT key FROM src").debug() - * dataFrame.typeCheck() + * sql("SELECT 1").debug() + * sql("SELECT 1").debugCodegen() * }}} */ package object debug { -- cgit v1.2.3 From 8ba2b7f28fee39c4839e5ea125bd25f5091a3a1e Mon Sep 17 00:00:00 2001 From: jerryshao Date: Fri, 1 Apr 2016 10:52:13 -0700 Subject: [SPARK-12343][YARN] Simplify Yarn client and client argument ## What changes were proposed in this pull request? Currently in Spark on YARN, configurations can be passed through SparkConf, env and command arguments, some parts are duplicated, like client argument and SparkConf. So here propose to simplify the command arguments. ## How was this patch tested? This patch is tested manually with unit test. CC vanzin tgravescs , please help to suggest this proposal. The original purpose of this JIRA is to remove `ClientArguments`, through refactoring some arguments like `--class`, `--arg` are not so easy to replace, so here I remove the most part of command line arguments, only keep the minimal set. Author: jerryshao Closes #11603 from jerryshao/SPARK-12343. --- .../org/apache/spark/deploy/SparkSubmit.scala | 44 ++--- .../org/apache/spark/internal/config/package.scala | 14 ++ .../org/apache/spark/deploy/SparkSubmitSuite.scala | 19 ++- docs/running-on-yarn.md | 7 + .../spark/deploy/yarn/ApplicationMaster.scala | 2 +- .../deploy/yarn/ApplicationMasterArguments.scala | 18 +- .../org/apache/spark/deploy/yarn/Client.scala | 89 ++++++---- .../apache/spark/deploy/yarn/ClientArguments.scala | 189 +-------------------- .../apache/spark/deploy/yarn/YarnAllocator.scala | 6 +- .../apache/spark/deploy/yarn/YarnRMClient.scala | 5 +- .../spark/deploy/yarn/YarnSparkHadoopUtil.scala | 2 +- .../org/apache/spark/deploy/yarn/config.scala | 18 +- .../cluster/YarnClientSchedulerBackend.scala | 42 +---- .../org/apache/spark/deploy/yarn/ClientSuite.scala | 19 ++- .../spark/deploy/yarn/YarnAllocatorSuite.scala | 8 +- .../spark/deploy/yarn/YarnClusterSuite.scala | 46 ++++- 16 files changed, 186 insertions(+), 342 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index 4049fc0c41..926e1ff7a8 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -441,7 +441,6 @@ object SparkSubmit { OptionAssigner(args.deployMode, ALL_CLUSTER_MGRS, ALL_DEPLOY_MODES, sysProp = "spark.submit.deployMode"), OptionAssigner(args.name, ALL_CLUSTER_MGRS, ALL_DEPLOY_MODES, sysProp = "spark.app.name"), - OptionAssigner(args.jars, ALL_CLUSTER_MGRS, CLIENT, sysProp = "spark.jars"), OptionAssigner(args.ivyRepoPath, ALL_CLUSTER_MGRS, CLIENT, sysProp = "spark.jars.ivy"), OptionAssigner(args.driverMemory, ALL_CLUSTER_MGRS, CLIENT, sysProp = "spark.driver.memory"), @@ -452,27 +451,15 @@ object SparkSubmit { OptionAssigner(args.driverExtraLibraryPath, ALL_CLUSTER_MGRS, ALL_DEPLOY_MODES, sysProp = "spark.driver.extraLibraryPath"), - // Yarn client only - OptionAssigner(args.queue, YARN, CLIENT, sysProp = "spark.yarn.queue"), + // Yarn only + OptionAssigner(args.queue, YARN, ALL_DEPLOY_MODES, sysProp = "spark.yarn.queue"), OptionAssigner(args.numExecutors, YARN, ALL_DEPLOY_MODES, sysProp = "spark.executor.instances"), - OptionAssigner(args.files, YARN, CLIENT, sysProp = "spark.yarn.dist.files"), - OptionAssigner(args.archives, YARN, CLIENT, sysProp = "spark.yarn.dist.archives"), - OptionAssigner(args.principal, YARN, CLIENT, sysProp = "spark.yarn.principal"), - OptionAssigner(args.keytab, YARN, CLIENT, sysProp = "spark.yarn.keytab"), - - // Yarn cluster only - OptionAssigner(args.name, YARN, CLUSTER, clOption = "--name"), - OptionAssigner(args.driverMemory, YARN, CLUSTER, clOption = "--driver-memory"), - OptionAssigner(args.driverCores, YARN, CLUSTER, clOption = "--driver-cores"), - OptionAssigner(args.queue, YARN, CLUSTER, clOption = "--queue"), - OptionAssigner(args.executorMemory, YARN, CLUSTER, clOption = "--executor-memory"), - OptionAssigner(args.executorCores, YARN, CLUSTER, clOption = "--executor-cores"), - OptionAssigner(args.files, YARN, CLUSTER, clOption = "--files"), - OptionAssigner(args.archives, YARN, CLUSTER, clOption = "--archives"), - OptionAssigner(args.jars, YARN, CLUSTER, clOption = "--addJars"), - OptionAssigner(args.principal, YARN, CLUSTER, clOption = "--principal"), - OptionAssigner(args.keytab, YARN, CLUSTER, clOption = "--keytab"), + OptionAssigner(args.jars, YARN, ALL_DEPLOY_MODES, sysProp = "spark.yarn.dist.jars"), + OptionAssigner(args.files, YARN, ALL_DEPLOY_MODES, sysProp = "spark.yarn.dist.files"), + OptionAssigner(args.archives, YARN, ALL_DEPLOY_MODES, sysProp = "spark.yarn.dist.archives"), + OptionAssigner(args.principal, YARN, ALL_DEPLOY_MODES, sysProp = "spark.yarn.principal"), + OptionAssigner(args.keytab, YARN, ALL_DEPLOY_MODES, sysProp = "spark.yarn.keytab"), // Other options OptionAssigner(args.executorCores, STANDALONE | YARN, ALL_DEPLOY_MODES, @@ -483,10 +470,11 @@ object SparkSubmit { sysProp = "spark.cores.max"), OptionAssigner(args.files, LOCAL | STANDALONE | MESOS, ALL_DEPLOY_MODES, sysProp = "spark.files"), - OptionAssigner(args.jars, STANDALONE | MESOS, CLUSTER, sysProp = "spark.jars"), - OptionAssigner(args.driverMemory, STANDALONE | MESOS, CLUSTER, + OptionAssigner(args.jars, LOCAL, CLIENT, sysProp = "spark.jars"), + OptionAssigner(args.jars, STANDALONE | MESOS, ALL_DEPLOY_MODES, sysProp = "spark.jars"), + OptionAssigner(args.driverMemory, STANDALONE | MESOS | YARN, CLUSTER, sysProp = "spark.driver.memory"), - OptionAssigner(args.driverCores, STANDALONE | MESOS, CLUSTER, + OptionAssigner(args.driverCores, STANDALONE | MESOS | YARN, CLUSTER, sysProp = "spark.driver.cores"), OptionAssigner(args.supervise.toString, STANDALONE | MESOS, CLUSTER, sysProp = "spark.driver.supervise"), @@ -550,6 +538,10 @@ object SparkSubmit { if (args.isPython) { sysProps.put("spark.yarn.isPython", "true") } + + if (args.pyFiles != null) { + sysProps("spark.submit.pyFiles") = args.pyFiles + } } // assure a keytab is available from any place in a JVM @@ -576,9 +568,6 @@ object SparkSubmit { childMainClass = "org.apache.spark.deploy.yarn.Client" if (args.isPython) { childArgs += ("--primary-py-file", args.primaryResource) - if (args.pyFiles != null) { - childArgs += ("--py-files", args.pyFiles) - } childArgs += ("--class", "org.apache.spark.deploy.PythonRunner") } else if (args.isR) { val mainFile = new Path(args.primaryResource).getName @@ -627,7 +616,8 @@ object SparkSubmit { "spark.jars", "spark.files", "spark.yarn.dist.files", - "spark.yarn.dist.archives") + "spark.yarn.dist.archives", + "spark.yarn.dist.jars") pathConfigs.foreach { config => // Replace old URIs with resolved URIs, if they exist sysProps.get(config).foreach { oldValue => diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index f2f20b3207..968c5192ac 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -18,6 +18,7 @@ package org.apache.spark.internal import org.apache.spark.launcher.SparkLauncher +import org.apache.spark.network.util.ByteUnit package object config { @@ -33,6 +34,10 @@ package object config { private[spark] val DRIVER_USER_CLASS_PATH_FIRST = ConfigBuilder("spark.driver.userClassPathFirst").booleanConf.withDefault(false) + private[spark] val DRIVER_MEMORY = ConfigBuilder("spark.driver.memory") + .bytesConf(ByteUnit.MiB) + .withDefaultString("1g") + private[spark] val EXECUTOR_CLASS_PATH = ConfigBuilder(SparkLauncher.EXECUTOR_EXTRA_CLASSPATH).stringConf.optional @@ -45,6 +50,10 @@ package object config { private[spark] val EXECUTOR_USER_CLASS_PATH_FIRST = ConfigBuilder("spark.executor.userClassPathFirst").booleanConf.withDefault(false) + private[spark] val EXECUTOR_MEMORY = ConfigBuilder("spark.executor.memory") + .bytesConf(ByteUnit.MiB) + .withDefaultString("1g") + private[spark] val IS_PYTHON_APP = ConfigBuilder("spark.yarn.isPython").internal .booleanConf.withDefault(false) @@ -73,4 +82,9 @@ package object config { private[spark] val EXECUTOR_INSTANCES = ConfigBuilder("spark.executor.instances").intConf.optional + private[spark] val PY_FILES = ConfigBuilder("spark.submit.pyFiles") + .internal + .stringConf + .toSequence + .withDefault(Nil) } diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala index 96cb4fd0eb..2718976992 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala @@ -199,21 +199,21 @@ class SparkSubmitSuite val (childArgs, classpath, sysProps, mainClass) = prepareSubmitEnvironment(appArgs) val childArgsStr = childArgs.mkString(" ") childArgsStr should include ("--class org.SomeClass") - childArgsStr should include ("--executor-memory 5g") - childArgsStr should include ("--driver-memory 4g") - childArgsStr should include ("--executor-cores 5") childArgsStr should include ("--arg arg1 --arg arg2") - childArgsStr should include ("--queue thequeue") childArgsStr should include regex ("--jar .*thejar.jar") - childArgsStr should include regex ("--addJars .*one.jar,.*two.jar,.*three.jar") - childArgsStr should include regex ("--files .*file1.txt,.*file2.txt") - childArgsStr should include regex ("--archives .*archive1.txt,.*archive2.txt") mainClass should be ("org.apache.spark.deploy.yarn.Client") classpath should have length (0) + + sysProps("spark.executor.memory") should be ("5g") + sysProps("spark.driver.memory") should be ("4g") + sysProps("spark.executor.cores") should be ("5") + sysProps("spark.yarn.queue") should be ("thequeue") + sysProps("spark.yarn.dist.jars") should include regex (".*one.jar,.*two.jar,.*three.jar") + sysProps("spark.yarn.dist.files") should include regex (".*file1.txt,.*file2.txt") + sysProps("spark.yarn.dist.archives") should include regex (".*archive1.txt,.*archive2.txt") sysProps("spark.app.name") should be ("beauty") sysProps("spark.ui.enabled") should be ("false") sysProps("SPARK_SUBMIT") should be ("true") - sysProps.keys should not contain ("spark.jars") } test("handles YARN client mode") { @@ -249,7 +249,8 @@ class SparkSubmitSuite sysProps("spark.executor.instances") should be ("6") sysProps("spark.yarn.dist.files") should include regex (".*file1.txt,.*file2.txt") sysProps("spark.yarn.dist.archives") should include regex (".*archive1.txt,.*archive2.txt") - sysProps("spark.jars") should include regex (".*one.jar,.*two.jar,.*three.jar,.*thejar.jar") + sysProps("spark.yarn.dist.jars") should include + regex (".*one.jar,.*two.jar,.*three.jar,.*thejar.jar") sysProps("SPARK_SUBMIT") should be ("true") sysProps("spark.ui.enabled") should be ("false") } diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md index c775fe710f..bb83272ec8 100644 --- a/docs/running-on-yarn.md +++ b/docs/running-on-yarn.md @@ -215,6 +215,13 @@ If you need a reference to the proper location to put log files in the YARN so t Comma-separated list of files to be placed in the working directory of each executor. + + + + + diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index e941089d1b..9e8453429c 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -662,7 +662,7 @@ object ApplicationMaster extends Logging { SignalLogger.register(log) val amArgs = new ApplicationMasterArguments(args) SparkHadoopUtil.get.runAsSparkUser { () => - master = new ApplicationMaster(amArgs, new YarnRMClient(amArgs)) + master = new ApplicationMaster(amArgs, new YarnRMClient) System.exit(master.run()) } } diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterArguments.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterArguments.scala index 6987e5a55f..5cdec87667 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterArguments.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterArguments.scala @@ -27,8 +27,6 @@ class ApplicationMasterArguments(val args: Array[String]) { var primaryPyFile: String = null var primaryRFile: String = null var userArgs: Seq[String] = Nil - var executorMemory = 1024 - var executorCores = 1 var propertiesFile: String = null parseArgs(args.toList) @@ -58,18 +56,10 @@ class ApplicationMasterArguments(val args: Array[String]) { primaryRFile = value args = tail - case ("--args" | "--arg") :: value :: tail => + case ("--arg") :: value :: tail => userArgsBuffer += value args = tail - case ("--worker-memory" | "--executor-memory") :: MemoryParam(value) :: tail => - executorMemory = value - args = tail - - case ("--worker-cores" | "--executor-cores") :: IntParam(value) :: tail => - executorCores = value - args = tail - case ("--properties-file") :: value :: tail => propertiesFile = value args = tail @@ -101,12 +91,8 @@ class ApplicationMasterArguments(val args: Array[String]) { | --class CLASS_NAME Name of your application's main class | --primary-py-file A main Python file | --primary-r-file A main R file - | --py-files PY_FILES Comma-separated list of .zip, .egg, or .py files to - | place on the PYTHONPATH for Python apps. - | --args ARGS Arguments to be passed to your application's main class. + | --arg ARG Argument to be passed to your application's main class. | Multiple invocations are possible, each will be passed in order. - | --executor-cores NUM Number of cores for the executors (Default: 1) - | --executor-memory MEM Memory per executor (e.g. 1000M, 2G) (Default: 1G) | --properties-file FILE Path to a custom Spark properties file. """.stripMargin) // scalastyle:on println diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index f0f13a16e0..4dd3ccdf37 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -64,21 +64,44 @@ private[spark] class Client( extends Logging { import Client._ + import YarnSparkHadoopUtil._ def this(clientArgs: ClientArguments, spConf: SparkConf) = this(clientArgs, SparkHadoopUtil.get.newConfiguration(spConf), spConf) private val yarnClient = YarnClient.createYarnClient private val yarnConf = new YarnConfiguration(hadoopConf) - private var credentials: Credentials = null - private val amMemoryOverhead = args.amMemoryOverhead // MB - private val executorMemoryOverhead = args.executorMemoryOverhead // MB + + private val isClusterMode = sparkConf.get("spark.submit.deployMode", "client") == "cluster" + + // AM related configurations + private val amMemory = if (isClusterMode) { + sparkConf.get(DRIVER_MEMORY).toInt + } else { + sparkConf.get(AM_MEMORY).toInt + } + private val amMemoryOverhead = { + val amMemoryOverheadEntry = if (isClusterMode) DRIVER_MEMORY_OVERHEAD else AM_MEMORY_OVERHEAD + sparkConf.get(amMemoryOverheadEntry).getOrElse( + math.max((MEMORY_OVERHEAD_FACTOR * amMemory).toLong, MEMORY_OVERHEAD_MIN)).toInt + } + private val amCores = if (isClusterMode) { + sparkConf.get(DRIVER_CORES) + } else { + sparkConf.get(AM_CORES) + } + + // Executor related configurations + private val executorMemory = sparkConf.get(EXECUTOR_MEMORY) + private val executorMemoryOverhead = sparkConf.get(EXECUTOR_MEMORY_OVERHEAD).getOrElse( + math.max((MEMORY_OVERHEAD_FACTOR * executorMemory).toLong, MEMORY_OVERHEAD_MIN)).toInt + private val distCacheMgr = new ClientDistributedCacheManager() - private val isClusterMode = args.isClusterMode private var loginFromKeytab = false private var principal: String = null private var keytab: String = null + private var credentials: Credentials = null private val launcherBackend = new LauncherBackend() { override def onStopRequest(): Unit = { @@ -179,8 +202,8 @@ private[spark] class Client( newApp: YarnClientApplication, containerContext: ContainerLaunchContext): ApplicationSubmissionContext = { val appContext = newApp.getApplicationSubmissionContext - appContext.setApplicationName(args.appName) - appContext.setQueue(args.amQueue) + appContext.setApplicationName(sparkConf.get("spark.app.name", "Spark")) + appContext.setQueue(sparkConf.get(QUEUE_NAME)) appContext.setAMContainerSpec(containerContext) appContext.setApplicationType("SPARK") @@ -217,8 +240,8 @@ private[spark] class Client( } val capability = Records.newRecord(classOf[Resource]) - capability.setMemory(args.amMemory + amMemoryOverhead) - capability.setVirtualCores(args.amCores) + capability.setMemory(amMemory + amMemoryOverhead) + capability.setVirtualCores(amCores) sparkConf.get(AM_NODE_LABEL_EXPRESSION) match { case Some(expr) => @@ -272,16 +295,16 @@ private[spark] class Client( val maxMem = newAppResponse.getMaximumResourceCapability().getMemory() logInfo("Verifying our application has not requested more than the maximum " + s"memory capability of the cluster ($maxMem MB per container)") - val executorMem = args.executorMemory + executorMemoryOverhead + val executorMem = executorMemory + executorMemoryOverhead if (executorMem > maxMem) { - throw new IllegalArgumentException(s"Required executor memory (${args.executorMemory}" + + throw new IllegalArgumentException(s"Required executor memory ($executorMemory" + s"+$executorMemoryOverhead MB) is above the max threshold ($maxMem MB) of this cluster! " + "Please check the values of 'yarn.scheduler.maximum-allocation-mb' and/or " + "'yarn.nodemanager.resource.memory-mb'.") } - val amMem = args.amMemory + amMemoryOverhead + val amMem = amMemory + amMemoryOverhead if (amMem > maxMem) { - throw new IllegalArgumentException(s"Required AM memory (${args.amMemory}" + + throw new IllegalArgumentException(s"Required AM memory ($amMemory" + s"+$amMemoryOverhead MB) is above the max threshold ($maxMem MB) of this cluster! " + "Please increase the value of 'yarn.scheduler.maximum-allocation-mb'.") } @@ -493,17 +516,15 @@ private[spark] class Client( */ val cachedSecondaryJarLinks = ListBuffer.empty[String] List( - (args.addJars, LocalResourceType.FILE, true), - (args.files, LocalResourceType.FILE, false), - (args.archives, LocalResourceType.ARCHIVE, false) + (sparkConf.get(JARS_TO_DISTRIBUTE), LocalResourceType.FILE, true), + (sparkConf.get(FILES_TO_DISTRIBUTE), LocalResourceType.FILE, false), + (sparkConf.get(ARCHIVES_TO_DISTRIBUTE), LocalResourceType.ARCHIVE, false) ).foreach { case (flist, resType, addToClasspath) => - if (flist != null && !flist.isEmpty()) { - flist.split(',').foreach { file => - val (_, localizedPath) = distribute(file, resType = resType) - require(localizedPath != null) - if (addToClasspath) { - cachedSecondaryJarLinks += localizedPath - } + flist.foreach { file => + val (_, localizedPath) = distribute(file, resType = resType) + require(localizedPath != null) + if (addToClasspath) { + cachedSecondaryJarLinks += localizedPath } } } @@ -519,7 +540,7 @@ private[spark] class Client( // The python files list needs to be treated especially. All files that are not an // archive need to be placed in a subdirectory that will be added to PYTHONPATH. - args.pyFiles.foreach { f => + sparkConf.get(PY_FILES).foreach { f => val targetDir = if (f.endsWith(".py")) Some(LOCALIZED_PYTHON_DIR) else None distribute(f, targetDir = targetDir) } @@ -678,7 +699,7 @@ private[spark] class Client( // // NOTE: the code currently does not handle .py files defined with a "local:" scheme. val pythonPath = new ListBuffer[String]() - val (pyFiles, pyArchives) = args.pyFiles.partition(_.endsWith(".py")) + val (pyFiles, pyArchives) = sparkConf.get(PY_FILES).partition(_.endsWith(".py")) if (pyFiles.nonEmpty) { pythonPath += buildPath(YarnSparkHadoopUtil.expandEnvironment(Environment.PWD), LOCALIZED_PYTHON_DIR) @@ -775,7 +796,7 @@ private[spark] class Client( var prefixEnv: Option[String] = None // Add Xmx for AM memory - javaOpts += "-Xmx" + args.amMemory + "m" + javaOpts += "-Xmx" + amMemory + "m" val tmpDir = new Path( YarnSparkHadoopUtil.expandEnvironment(Environment.PWD), @@ -879,8 +900,6 @@ private[spark] class Client( val amArgs = Seq(amClass) ++ userClass ++ userJar ++ primaryPyFile ++ primaryRFile ++ userArgs ++ Seq( - "--executor-memory", args.executorMemory.toString + "m", - "--executor-cores", args.executorCores.toString, "--properties-file", buildPath(YarnSparkHadoopUtil.expandEnvironment(Environment.PWD), LOCALIZED_CONF_DIR, SPARK_CONF_FILE)) @@ -919,10 +938,10 @@ private[spark] class Client( } def setupCredentials(): Unit = { - loginFromKeytab = args.principal != null || sparkConf.contains(PRINCIPAL.key) + loginFromKeytab = sparkConf.contains(PRINCIPAL.key) if (loginFromKeytab) { - principal = Option(args.principal).orElse(sparkConf.get(PRINCIPAL)).get - keytab = Option(args.keytab).orElse(sparkConf.get(KEYTAB)).orNull + principal = sparkConf.get(PRINCIPAL).get + keytab = sparkConf.get(KEYTAB).orNull require(keytab != null, "Keytab must be specified when principal is specified.") logInfo("Attempting to login to the Kerberos" + @@ -1084,7 +1103,7 @@ private[spark] class Client( } -object Client extends Logging { +private object Client extends Logging { def main(argStrings: Array[String]) { if (!sys.props.contains("SPARK_SUBMIT")) { @@ -1097,11 +1116,7 @@ object Client extends Logging { System.setProperty("SPARK_YARN_MODE", "true") val sparkConf = new SparkConf - val args = new ClientArguments(argStrings, sparkConf) - // to maintain backwards-compatibility - if (!Utils.isDynamicAllocationEnabled(sparkConf)) { - sparkConf.setIfMissing(EXECUTOR_INSTANCES, args.numExecutors) - } + val args = new ClientArguments(argStrings) new Client(args, sparkConf).run() } @@ -1246,7 +1261,7 @@ object Client extends Logging { val secondaryJars = if (args != null) { - getSecondaryJarUris(Option(args.addJars).map(_.split(",").toSeq)) + getSecondaryJarUris(Option(sparkConf.get(JARS_TO_DISTRIBUTE))) } else { getSecondaryJarUris(sparkConf.get(SECONDARY_JARS)) } diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala index 47b4cc3009..61c027ec44 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala @@ -19,118 +19,20 @@ package org.apache.spark.deploy.yarn import scala.collection.mutable.ArrayBuffer -import org.apache.spark.{SparkConf, SparkException} -import org.apache.spark.deploy.yarn.YarnSparkHadoopUtil._ -import org.apache.spark.deploy.yarn.config._ -import org.apache.spark.internal.config._ -import org.apache.spark.util.{IntParam, MemoryParam, Utils} - // TODO: Add code and support for ensuring that yarn resource 'tasks' are location aware ! -private[spark] class ClientArguments( - args: Array[String], - sparkConf: SparkConf) { +private[spark] class ClientArguments(args: Array[String]) { - var addJars: String = null - var files: String = null - var archives: String = null var userJar: String = null var userClass: String = null - var pyFiles: Seq[String] = Nil var primaryPyFile: String = null var primaryRFile: String = null var userArgs: ArrayBuffer[String] = new ArrayBuffer[String]() - var executorMemory = 1024 // MB - var executorCores = 1 - var numExecutors = DEFAULT_NUMBER_EXECUTORS - var amQueue = sparkConf.get(QUEUE_NAME) - var amMemory: Int = _ - var amCores: Int = _ - var appName: String = "Spark" - var priority = 0 - var principal: String = null - var keytab: String = null - def isClusterMode: Boolean = userClass != null - - private var driverMemory: Int = Utils.DEFAULT_DRIVER_MEM_MB // MB - private var driverCores: Int = 1 - private val isDynamicAllocationEnabled = Utils.isDynamicAllocationEnabled(sparkConf) parseArgs(args.toList) - loadEnvironmentArgs() - validateArgs() - - // Additional memory to allocate to containers - val amMemoryOverheadEntry = if (isClusterMode) DRIVER_MEMORY_OVERHEAD else AM_MEMORY_OVERHEAD - val amMemoryOverhead = sparkConf.get(amMemoryOverheadEntry).getOrElse( - math.max((MEMORY_OVERHEAD_FACTOR * amMemory).toLong, MEMORY_OVERHEAD_MIN)).toInt - - val executorMemoryOverhead = sparkConf.get(EXECUTOR_MEMORY_OVERHEAD).getOrElse( - math.max((MEMORY_OVERHEAD_FACTOR * executorMemory).toLong, MEMORY_OVERHEAD_MIN)).toInt - - /** Load any default arguments provided through environment variables and Spark properties. */ - private def loadEnvironmentArgs(): Unit = { - // For backward compatibility, SPARK_YARN_DIST_{ARCHIVES/FILES} should be resolved to hdfs://, - // while spark.yarn.dist.{archives/files} should be resolved to file:// (SPARK-2051). - files = Option(files) - .orElse(sparkConf.get(FILES_TO_DISTRIBUTE).map(p => Utils.resolveURIs(p))) - .orElse(sys.env.get("SPARK_YARN_DIST_FILES")) - .orNull - archives = Option(archives) - .orElse(sparkConf.get(ARCHIVES_TO_DISTRIBUTE).map(p => Utils.resolveURIs(p))) - .orElse(sys.env.get("SPARK_YARN_DIST_ARCHIVES")) - .orNull - // If dynamic allocation is enabled, start at the configured initial number of executors. - // Default to minExecutors if no initialExecutors is set. - numExecutors = YarnSparkHadoopUtil.getInitialTargetExecutorNumber(sparkConf, numExecutors) - principal = Option(principal) - .orElse(sparkConf.get(PRINCIPAL)) - .orNull - keytab = Option(keytab) - .orElse(sparkConf.get(KEYTAB)) - .orNull - } - - /** - * Fail fast if any arguments provided are invalid. - * This is intended to be called only after the provided arguments have been parsed. - */ - private def validateArgs(): Unit = { - if (numExecutors < 0 || (!isDynamicAllocationEnabled && numExecutors == 0)) { - throw new IllegalArgumentException( - s""" - |Number of executors was $numExecutors, but must be at least 1 - |(or 0 if dynamic executor allocation is enabled). - |${getUsageMessage()} - """.stripMargin) - } - if (executorCores < sparkConf.get(CPUS_PER_TASK)) { - throw new SparkException(s"Executor cores must not be less than ${CPUS_PER_TASK.key}.") - } - // scalastyle:off println - if (isClusterMode) { - for (key <- Seq(AM_MEMORY.key, AM_MEMORY_OVERHEAD.key, AM_CORES.key)) { - if (sparkConf.contains(key)) { - println(s"$key is set but does not apply in cluster mode.") - } - } - amMemory = driverMemory - amCores = driverCores - } else { - for (key <- Seq(DRIVER_MEMORY_OVERHEAD.key, DRIVER_CORES.key)) { - if (sparkConf.contains(key)) { - println(s"$key is set but does not apply in client mode.") - } - } - amMemory = sparkConf.get(AM_MEMORY).toInt - amCores = sparkConf.get(AM_CORES) - } - // scalastyle:on println - } private def parseArgs(inputArgs: List[String]): Unit = { var args = inputArgs - // scalastyle:off println while (!args.isEmpty) { args match { case ("--jar") :: value :: tail => @@ -149,88 +51,16 @@ private[spark] class ClientArguments( primaryRFile = value args = tail - case ("--args" | "--arg") :: value :: tail => - if (args(0) == "--args") { - println("--args is deprecated. Use --arg instead.") - } + case ("--arg") :: value :: tail => userArgs += value args = tail - case ("--master-class" | "--am-class") :: value :: tail => - println(s"${args(0)} is deprecated and is not used anymore.") - args = tail - - case ("--master-memory" | "--driver-memory") :: MemoryParam(value) :: tail => - if (args(0) == "--master-memory") { - println("--master-memory is deprecated. Use --driver-memory instead.") - } - driverMemory = value - args = tail - - case ("--driver-cores") :: IntParam(value) :: tail => - driverCores = value - args = tail - - case ("--num-workers" | "--num-executors") :: IntParam(value) :: tail => - if (args(0) == "--num-workers") { - println("--num-workers is deprecated. Use --num-executors instead.") - } - numExecutors = value - args = tail - - case ("--worker-memory" | "--executor-memory") :: MemoryParam(value) :: tail => - if (args(0) == "--worker-memory") { - println("--worker-memory is deprecated. Use --executor-memory instead.") - } - executorMemory = value - args = tail - - case ("--worker-cores" | "--executor-cores") :: IntParam(value) :: tail => - if (args(0) == "--worker-cores") { - println("--worker-cores is deprecated. Use --executor-cores instead.") - } - executorCores = value - args = tail - - case ("--queue") :: value :: tail => - amQueue = value - args = tail - - case ("--name") :: value :: tail => - appName = value - args = tail - - case ("--addJars") :: value :: tail => - addJars = value - args = tail - - case ("--py-files") :: value :: tail => - pyFiles = value.split(",") - args = tail - - case ("--files") :: value :: tail => - files = value - args = tail - - case ("--archives") :: value :: tail => - archives = value - args = tail - - case ("--principal") :: value :: tail => - principal = value - args = tail - - case ("--keytab") :: value :: tail => - keytab = value - args = tail - case Nil => case _ => throw new IllegalArgumentException(getUsageMessage(args)) } } - // scalastyle:on println if (primaryPyFile != null && primaryRFile != null) { throw new IllegalArgumentException("Cannot have primary-py-file and primary-r-file" + @@ -240,7 +70,6 @@ private[spark] class ClientArguments( private def getUsageMessage(unknownParam: List[String] = null): String = { val message = if (unknownParam != null) s"Unknown/unsupported param $unknownParam\n" else "" - val mem_mb = Utils.DEFAULT_DRIVER_MEM_MB message + s""" |Usage: org.apache.spark.deploy.yarn.Client [options] @@ -252,20 +81,6 @@ private[spark] class ClientArguments( | --primary-r-file A main R file | --arg ARG Argument to be passed to your application's main class. | Multiple invocations are possible, each will be passed in order. - | --num-executors NUM Number of executors to start (Default: 2) - | --executor-cores NUM Number of cores per executor (Default: 1). - | --driver-memory MEM Memory for driver (e.g. 1000M, 2G) (Default: $mem_mb Mb) - | --driver-cores NUM Number of cores used by the driver (Default: 1). - | --executor-memory MEM Memory per executor (e.g. 1000M, 2G) (Default: 1G) - | --name NAME The name of your application (Default: Spark) - | --queue QUEUE The hadoop queue to use for allocation requests (Default: - | 'default') - | --addJars jars Comma separated list of local jars that want SparkContext.addJar - | to work with. - | --py-files PY_FILES Comma-separated list of .zip, .egg, or .py files to - | place on the PYTHONPATH for Python apps. - | --files files Comma separated list of files to be distributed with the job. - | --archives archives Comma separated list of archives to be distributed with the job. """.stripMargin } } diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala index d094302362..7d71a642f6 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala @@ -36,6 +36,7 @@ import org.apache.spark.{SecurityManager, SparkConf, SparkException} import org.apache.spark.deploy.yarn.YarnSparkHadoopUtil._ import org.apache.spark.deploy.yarn.config._ import org.apache.spark.internal.Logging +import org.apache.spark.internal.config._ import org.apache.spark.rpc.{RpcCallContext, RpcEndpointRef} import org.apache.spark.scheduler.{ExecutorExited, ExecutorLossReason} import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages.RemoveExecutor @@ -61,7 +62,6 @@ private[yarn] class YarnAllocator( sparkConf: SparkConf, amClient: AMRMClient[ContainerRequest], appAttemptId: ApplicationAttemptId, - args: ApplicationMasterArguments, securityMgr: SecurityManager) extends Logging { @@ -107,12 +107,12 @@ private[yarn] class YarnAllocator( private val containerIdToExecutorId = new HashMap[ContainerId, String] // Executor memory in MB. - protected val executorMemory = args.executorMemory + protected val executorMemory = sparkConf.get(EXECUTOR_MEMORY).toInt // Additional memory overhead. protected val memoryOverhead: Int = sparkConf.get(EXECUTOR_MEMORY_OVERHEAD).getOrElse( math.max((MEMORY_OVERHEAD_FACTOR * executorMemory).toInt, MEMORY_OVERHEAD_MIN)).toInt // Number of cores per executor. - protected val executorCores = args.executorCores + protected val executorCores = sparkConf.get(EXECUTOR_CORES) // Resource capability requested for each executors private[yarn] val resource = Resource.newInstance(executorMemory + memoryOverhead, executorCores) diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala index 83d30b7352..e7f7544664 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala @@ -39,7 +39,7 @@ import org.apache.spark.util.Utils /** * Handles registering and unregistering the application with the YARN ResourceManager. */ -private[spark] class YarnRMClient(args: ApplicationMasterArguments) extends Logging { +private[spark] class YarnRMClient extends Logging { private var amClient: AMRMClient[ContainerRequest] = _ private var uiHistoryAddress: String = _ @@ -72,8 +72,7 @@ private[spark] class YarnRMClient(args: ApplicationMasterArguments) extends Logg amClient.registerApplicationMaster(Utils.localHostName(), 0, uiAddress) registered = true } - new YarnAllocator(driverUrl, driverRef, conf, sparkConf, amClient, getAttemptId(), args, - securityMgr) + new YarnAllocator(driverUrl, driverRef, conf, sparkConf, amClient, getAttemptId(), securityMgr) } /** diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala index 2915e664be..5af2c29808 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala @@ -512,7 +512,7 @@ object YarnSparkHadoopUtil { val initialNumExecutors = conf.get(DYN_ALLOCATION_INITIAL_EXECUTORS) val maxNumExecutors = conf.get(DYN_ALLOCATION_MAX_EXECUTORS) require(initialNumExecutors >= minNumExecutors && initialNumExecutors <= maxNumExecutors, - s"initial executor number $initialNumExecutors must between min executor number" + + s"initial executor number $initialNumExecutors must between min executor number " + s"$minNumExecutors and max executor number $maxNumExecutors") initialNumExecutors diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/config.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/config.scala index 0789567ae6..a3b9134b58 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/config.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/config.scala @@ -85,11 +85,18 @@ package object config { private[spark] val ARCHIVES_TO_DISTRIBUTE = ConfigBuilder("spark.yarn.dist.archives") .stringConf - .optional + .toSequence + .withDefault(Nil) private[spark] val FILES_TO_DISTRIBUTE = ConfigBuilder("spark.yarn.dist.files") .stringConf - .optional + .toSequence + .withDefault(Nil) + + private[spark] val JARS_TO_DISTRIBUTE = ConfigBuilder("spark.yarn.dist.jars") + .stringConf + .toSequence + .withDefault(Nil) private[spark] val PRESERVE_STAGING_FILES = ConfigBuilder("spark.yarn.preserve.staging.files") .doc("Whether to preserve temporary files created by the job in HDFS.") @@ -183,7 +190,7 @@ package object config { private[spark] val DRIVER_CORES = ConfigBuilder("spark.driver.cores") .intConf - .optional + .withDefault(1) private[spark] val DRIVER_MEMORY_OVERHEAD = ConfigBuilder("spark.yarn.driver.memoryOverhead") .bytesConf(ByteUnit.MiB) @@ -191,6 +198,10 @@ package object config { /* Executor configuration. */ + private[spark] val EXECUTOR_CORES = ConfigBuilder("spark.executor.cores") + .intConf + .withDefault(1) + private[spark] val EXECUTOR_MEMORY_OVERHEAD = ConfigBuilder("spark.yarn.executor.memoryOverhead") .bytesConf(ByteUnit.MiB) .optional @@ -245,5 +256,4 @@ package object config { .stringConf .toSequence .optional - } diff --git a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala index 9fc727904b..56dc0004d0 100644 --- a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala +++ b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala @@ -48,11 +48,10 @@ private[spark] class YarnClientSchedulerBackend( val argsArrayBuf = new ArrayBuffer[String]() argsArrayBuf += ("--arg", hostport) - argsArrayBuf ++= getExtraClientArguments logDebug("ClientArguments called with: " + argsArrayBuf.mkString(" ")) - val args = new ClientArguments(argsArrayBuf.toArray, conf) - totalExpectedExecutors = args.numExecutors + val args = new ClientArguments(argsArrayBuf.toArray) + totalExpectedExecutors = YarnSparkHadoopUtil.getInitialTargetExecutorNumber(conf) client = new Client(args, conf) bindToYarn(client.submitApplication(), None) @@ -72,43 +71,6 @@ private[spark] class YarnClientSchedulerBackend( monitorThread.start() } - /** - * Return any extra command line arguments to be passed to Client provided in the form of - * environment variables or Spark properties. - */ - private def getExtraClientArguments: Seq[String] = { - val extraArgs = new ArrayBuffer[String] - // List of (target Client argument, environment variable, Spark property) - val optionTuples = - List( - ("--executor-memory", "SPARK_WORKER_MEMORY", "spark.executor.memory"), - ("--executor-memory", "SPARK_EXECUTOR_MEMORY", "spark.executor.memory"), - ("--executor-cores", "SPARK_WORKER_CORES", "spark.executor.cores"), - ("--executor-cores", "SPARK_EXECUTOR_CORES", "spark.executor.cores"), - ("--queue", "SPARK_YARN_QUEUE", "spark.yarn.queue"), - ("--py-files", null, "spark.submit.pyFiles") - ) - // Warn against the following deprecated environment variables: env var -> suggestion - val deprecatedEnvVars = Map( - "SPARK_WORKER_MEMORY" -> "SPARK_EXECUTOR_MEMORY or --executor-memory through spark-submit", - "SPARK_WORKER_CORES" -> "SPARK_EXECUTOR_CORES or --executor-cores through spark-submit") - optionTuples.foreach { case (optionName, envVar, sparkProp) => - if (sc.getConf.contains(sparkProp)) { - extraArgs += (optionName, sc.getConf.get(sparkProp)) - } else if (envVar != null && System.getenv(envVar) != null) { - extraArgs += (optionName, System.getenv(envVar)) - if (deprecatedEnvVars.contains(envVar)) { - logWarning(s"NOTE: $envVar is deprecated. Use ${deprecatedEnvVars(envVar)} instead.") - } - } - } - // The app name is a special case because "spark.app.name" is required of all applications. - // As a result, the corresponding "SPARK_YARN_APP_NAME" is already handled preemptively in - // SparkSubmitArguments if "spark.app.name" is not explicitly set by the user. (SPARK-5222) - sc.getConf.getOption("spark.app.name").foreach(v => extraArgs += ("--name", v)) - extraArgs - } - /** * Report the state of the application until it is running. * If the application has finished, failed or been killed in the process, throw an exception. diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala index 64723c361c..2eaafa072a 100644 --- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala @@ -118,8 +118,9 @@ class ClientSuite extends SparkFunSuite with Matchers with BeforeAndAfterAll val sparkConf = new SparkConf() .set(SPARK_JARS, Seq(SPARK)) .set(USER_CLASS_PATH_FIRST, true) + .set("spark.yarn.dist.jars", ADDED) val env = new MutableHashMap[String, String]() - val args = new ClientArguments(Array("--jar", USER, "--addJars", ADDED), sparkConf) + val args = new ClientArguments(Array("--jar", USER)) populateClasspath(args, conf, sparkConf, env) @@ -138,9 +139,11 @@ class ClientSuite extends SparkFunSuite with Matchers with BeforeAndAfterAll } test("Jar path propagation through SparkConf") { - val sparkConf = new SparkConf().set(SPARK_JARS, Seq(SPARK)) - val client = createClient(sparkConf, - args = Array("--jar", USER, "--addJars", ADDED)) + val conf = new Configuration() + val sparkConf = new SparkConf() + .set(SPARK_JARS, Seq(SPARK)) + .set("spark.yarn.dist.jars", ADDED) + val client = createClient(sparkConf, args = Array("--jar", USER)) val tempDir = Utils.createTempDir() try { @@ -192,9 +195,9 @@ class ClientSuite extends SparkFunSuite with Matchers with BeforeAndAfterAll val sparkConf = new SparkConf() .set(APPLICATION_TAGS.key, ",tag1, dup,tag2 , ,multi word , dup") .set(MAX_APP_ATTEMPTS, 42) - val args = new ClientArguments(Array( - "--name", "foo-test-app", - "--queue", "staging-queue"), sparkConf) + .set("spark.app.name", "foo-test-app") + .set(QUEUE_NAME, "staging-queue") + val args = new ClientArguments(Array()) val appContext = Records.newRecord(classOf[ApplicationSubmissionContext]) val getNewApplicationResponse = Records.newRecord(classOf[GetNewApplicationResponse]) @@ -346,7 +349,7 @@ class ClientSuite extends SparkFunSuite with Matchers with BeforeAndAfterAll sparkConf: SparkConf, conf: Configuration = new Configuration(), args: Array[String] = Array()): Client = { - val clientArgs = new ClientArguments(args, sparkConf) + val clientArgs = new ClientArguments(args) val client = spy(new Client(clientArgs, conf, sparkConf)) doReturn(new Path("/")).when(client).copyFileToRemote(any(classOf[Path]), any(classOf[Path]), anyShort()) diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala index 0587444a33..a641a6e73e 100644 --- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala @@ -90,12 +90,13 @@ class YarnAllocatorSuite extends SparkFunSuite with Matchers with BeforeAndAfter def createAllocator(maxExecutors: Int = 5): YarnAllocator = { val args = Array( - "--executor-cores", "5", - "--executor-memory", "2048", "--jar", "somejar.jar", "--class", "SomeClass") val sparkConfClone = sparkConf.clone() - sparkConfClone.set("spark.executor.instances", maxExecutors.toString) + sparkConfClone + .set("spark.executor.instances", maxExecutors.toString) + .set("spark.executor.cores", "5") + .set("spark.executor.memory", "2048") new YarnAllocator( "not used", mock(classOf[RpcEndpointRef]), @@ -103,7 +104,6 @@ class YarnAllocatorSuite extends SparkFunSuite with Matchers with BeforeAndAfter sparkConfClone, rmClient, appAttemptId, - new ApplicationMasterArguments(args), new SecurityManager(sparkConf)) } diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala index 26520529ec..b2b4d84f53 100644 --- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala @@ -85,6 +85,35 @@ class YarnClusterSuite extends BaseYarnClusterSuite { testBasicYarnApp(false) } + test("run Spark in yarn-client mode with different configurations") { + testBasicYarnApp(true, + Map( + "spark.driver.memory" -> "512m", + "spark.executor.cores" -> "1", + "spark.executor.memory" -> "512m", + "spark.executor.instances" -> "2" + )) + } + + test("run Spark in yarn-cluster mode with different configurations") { + testBasicYarnApp(true, + Map( + "spark.driver.memory" -> "512m", + "spark.driver.cores" -> "1", + "spark.executor.cores" -> "1", + "spark.executor.memory" -> "512m", + "spark.executor.instances" -> "2" + )) + } + + test("run Spark in yarn-client mode with additional jar") { + testWithAddJar(true) + } + + test("run Spark in yarn-cluster mode with additional jar") { + testWithAddJar(false) + } + test("run Spark in yarn-cluster mode unsuccessfully") { // Don't provide arguments so the driver will fail. val finalState = runSpark(false, mainClassName(YarnClusterDriver.getClass)) @@ -139,13 +168,26 @@ class YarnClusterSuite extends BaseYarnClusterSuite { } } - private def testBasicYarnApp(clientMode: Boolean): Unit = { + private def testBasicYarnApp(clientMode: Boolean, conf: Map[String, String] = Map()): Unit = { val result = File.createTempFile("result", null, tempDir) val finalState = runSpark(clientMode, mainClassName(YarnClusterDriver.getClass), - appArgs = Seq(result.getAbsolutePath())) + appArgs = Seq(result.getAbsolutePath()), + extraConf = conf) checkResult(finalState, result) } + private def testWithAddJar(clientMode: Boolean): Unit = { + val originalJar = TestUtils.createJarWithFiles(Map("test.resource" -> "ORIGINAL"), tempDir) + val driverResult = File.createTempFile("driver", null, tempDir) + val executorResult = File.createTempFile("executor", null, tempDir) + val finalState = runSpark(clientMode, mainClassName(YarnClasspathTest.getClass), + appArgs = Seq(driverResult.getAbsolutePath(), executorResult.getAbsolutePath()), + extraClassPath = Seq(originalJar.getPath()), + extraJars = Seq("local:" + originalJar.getPath())) + checkResult(finalState, driverResult, "ORIGINAL") + checkResult(finalState, executorResult, "ORIGINAL") + } + private def testPySpark(clientMode: Boolean): Unit = { val primaryPyFile = new File(tempDir, "test.py") Files.write(TEST_PYFILE, primaryPyFile, StandardCharsets.UTF_8) -- cgit v1.2.3 From 381358fbe9afbe205299cbbea4c43148e2e69468 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Fri, 1 Apr 2016 12:53:39 -0700 Subject: [SPARK-14305][ML][PYSPARK] PySpark ml.clustering BisectingKMeans support export/import ## What changes were proposed in this pull request? PySpark ml.clustering BisectingKMeans support export/import ## How was this patch tested? doc test. cc jkbradley Author: Yanbo Liang Closes #12112 from yanboliang/spark-14305. --- python/pyspark/ml/clustering.py | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/python/pyspark/ml/clustering.py b/python/pyspark/ml/clustering.py index e22d5c8ea4..f071c597c8 100644 --- a/python/pyspark/ml/clustering.py +++ b/python/pyspark/ml/clustering.py @@ -171,7 +171,7 @@ class KMeans(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIter, HasTol return self.getOrDefault(self.initSteps) -class BisectingKMeansModel(JavaModel): +class BisectingKMeansModel(JavaModel, JavaMLWritable, JavaMLReadable): """ .. note:: Experimental @@ -195,7 +195,8 @@ class BisectingKMeansModel(JavaModel): @inherit_doc -class BisectingKMeans(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIter, HasSeed): +class BisectingKMeans(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIter, HasSeed, + JavaMLWritable, JavaMLReadable): """ .. note:: Experimental @@ -225,6 +226,18 @@ class BisectingKMeans(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIte True >>> rows[2].prediction == rows[3].prediction True + >>> bkm_path = temp_path + "/bkm" + >>> bkm.save(bkm_path) + >>> bkm2 = BisectingKMeans.load(bkm_path) + >>> bkm2.getK() + 2 + >>> model_path = temp_path + "/bkm_model" + >>> model.save(model_path) + >>> model2 = BisectingKMeansModel.load(model_path) + >>> model.clusterCenters()[0] == model2.clusterCenters()[0] + array([ True, True], dtype=bool) + >>> model.clusterCenters()[1] == model2.clusterCenters()[1] + array([ True, True], dtype=bool) .. versionadded:: 2.0.0 """ -- cgit v1.2.3 From df68beb85de59bb6d35b2a8a3b85dbc447798bf5 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Fri, 1 Apr 2016 13:00:55 -0700 Subject: [SPARK-13995][SQL] Extract correct IsNotNull constraints for Expression ## What changes were proposed in this pull request? JIRA: https://issues.apache.org/jira/browse/SPARK-13995 We infer relative `IsNotNull` constraints from logical plan's expressions in `constructIsNotNullConstraints` now. However, we don't consider the case of (nested) `Cast`. For example: val tr = LocalRelation('a.int, 'b.long) val plan = tr.where('a.attr === 'b.attr).analyze Then, the plan's constraints will have `IsNotNull(Cast(resolveColumn(tr, "a"), LongType))`, instead of `IsNotNull(resolveColumn(tr, "a"))`. This PR fixes it. Besides, as `IsNotNull` constraints are most useful for `Attribute`, we should do recursing through any `Expression` that is null intolerant and construct `IsNotNull` constraints for all `Attribute`s under these Expressions. For example, consider the following constraints: val df = Seq((1,2,3)).toDF("a", "b", "c") df.where("a + b = c").queryExecution.analyzed.constraints The inferred isnotnull constraints should be isnotnull(a), isnotnull(b), isnotnull(c), instead of isnotnull(a + c) and isnotnull(c). ## How was this patch tested? Test is added into `ConstraintPropagationSuite`. Author: Liang-Chi Hsieh Closes #11809 from viirya/constraint-cast. --- .../spark/sql/catalyst/expressions/Cast.scala | 2 +- .../sql/catalyst/expressions/arithmetic.scala | 25 ++++--- .../catalyst/expressions/namedExpressions.scala | 2 +- .../spark/sql/catalyst/expressions/package.scala | 7 ++ .../sql/catalyst/expressions/predicates.scala | 17 +++-- .../spark/sql/catalyst/plans/QueryPlan.scala | 33 ++++----- .../plans/ConstraintPropagationSuite.scala | 85 +++++++++++++++++++++- 7 files changed, 134 insertions(+), 37 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index a965cc8d53..d842ffdc66 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -112,7 +112,7 @@ object Cast { } /** Cast the child expression to the target data type. */ -case class Cast(child: Expression, dataType: DataType) extends UnaryExpression { +case class Cast(child: Expression, dataType: DataType) extends UnaryExpression with NullIntolerant { override def toString: String = s"cast($child as ${dataType.simpleString})" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 1e9c971800..b388091538 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -24,7 +24,8 @@ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.CalendarInterval -case class UnaryMinus(child: Expression) extends UnaryExpression with ExpectsInputTypes { +case class UnaryMinus(child: Expression) extends UnaryExpression + with ExpectsInputTypes with NullIntolerant { override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection.NumericAndInterval) @@ -58,7 +59,8 @@ case class UnaryMinus(child: Expression) extends UnaryExpression with ExpectsInp override def sql: String = s"(-${child.sql})" } -case class UnaryPositive(child: Expression) extends UnaryExpression with ExpectsInputTypes { +case class UnaryPositive(child: Expression) + extends UnaryExpression with ExpectsInputTypes with NullIntolerant { override def prettyName: String = "positive" override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection.NumericAndInterval) @@ -79,7 +81,8 @@ case class UnaryPositive(child: Expression) extends UnaryExpression with Expects @ExpressionDescription( usage = "_FUNC_(expr) - Returns the absolute value of the numeric value", extended = "> SELECT _FUNC_('-1');\n1") -case class Abs(child: Expression) extends UnaryExpression with ExpectsInputTypes { +case class Abs(child: Expression) + extends UnaryExpression with ExpectsInputTypes with NullIntolerant { override def inputTypes: Seq[AbstractDataType] = Seq(NumericType) @@ -123,7 +126,7 @@ private[sql] object BinaryArithmetic { def unapply(e: BinaryArithmetic): Option[(Expression, Expression)] = Some((e.left, e.right)) } -case class Add(left: Expression, right: Expression) extends BinaryArithmetic { +case class Add(left: Expression, right: Expression) extends BinaryArithmetic with NullIntolerant { override def inputType: AbstractDataType = TypeCollection.NumericAndInterval @@ -152,7 +155,8 @@ case class Add(left: Expression, right: Expression) extends BinaryArithmetic { } } -case class Subtract(left: Expression, right: Expression) extends BinaryArithmetic { +case class Subtract(left: Expression, right: Expression) + extends BinaryArithmetic with NullIntolerant { override def inputType: AbstractDataType = TypeCollection.NumericAndInterval @@ -181,7 +185,8 @@ case class Subtract(left: Expression, right: Expression) extends BinaryArithmeti } } -case class Multiply(left: Expression, right: Expression) extends BinaryArithmetic { +case class Multiply(left: Expression, right: Expression) + extends BinaryArithmetic with NullIntolerant { override def inputType: AbstractDataType = NumericType @@ -193,7 +198,8 @@ case class Multiply(left: Expression, right: Expression) extends BinaryArithmeti protected override def nullSafeEval(input1: Any, input2: Any): Any = numeric.times(input1, input2) } -case class Divide(left: Expression, right: Expression) extends BinaryArithmetic { +case class Divide(left: Expression, right: Expression) + extends BinaryArithmetic with NullIntolerant { override def inputType: AbstractDataType = NumericType @@ -269,7 +275,8 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic } } -case class Remainder(left: Expression, right: Expression) extends BinaryArithmetic { +case class Remainder(left: Expression, right: Expression) + extends BinaryArithmetic with NullIntolerant { override def inputType: AbstractDataType = NumericType @@ -457,7 +464,7 @@ case class MinOf(left: Expression, right: Expression) override def symbol: String = "min" } -case class Pmod(left: Expression, right: Expression) extends BinaryArithmetic { +case class Pmod(left: Expression, right: Expression) extends BinaryArithmetic with NullIntolerant { override def toString: String = s"pmod($left, $right)" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala index a5b5758167..262582ca5d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala @@ -97,7 +97,7 @@ trait NamedExpression extends Expression { } } -abstract class Attribute extends LeafExpression with NamedExpression { +abstract class Attribute extends LeafExpression with NamedExpression with NullIntolerant { override def references: AttributeSet = AttributeSet(this) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala index f1fa13daa7..23baa6f783 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala @@ -92,4 +92,11 @@ package object expressions { StructType(attrs.map(a => StructField(a.name, a.dataType, a.nullable))) } } + + /** + * When an expression inherits this, meaning the expression is null intolerant (i.e. any null + * input will result in null output). We will use this information during constructing IsNotNull + * constraints. + */ + trait NullIntolerant } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index e23ad5596b..4eb33258ac 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -90,7 +90,7 @@ trait PredicateHelper { case class Not(child: Expression) - extends UnaryExpression with Predicate with ImplicitCastInputTypes { + extends UnaryExpression with Predicate with ImplicitCastInputTypes with NullIntolerant { override def toString: String = s"NOT $child" @@ -402,7 +402,8 @@ private[sql] object Equality { } -case class EqualTo(left: Expression, right: Expression) extends BinaryComparison { +case class EqualTo(left: Expression, right: Expression) + extends BinaryComparison with NullIntolerant { override def inputType: AbstractDataType = AnyDataType @@ -467,7 +468,8 @@ case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComp } -case class LessThan(left: Expression, right: Expression) extends BinaryComparison { +case class LessThan(left: Expression, right: Expression) + extends BinaryComparison with NullIntolerant { override def inputType: AbstractDataType = TypeCollection.Ordered @@ -479,7 +481,8 @@ case class LessThan(left: Expression, right: Expression) extends BinaryCompariso } -case class LessThanOrEqual(left: Expression, right: Expression) extends BinaryComparison { +case class LessThanOrEqual(left: Expression, right: Expression) + extends BinaryComparison with NullIntolerant { override def inputType: AbstractDataType = TypeCollection.Ordered @@ -491,7 +494,8 @@ case class LessThanOrEqual(left: Expression, right: Expression) extends BinaryCo } -case class GreaterThan(left: Expression, right: Expression) extends BinaryComparison { +case class GreaterThan(left: Expression, right: Expression) + extends BinaryComparison with NullIntolerant { override def inputType: AbstractDataType = TypeCollection.Ordered @@ -503,7 +507,8 @@ case class GreaterThan(left: Expression, right: Expression) extends BinaryCompar } -case class GreaterThanOrEqual(left: Expression, right: Expression) extends BinaryComparison { +case class GreaterThanOrEqual(left: Expression, right: Expression) + extends BinaryComparison with NullIntolerant { override def inputType: AbstractDataType = TypeCollection.Ordered diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index d31164fe94..22a4461e66 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -44,25 +44,9 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT * returns a constraint of the form `isNotNull(a)` */ private def constructIsNotNullConstraints(constraints: Set[Expression]): Set[Expression] = { - var isNotNullConstraints = Set.empty[Expression] - - // First, we propagate constraints if the condition consists of equality and ranges. For all - // other cases, we return an empty set of constraints - constraints.foreach { - case EqualTo(l, r) => - isNotNullConstraints ++= Set(IsNotNull(l), IsNotNull(r)) - case GreaterThan(l, r) => - isNotNullConstraints ++= Set(IsNotNull(l), IsNotNull(r)) - case GreaterThanOrEqual(l, r) => - isNotNullConstraints ++= Set(IsNotNull(l), IsNotNull(r)) - case LessThan(l, r) => - isNotNullConstraints ++= Set(IsNotNull(l), IsNotNull(r)) - case LessThanOrEqual(l, r) => - isNotNullConstraints ++= Set(IsNotNull(l), IsNotNull(r)) - case Not(EqualTo(l, r)) => - isNotNullConstraints ++= Set(IsNotNull(l), IsNotNull(r)) - case _ => // No inference - } + // First, we propagate constraints from the null intolerant expressions. + var isNotNullConstraints: Set[Expression] = + constraints.flatMap(scanNullIntolerantExpr).map(IsNotNull(_)) // Second, we infer additional constraints from non-nullable attributes that are part of the // operator's output @@ -72,6 +56,17 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT isNotNullConstraints -- constraints } + /** + * Recursively explores the expressions which are null intolerant and returns all attributes + * in these expressions. + */ + private def scanNullIntolerantExpr(expr: Expression): Seq[Attribute] = expr match { + case a: Attribute => Seq(a) + case _: NullIntolerant | IsNotNull(_: NullIntolerant) => + expr.children.flatMap(scanNullIntolerantExpr) + case _ => Seq.empty[Attribute] + } + /** * Infers an additional set of constraints from a given set of equality constraints. * For e.g., if an operator has constraints of the form (`a = 5`, `a = b`), this returns an diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala index e5063599a3..5cbb889f8e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.types.{IntegerType, StringType} +import org.apache.spark.sql.types.{DoubleType, IntegerType, LongType, StringType} class ConstraintPropagationSuite extends SparkFunSuite { @@ -219,6 +219,89 @@ class ConstraintPropagationSuite extends SparkFunSuite { IsNotNull(resolveColumn(tr, "b"))))) } + test("infer constraints on cast") { + val tr = LocalRelation('a.int, 'b.long, 'c.int, 'd.long, 'e.int) + verifyConstraints( + tr.where('a.attr === 'b.attr && + 'c.attr + 100 > 'd.attr && + IsNotNull(Cast(Cast(resolveColumn(tr, "e"), LongType), LongType))).analyze.constraints, + ExpressionSet(Seq(Cast(resolveColumn(tr, "a"), LongType) === resolveColumn(tr, "b"), + Cast(resolveColumn(tr, "c") + 100, LongType) > resolveColumn(tr, "d"), + IsNotNull(resolveColumn(tr, "a")), + IsNotNull(resolveColumn(tr, "b")), + IsNotNull(resolveColumn(tr, "c")), + IsNotNull(resolveColumn(tr, "d")), + IsNotNull(resolveColumn(tr, "e")), + IsNotNull(Cast(Cast(resolveColumn(tr, "e"), LongType), LongType))))) + } + + test("infer isnotnull constraints from compound expressions") { + val tr = LocalRelation('a.int, 'b.long, 'c.int, 'd.long, 'e.int) + verifyConstraints( + tr.where('a.attr + 'b.attr === 'c.attr && + IsNotNull( + Cast( + Cast(Cast(resolveColumn(tr, "e"), LongType), LongType), LongType))).analyze.constraints, + ExpressionSet(Seq( + Cast(resolveColumn(tr, "a"), LongType) + resolveColumn(tr, "b") === + Cast(resolveColumn(tr, "c"), LongType), + IsNotNull(resolveColumn(tr, "a")), + IsNotNull(resolveColumn(tr, "b")), + IsNotNull(resolveColumn(tr, "c")), + IsNotNull(resolveColumn(tr, "e")), + IsNotNull(Cast(Cast(Cast(resolveColumn(tr, "e"), LongType), LongType), LongType))))) + + verifyConstraints( + tr.where(('a.attr * 'b.attr + 100) === 'c.attr && 'd / 10 === 'e).analyze.constraints, + ExpressionSet(Seq( + Cast(resolveColumn(tr, "a"), LongType) * resolveColumn(tr, "b") + Cast(100, LongType) === + Cast(resolveColumn(tr, "c"), LongType), + Cast(resolveColumn(tr, "d"), DoubleType) / + Cast(Cast(10, LongType), DoubleType) === + Cast(resolveColumn(tr, "e"), DoubleType), + IsNotNull(resolveColumn(tr, "a")), + IsNotNull(resolveColumn(tr, "b")), + IsNotNull(resolveColumn(tr, "c")), + IsNotNull(resolveColumn(tr, "d")), + IsNotNull(resolveColumn(tr, "e"))))) + + verifyConstraints( + tr.where(('a.attr * 'b.attr - 10) >= 'c.attr && 'd / 10 < 'e).analyze.constraints, + ExpressionSet(Seq( + Cast(resolveColumn(tr, "a"), LongType) * resolveColumn(tr, "b") - Cast(10, LongType) >= + Cast(resolveColumn(tr, "c"), LongType), + Cast(resolveColumn(tr, "d"), DoubleType) / + Cast(Cast(10, LongType), DoubleType) < + Cast(resolveColumn(tr, "e"), DoubleType), + IsNotNull(resolveColumn(tr, "a")), + IsNotNull(resolveColumn(tr, "b")), + IsNotNull(resolveColumn(tr, "c")), + IsNotNull(resolveColumn(tr, "d")), + IsNotNull(resolveColumn(tr, "e"))))) + + verifyConstraints( + tr.where('a.attr + 'b.attr - 'c.attr * 'd.attr > 'e.attr * 1000).analyze.constraints, + ExpressionSet(Seq( + (Cast(resolveColumn(tr, "a"), LongType) + resolveColumn(tr, "b")) - + (Cast(resolveColumn(tr, "c"), LongType) * resolveColumn(tr, "d")) > + Cast(resolveColumn(tr, "e") * 1000, LongType), + IsNotNull(resolveColumn(tr, "a")), + IsNotNull(resolveColumn(tr, "b")), + IsNotNull(resolveColumn(tr, "c")), + IsNotNull(resolveColumn(tr, "d")), + IsNotNull(resolveColumn(tr, "e"))))) + + // The constraint IsNotNull(IsNotNull(expr)) doesn't guarantee expr is not null. + verifyConstraints( + tr.where('a.attr === 'c.attr && + IsNotNull(IsNotNull(resolveColumn(tr, "b")))).analyze.constraints, + ExpressionSet(Seq( + resolveColumn(tr, "a") === resolveColumn(tr, "c"), + IsNotNull(IsNotNull(resolveColumn(tr, "b"))), + IsNotNull(resolveColumn(tr, "a")), + IsNotNull(resolveColumn(tr, "c"))))) + } + test("infer IsNotNull constraints from non-nullable attributes") { val tr = LocalRelation('a.int, AttributeReference("b", IntegerType, nullable = false)(), AttributeReference("c", StringType, nullable = false)()) -- cgit v1.2.3 From a884daad805a701494e87393dc307937472a985d Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Fri, 1 Apr 2016 13:03:27 -0700 Subject: [SPARK-14191][SQL] Remove invalid Expand operator constraints `Expand` operator now uses its child plan's constraints as its valid constraints (i.e., the base of constraints). This is not correct because `Expand` will set its group by attributes to null values. So the nullability of these attributes should be true. E.g., for an `Expand` operator like: val input = LocalRelation('a.int, 'b.int, 'c.int).where('c.attr > 10 && 'a.attr < 5 && 'b.attr > 2) Expand( Seq( Seq('c, Literal.create(null, StringType), 1), Seq('c, 'a, 2)), Seq('c, 'a, 'gid.int), Project(Seq('a, 'c), input)) The `Project` operator has the constraints `IsNotNull('a)`, `IsNotNull('b)` and `IsNotNull('c)`. But the `Expand` should not have `IsNotNull('a)` in its constraints. This PR is the first step for this issue and remove invalid constraints of `Expand` operator. A test is added to `ConstraintPropagationSuite`. Author: Liang-Chi Hsieh Author: Michael Armbrust Closes #11995 from viirya/fix-expand-constraints. --- .../catalyst/plans/logical/basicOperators.scala | 5 +++- .../plans/ConstraintPropagationSuite.scala | 27 ++++++++++++++++++++++ 2 files changed, 31 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 09c200fa83..a18efc90ab 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -519,7 +519,6 @@ case class Expand( projections: Seq[Seq[Expression]], output: Seq[Attribute], child: LogicalPlan) extends UnaryNode { - override def references: AttributeSet = AttributeSet(projections.flatten.flatMap(_.references)) @@ -527,6 +526,10 @@ case class Expand( val sizeInBytes = super.statistics.sizeInBytes * projections.length Statistics(sizeInBytes = sizeInBytes) } + + // This operator can reuse attributes (for example making them null when doing a roll up) so + // the contraints of the child may no longer be valid. + override protected def validConstraints: Set[Expression] = Set.empty[Expression] } /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala index 5cbb889f8e..49c1353efb 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala @@ -88,6 +88,33 @@ class ConstraintPropagationSuite extends SparkFunSuite { IsNotNull(resolveColumn(aliasedRelation.analyze, "a"))))) } + test("propagating constraints in expand") { + val tr = LocalRelation('a.int, 'b.int, 'c.int) + + assert(tr.analyze.constraints.isEmpty) + + // We add IsNotNull constraints for 'a, 'b and 'c into LocalRelation + // by creating notNullRelation. + val notNullRelation = tr.where('c.attr > 10 && 'a.attr < 5 && 'b.attr > 2) + verifyConstraints(notNullRelation.analyze.constraints, + ExpressionSet(Seq(resolveColumn(notNullRelation.analyze, "c") > 10, + IsNotNull(resolveColumn(notNullRelation.analyze, "c")), + resolveColumn(notNullRelation.analyze, "a") < 5, + IsNotNull(resolveColumn(notNullRelation.analyze, "a")), + resolveColumn(notNullRelation.analyze, "b") > 2, + IsNotNull(resolveColumn(notNullRelation.analyze, "b"))))) + + val expand = Expand( + Seq( + Seq('c, Literal.create(null, StringType), 1), + Seq('c, 'a, 2)), + Seq('c, 'a, 'gid.int), + Project(Seq('a, 'c), + notNullRelation)) + verifyConstraints(expand.analyze.constraints, + ExpressionSet(Seq.empty[Expression])) + } + test("propagating constraints in aliases") { val tr = LocalRelation('a.int, 'b.string, 'c.int) -- cgit v1.2.3 From 1e886159849e3918445d3fdc3c4cef86c6c1a236 Mon Sep 17 00:00:00 2001 From: Tejas Patil Date: Fri, 1 Apr 2016 13:13:16 -0700 Subject: [SPARK-14070][SQL] Use ORC data source for SQL queries on ORC tables ## What changes were proposed in this pull request? This patch enables use of OrcRelation for SQL queries which read data from Hive tables. Changes in this patch: - Added a new rule `OrcConversions` which would alter the plan to use `OrcRelation`. In this diff, the conversion is done only for reads. - Added a new config `spark.sql.hive.convertMetastoreOrc` to control the conversion BEFORE ``` scala> hqlContext.sql("SELECT * FROM orc_table").explain(true) == Parsed Logical Plan == 'Project [unresolvedalias(*, None)] +- 'UnresolvedRelation `orc_table`, None == Analyzed Logical Plan == key: string, value: string Project [key#171,value#172] +- MetastoreRelation default, orc_table, None == Optimized Logical Plan == MetastoreRelation default, orc_table, None == Physical Plan == HiveTableScan [key#171,value#172], MetastoreRelation default, orc_table, None ``` AFTER ``` scala> hqlContext.sql("SELECT * FROM orc_table").explain(true) == Parsed Logical Plan == 'Project [unresolvedalias(*, None)] +- 'UnresolvedRelation `orc_table`, None == Analyzed Logical Plan == key: string, value: string Project [key#76,value#77] +- SubqueryAlias orc_table +- Relation[key#76,value#77] ORC part: struct<>, data: struct == Optimized Logical Plan == Relation[key#76,value#77] ORC part: struct<>, data: struct == Physical Plan == WholeStageCodegen : +- Scan ORC part: struct<>, data: struct[key#76,value#77] InputPaths: file:/user/hive/warehouse/orc_table ``` ## How was this patch tested? - Added a new unit test. Ran existing unit tests - Ran with production like data ## Performance gains Ran on a production table in Facebook (note that the data was in DWRF file format which is similar to ORC) Best case : when there was no matching rows for the predicate in the query (everything is filtered out) ``` CPU time Wall time Total wall time across all tasks ================================================================ Without the change 541_515 sec 25.0 mins 165.8 hours With change 407 sec 1.5 mins 15 mins ``` Average case: A subset of rows in the data match the query predicate ``` CPU time Wall time Total wall time across all tasks ================================================================ Without the change 624_630 sec 31.0 mins 199.0 h With change 14_769 sec 5.3 mins 7.7 h ``` Author: Tejas Patil Closes #11891 from tejasapatil/orc_ppd. --- .../hive/execution/HiveCompatibilitySuite.scala | 8 +- .../org/apache/spark/sql/hive/HiveContext.scala | 12 ++ .../spark/sql/hive/HiveMetastoreCatalog.scala | 234 ++++++++++++++------- .../apache/spark/sql/hive/HiveSessionCatalog.scala | 1 + .../apache/spark/sql/hive/HiveSessionState.scala | 1 + .../apache/spark/sql/hive/orc/OrcQuerySuite.scala | 39 ++++ 6 files changed, 220 insertions(+), 75 deletions(-) diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala index 695b5ef733..d8695bc5db 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala @@ -20,9 +20,10 @@ package org.apache.spark.sql.hive.execution import java.io.File import java.util.{Locale, TimeZone} -import org.apache.spark.sql.catalyst.rules.RuleExecutor import org.scalatest.BeforeAndAfter +import org.apache.spark.sql.catalyst.rules.RuleExecutor +import org.apache.spark.sql.hive.HiveContext import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.internal.SQLConf @@ -38,6 +39,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { private val originalLocale = Locale.getDefault private val originalColumnBatchSize = TestHive.conf.columnBatchSize private val originalInMemoryPartitionPruning = TestHive.conf.inMemoryPartitionPruning + private val originalConvertMetastoreOrc = TestHive.convertMetastoreOrc def testCases: Seq[(String, File)] = { hiveQueryDir.listFiles.map(f => f.getName.stripSuffix(".q") -> f) @@ -56,6 +58,9 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { TestHive.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, true) // Use Hive hash expression instead of the native one TestHive.sessionState.functionRegistry.unregisterFunction("hash") + // Ensures that the plans generation use metastore relation and not OrcRelation + // Was done because SqlBuilder does not work with plans having logical relation + TestHive.setConf(HiveContext.CONVERT_METASTORE_ORC, false) RuleExecutor.resetTime() } @@ -66,6 +71,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { Locale.setDefault(originalLocale) TestHive.setConf(SQLConf.COLUMN_BATCH_SIZE, originalColumnBatchSize) TestHive.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, originalInMemoryPartitionPruning) + TestHive.setConf(HiveContext.CONVERT_METASTORE_ORC, originalConvertMetastoreOrc) TestHive.sessionState.functionRegistry.restore() // For debugging dump some statistics about how much time was spent in various optimizer rules. diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index c0b6d16d3c..073b954a5f 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -154,6 +154,13 @@ class HiveContext private[hive]( protected[sql] def convertMetastoreParquetWithSchemaMerging: Boolean = getConf(CONVERT_METASTORE_PARQUET_WITH_SCHEMA_MERGING) + /** + * When true, enables an experimental feature where metastore tables that use the Orc SerDe + * are automatically converted to use the Spark SQL ORC table scan, instead of the Hive + * SerDe. + */ + protected[sql] def convertMetastoreOrc: Boolean = getConf(CONVERT_METASTORE_ORC) + /** * When true, a table created by a Hive CTAS statement (no USING clause) will be * converted to a data source table, using the data source set by spark.sql.sources.default. @@ -442,6 +449,11 @@ private[hive] object HiveContext extends Logging { "different Parquet data files. This configuration is only effective " + "when \"spark.sql.hive.convertMetastoreParquet\" is true.") + val CONVERT_METASTORE_ORC = booleanConf("spark.sql.hive.convertMetastoreOrc", + defaultValue = Some(true), + doc = "When set to false, Spark SQL will use the Hive SerDe for ORC tables instead of " + + "the built in support.") + val CONVERT_CTAS = booleanConf("spark.sql.hive.convertCTAS", defaultValue = Some(false), doc = "When true, a table created by a Hive CTAS statement (no USING clause) will be " + diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index 2cdc931c3f..14f331961e 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -40,12 +40,13 @@ import org.apache.spark.sql.catalyst.parser.DataTypeParser import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ -import org.apache.spark.sql.execution.{datasources, FileRelation} +import org.apache.spark.sql.execution.FileRelation import org.apache.spark.sql.execution.datasources._ -import org.apache.spark.sql.execution.datasources.parquet.{DefaultSource, ParquetRelation} +import org.apache.spark.sql.execution.datasources.parquet.{DefaultSource => ParquetDefaultSource, ParquetRelation} import org.apache.spark.sql.hive.client._ import org.apache.spark.sql.hive.execution.HiveNativeCommand -import org.apache.spark.sql.sources.{HadoopFsRelation, HDFSFileCatalog} +import org.apache.spark.sql.hive.orc.{DefaultSource => OrcDefaultSource} +import org.apache.spark.sql.sources.{FileFormat, HadoopFsRelation, HDFSFileCatalog} import org.apache.spark.sql.types._ private[hive] case class HiveSerDe( @@ -451,58 +452,72 @@ private[hive] class HiveMetastoreCatalog(val client: HiveClient, hive: HiveConte } } - private def convertToParquetRelation(metastoreRelation: MetastoreRelation): LogicalRelation = { - val metastoreSchema = StructType.fromAttributes(metastoreRelation.output) - val mergeSchema = hive.convertMetastoreParquetWithSchemaMerging - - val parquetOptions = Map( - ParquetRelation.MERGE_SCHEMA -> mergeSchema.toString, - ParquetRelation.METASTORE_TABLE_NAME -> TableIdentifier( - metastoreRelation.tableName, - Some(metastoreRelation.databaseName) - ).unquotedString - ) - val tableIdentifier = - QualifiedTableName(metastoreRelation.databaseName, metastoreRelation.tableName) - - def getCached( - tableIdentifier: QualifiedTableName, - pathsInMetastore: Seq[String], - schemaInMetastore: StructType, - partitionSpecInMetastore: Option[PartitionSpec]): Option[LogicalRelation] = { - cachedDataSourceTables.getIfPresent(tableIdentifier) match { - case null => None // Cache miss - case logical @ LogicalRelation(parquetRelation: HadoopFsRelation, _, _) => - // If we have the same paths, same schema, and same partition spec, - // we will use the cached Parquet Relation. - val useCached = - parquetRelation.location.paths.map(_.toString).toSet == pathsInMetastore.toSet && - logical.schema.sameType(metastoreSchema) && - parquetRelation.partitionSpec == partitionSpecInMetastore.getOrElse { - PartitionSpec(StructType(Nil), Array.empty[datasources.PartitionDirectory]) + private def getCached( + tableIdentifier: QualifiedTableName, + metastoreRelation: MetastoreRelation, + schemaInMetastore: StructType, + expectedFileFormat: Class[_ <: FileFormat], + expectedBucketSpec: Option[BucketSpec], + partitionSpecInMetastore: Option[PartitionSpec]): Option[LogicalRelation] = { + + cachedDataSourceTables.getIfPresent(tableIdentifier) match { + case null => None // Cache miss + case logical @ LogicalRelation(relation: HadoopFsRelation, _, _) => + val pathsInMetastore = metastoreRelation.table.storage.locationUri.toSeq + val cachedRelationFileFormatClass = relation.fileFormat.getClass + + expectedFileFormat match { + case `cachedRelationFileFormatClass` => + // If we have the same paths, same schema, and same partition spec, + // we will use the cached relation. + val useCached = + relation.location.paths.map(_.toString).toSet == pathsInMetastore.toSet && + logical.schema.sameType(schemaInMetastore) && + relation.bucketSpec == expectedBucketSpec && + relation.partitionSpec == partitionSpecInMetastore.getOrElse { + PartitionSpec(StructType(Nil), Array.empty[PartitionDirectory]) + } + + if (useCached) { + Some(logical) + } else { + // If the cached relation is not updated, we invalidate it right away. + cachedDataSourceTables.invalidate(tableIdentifier) + None } - - if (useCached) { - Some(logical) - } else { - // If the cached relation is not updated, we invalidate it right away. + case _ => + logWarning( + s"${metastoreRelation.databaseName}.${metastoreRelation.tableName} " + + s"should be stored as $expectedFileFormat. However, we are getting " + + s"a ${relation.fileFormat} from the metastore cache. This cached " + + s"entry will be invalidated.") cachedDataSourceTables.invalidate(tableIdentifier) None - } - case other => - logWarning( - s"${metastoreRelation.databaseName}.${metastoreRelation.tableName} should be stored " + - s"as Parquet. However, we are getting a $other from the metastore cache. " + - s"This cached entry will be invalidated.") - cachedDataSourceTables.invalidate(tableIdentifier) - None - } + } + case other => + logWarning( + s"${metastoreRelation.databaseName}.${metastoreRelation.tableName} should be stored " + + s"as $expectedFileFormat. However, we are getting a $other from the metastore cache. " + + s"This cached entry will be invalidated.") + cachedDataSourceTables.invalidate(tableIdentifier) + None } + } + + private def convertToLogicalRelation(metastoreRelation: MetastoreRelation, + options: Map[String, String], + defaultSource: FileFormat, + fileFormatClass: Class[_ <: FileFormat], + fileType: String): LogicalRelation = { + val metastoreSchema = StructType.fromAttributes(metastoreRelation.output) + val tableIdentifier = + QualifiedTableName(metastoreRelation.databaseName, metastoreRelation.tableName) + val bucketSpec = None // We don't support hive bucketed tables, only ones we write out. val result = if (metastoreRelation.hiveQlTable.isPartitioned) { val partitionSchema = StructType.fromAttributes(metastoreRelation.partitionKeys) val partitionColumnDataTypes = partitionSchema.map(_.dataType) - // We're converting the entire table into ParquetRelation, so predicates to Hive metastore + // We're converting the entire table into HadoopFsRelation, so predicates to Hive metastore // are empty. val partitions = metastoreRelation.getHiveQlPartitions().map { p => val location = p.getLocation @@ -515,54 +530,65 @@ private[hive] class HiveMetastoreCatalog(val client: HiveClient, hive: HiveConte val cached = getCached( tableIdentifier, - metastoreRelation.table.storage.locationUri.toSeq, + metastoreRelation, metastoreSchema, + fileFormatClass, + bucketSpec, Some(partitionSpec)) - val parquetRelation = cached.getOrElse { + val hadoopFsRelation = cached.getOrElse { val paths = new Path(metastoreRelation.table.storage.locationUri.get) :: Nil val fileCatalog = new MetaStoreFileCatalog(hive, paths, partitionSpec) - val format = new DefaultSource() - val inferredSchema = format.inferSchema(hive, parquetOptions, fileCatalog.allFiles()) - val mergedSchema = inferredSchema.map { inferred => - ParquetRelation.mergeMetastoreParquetSchema(metastoreSchema, inferred) - }.getOrElse(metastoreSchema) + val inferredSchema = if (fileType.equals("parquet")) { + val inferredSchema = defaultSource.inferSchema(hive, options, fileCatalog.allFiles()) + inferredSchema.map { inferred => + ParquetRelation.mergeMetastoreParquetSchema(metastoreSchema, inferred) + }.getOrElse(metastoreSchema) + } else { + defaultSource.inferSchema(hive, options, fileCatalog.allFiles()).get + } val relation = HadoopFsRelation( sqlContext = hive, location = fileCatalog, partitionSchema = partitionSchema, - dataSchema = mergedSchema, - bucketSpec = None, // We don't support hive bucketed tables, only ones we write out. - fileFormat = new DefaultSource(), - options = parquetOptions) + dataSchema = inferredSchema, + bucketSpec = bucketSpec, + fileFormat = defaultSource, + options = options) val created = LogicalRelation(relation) cachedDataSourceTables.put(tableIdentifier, created) created } - parquetRelation + hadoopFsRelation } else { val paths = Seq(metastoreRelation.hiveQlTable.getDataLocation.toString) - val cached = getCached(tableIdentifier, paths, metastoreSchema, None) - val parquetRelation = cached.getOrElse { + val cached = getCached(tableIdentifier, + metastoreRelation, + metastoreSchema, + fileFormatClass, + bucketSpec, + None) + val logicalRelation = cached.getOrElse { val created = LogicalRelation( DataSource( sqlContext = hive, paths = paths, userSpecifiedSchema = Some(metastoreRelation.schema), - options = parquetOptions, - className = "parquet").resolveRelation()) + bucketSpec = bucketSpec, + options = options, + className = fileType).resolveRelation()) cachedDataSourceTables.put(tableIdentifier, created) created } - parquetRelation + logicalRelation } result.copy(expectedOutputAttributes = Some(metastoreRelation.output)) } @@ -572,6 +598,27 @@ private[hive] class HiveMetastoreCatalog(val client: HiveClient, hive: HiveConte * data source relations for better performance. */ object ParquetConversions extends Rule[LogicalPlan] { + private def shouldConvertMetastoreParquet(relation: MetastoreRelation): Boolean = { + relation.tableDesc.getSerdeClassName.toLowerCase.contains("parquet") && + hive.convertMetastoreParquet + } + + private def convertToParquetRelation(relation: MetastoreRelation): LogicalRelation = { + val defaultSource = new ParquetDefaultSource() + val fileFormatClass = classOf[ParquetDefaultSource] + + val mergeSchema = hive.convertMetastoreParquetWithSchemaMerging + val options = Map( + ParquetRelation.MERGE_SCHEMA -> mergeSchema.toString, + ParquetRelation.METASTORE_TABLE_NAME -> TableIdentifier( + relation.tableName, + Some(relation.databaseName) + ).unquotedString + ) + + convertToLogicalRelation(relation, options, defaultSource, fileFormatClass, "parquet") + } + override def apply(plan: LogicalPlan): LogicalPlan = { if (!plan.resolved || plan.analyzed) { return plan @@ -581,28 +628,67 @@ private[hive] class HiveMetastoreCatalog(val client: HiveClient, hive: HiveConte // Write path case InsertIntoTable(r: MetastoreRelation, partition, child, overwrite, ifNotExists) // Inserting into partitioned table is not supported in Parquet data source (yet). - if !r.hiveQlTable.isPartitioned && hive.convertMetastoreParquet && - r.tableDesc.getSerdeClassName.toLowerCase.contains("parquet") => - val parquetRelation = convertToParquetRelation(r) - InsertIntoTable(parquetRelation, partition, child, overwrite, ifNotExists) + if !r.hiveQlTable.isPartitioned && shouldConvertMetastoreParquet(r) => + InsertIntoTable(convertToParquetRelation(r), partition, child, overwrite, ifNotExists) // Write path case InsertIntoHiveTable(r: MetastoreRelation, partition, child, overwrite, ifNotExists) // Inserting into partitioned table is not supported in Parquet data source (yet). - if !r.hiveQlTable.isPartitioned && hive.convertMetastoreParquet && - r.tableDesc.getSerdeClassName.toLowerCase.contains("parquet") => - val parquetRelation = convertToParquetRelation(r) - InsertIntoTable(parquetRelation, partition, child, overwrite, ifNotExists) + if !r.hiveQlTable.isPartitioned && shouldConvertMetastoreParquet(r) => + InsertIntoTable(convertToParquetRelation(r), partition, child, overwrite, ifNotExists) // Read path - case relation: MetastoreRelation if hive.convertMetastoreParquet && - relation.tableDesc.getSerdeClassName.toLowerCase.contains("parquet") => + case relation: MetastoreRelation if shouldConvertMetastoreParquet(relation) => val parquetRelation = convertToParquetRelation(relation) SubqueryAlias(relation.alias.getOrElse(relation.tableName), parquetRelation) } } } + /** + * When scanning Metastore ORC tables, convert them to ORC data source relations + * for better performance. + */ + object OrcConversions extends Rule[LogicalPlan] { + private def shouldConvertMetastoreOrc(relation: MetastoreRelation): Boolean = { + relation.tableDesc.getSerdeClassName.toLowerCase.contains("orc") && + hive.convertMetastoreOrc + } + + private def convertToOrcRelation(relation: MetastoreRelation): LogicalRelation = { + val defaultSource = new OrcDefaultSource() + val fileFormatClass = classOf[OrcDefaultSource] + val options = Map[String, String]() + + convertToLogicalRelation(relation, options, defaultSource, fileFormatClass, "orc") + } + + override def apply(plan: LogicalPlan): LogicalPlan = { + if (!plan.resolved || plan.analyzed) { + return plan + } + + plan transformUp { + // Write path + case InsertIntoTable(r: MetastoreRelation, partition, child, overwrite, ifNotExists) + // Inserting into partitioned table is not supported in Orc data source (yet). + if !r.hiveQlTable.isPartitioned && shouldConvertMetastoreOrc(r) => + InsertIntoTable(convertToOrcRelation(r), partition, child, overwrite, ifNotExists) + + // Write path + case InsertIntoHiveTable(r: MetastoreRelation, partition, child, overwrite, ifNotExists) + // Inserting into partitioned table is not supported in Orc data source (yet). + if !r.hiveQlTable.isPartitioned && shouldConvertMetastoreOrc(r) => + InsertIntoTable(convertToOrcRelation(r), partition, child, overwrite, ifNotExists) + + // Read path + case relation: MetastoreRelation if shouldConvertMetastoreOrc(relation) => + val orcRelation = convertToOrcRelation(relation) + SubqueryAlias(relation.alias.getOrElse(relation.tableName), orcRelation) + } + } + } + /** * Creates any tables required for query execution. * For example, because of a CREATE TABLE X AS statement. diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala index 1cd783e63a..dfbf22cc47 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala @@ -74,6 +74,7 @@ class HiveSessionCatalog( private val metastoreCatalog = new HiveMetastoreCatalog(client, context) val ParquetConversions: Rule[LogicalPlan] = metastoreCatalog.ParquetConversions + val OrcConversions: Rule[LogicalPlan] = metastoreCatalog.OrcConversions val CreateTables: Rule[LogicalPlan] = metastoreCatalog.CreateTables val PreInsertionCasts: Rule[LogicalPlan] = metastoreCatalog.PreInsertionCasts diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala index 11ef0fd1bb..2bdb428e9d 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala @@ -57,6 +57,7 @@ private[hive] class HiveSessionState(ctx: HiveContext) extends SessionState(ctx) new Analyzer(catalog, functionRegistry, conf) { override val extendedResolutionRules = catalog.ParquetConversions :: + catalog.OrcConversions :: catalog.CreateTables :: catalog.PreInsertionCasts :: python.ExtractPythonUDFs :: diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala index 92f424bac7..5ef8194f28 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala @@ -26,6 +26,8 @@ import org.scalatest.BeforeAndAfterAll import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.execution.datasources.LogicalRelation +import org.apache.spark.sql.hive.HiveContext import org.apache.spark.sql.hive.test.TestHive._ import org.apache.spark.sql.hive.test.TestHive.implicits._ import org.apache.spark.sql.internal.SQLConf @@ -400,4 +402,41 @@ class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest { } } } + + test("SPARK-14070 Use ORC data source for SQL queries on ORC tables") { + withTempPath { dir => + withSQLConf(SQLConf.ORC_FILTER_PUSHDOWN_ENABLED.key -> "true", + HiveContext.CONVERT_METASTORE_ORC.key -> "true") { + val path = dir.getCanonicalPath + + withTable("dummy_orc") { + withTempTable("single") { + sqlContext.sql( + s"""CREATE TABLE dummy_orc(key INT, value STRING) + |STORED AS ORC + |LOCATION '$path' + """.stripMargin) + + val singleRowDF = Seq((0, "foo")).toDF("key", "value").coalesce(1) + singleRowDF.registerTempTable("single") + + sqlContext.sql( + s"""INSERT INTO TABLE dummy_orc + |SELECT key, value FROM single + """.stripMargin) + + val df = sqlContext.sql("SELECT * FROM dummy_orc WHERE key=0") + checkAnswer(df, singleRowDF) + + val queryExecution = df.queryExecution + queryExecution.analyzed.collectFirst { + case _: LogicalRelation => () + }.getOrElse { + fail(s"Expecting the query plan to have LogicalRelation, but got:\n$queryExecution") + } + } + } + } + } + } } -- cgit v1.2.3 From 1b829ce13990b40fd8d7c9efcc2ae55c4dbc861c Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Fri, 1 Apr 2016 13:19:24 -0700 Subject: [SPARK-14160] Time Windowing functions for Datasets MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? This PR adds the function `window` as a column expression. `window` can be used to bucket rows into time windows given a time column. With this expression, performing time series analysis on batch data, as well as streaming data should become much more simpler. ### Usage Assume the following schema: `sensor_id, measurement, timestamp` To average 5 minute data every 1 minute (window length of 5 minutes, slide duration of 1 minute), we will use: ```scala df.groupBy(window("timestamp", “5 minutes”, “1 minute”), "sensor_id") .agg(mean("measurement").as("avg_meas")) ``` This will generate windows such as: ``` 09:00:00-09:05:00 09:01:00-09:06:00 09:02:00-09:07:00 ... ``` Intervals will start at every `slideDuration` starting at the unix epoch (1970-01-01 00:00:00 UTC). To start intervals at a different point of time, e.g. 30 seconds after a minute, the `startTime` parameter can be used. ```scala df.groupBy(window("timestamp", “5 minutes”, “1 minute”, "30 second"), "sensor_id") .agg(mean("measurement").as("avg_meas")) ``` This will generate windows such as: ``` 09:00:30-09:05:30 09:01:30-09:06:30 09:02:30-09:07:30 ... ``` Support for Python will be made in a follow up PR after this. ## How was this patch tested? This patch has some basic unit tests for the `TimeWindow` expression testing that the parameters pass validation, and it also has some unit/integration tests testing the correctness of the windowing and usability in complex operations (multi-column grouping, multi-column projections, joins). Author: Burak Yavuz Author: Michael Armbrust Closes #12008 from brkyvz/df-time-window. --- .../spark/sql/catalyst/analysis/Analyzer.scala | 90 ++++++++ .../sql/catalyst/analysis/FunctionRegistry.scala | 1 + .../sql/catalyst/expressions/TimeWindow.scala | 133 +++++++++++ .../sql/catalyst/analysis/AnalysisErrorSuite.scala | 56 +++++ .../sql/catalyst/expressions/TimeWindowSuite.scala | 76 +++++++ .../scala/org/apache/spark/sql/functions.scala | 137 ++++++++++++ .../spark/sql/DataFrameTimeWindowingSuite.scala | 242 +++++++++++++++++++++ 7 files changed, 735 insertions(+) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/TimeWindowSuite.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/DataFrameTimeWindowingSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 8dc0532b3f..d82ee3a205 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -102,6 +102,7 @@ class Analyzer( ExtractWindowExpressions :: GlobalAggregates :: ResolveAggregateFunctions :: + TimeWindowing :: HiveTypeCoercion.typeCoercionRules ++ extendedResolutionRules : _*), Batch("Nondeterministic", Once, @@ -1591,3 +1592,92 @@ object ResolveUpCast extends Rule[LogicalPlan] { } } } + +/** + * Maps a time column to multiple time windows using the Expand operator. Since it's non-trivial to + * figure out how many windows a time column can map to, we over-estimate the number of windows and + * filter out the rows where the time column is not inside the time window. + */ +object TimeWindowing extends Rule[LogicalPlan] { + import org.apache.spark.sql.catalyst.dsl.expressions._ + + private final val WINDOW_START = "start" + private final val WINDOW_END = "end" + + /** + * Generates the logical plan for generating window ranges on a timestamp column. Without + * knowing what the timestamp value is, it's non-trivial to figure out deterministically how many + * window ranges a timestamp will map to given all possible combinations of a window duration, + * slide duration and start time (offset). Therefore, we express and over-estimate the number of + * windows there may be, and filter the valid windows. We use last Project operator to group + * the window columns into a struct so they can be accessed as `window.start` and `window.end`. + * + * The windows are calculated as below: + * maxNumOverlapping <- ceil(windowDuration / slideDuration) + * for (i <- 0 until maxNumOverlapping) + * windowId <- ceil((timestamp - startTime) / slideDuration) + * windowStart <- windowId * slideDuration + (i - maxNumOverlapping) * slideDuration + startTime + * windowEnd <- windowStart + windowDuration + * return windowStart, windowEnd + * + * This behaves as follows for the given parameters for the time: 12:05. The valid windows are + * marked with a +, and invalid ones are marked with a x. The invalid ones are filtered using the + * Filter operator. + * window: 12m, slide: 5m, start: 0m :: window: 12m, slide: 5m, start: 2m + * 11:55 - 12:07 + 11:52 - 12:04 x + * 12:00 - 12:12 + 11:57 - 12:09 + + * 12:05 - 12:17 + 12:02 - 12:14 + + * + * @param plan The logical plan + * @return the logical plan that will generate the time windows using the Expand operator, with + * the Filter operator for correctness and Project for usability. + */ + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + case p: LogicalPlan if p.children.size == 1 => + val child = p.children.head + val windowExpressions = + p.expressions.flatMap(_.collect { case t: TimeWindow => t }).distinct.toList // Not correct. + + // Only support a single window expression for now + if (windowExpressions.size == 1 && + windowExpressions.head.timeColumn.resolved && + windowExpressions.head.checkInputDataTypes().isSuccess) { + val window = windowExpressions.head + val windowAttr = AttributeReference("window", window.dataType)() + + val maxNumOverlapping = math.ceil(window.windowDuration * 1.0 / window.slideDuration).toInt + val windows = Seq.tabulate(maxNumOverlapping + 1) { i => + val windowId = Ceil((PreciseTimestamp(window.timeColumn) - window.startTime) / + window.slideDuration) + val windowStart = (windowId + i - maxNumOverlapping) * + window.slideDuration + window.startTime + val windowEnd = windowStart + window.windowDuration + + CreateNamedStruct( + Literal(WINDOW_START) :: windowStart :: + Literal(WINDOW_END) :: windowEnd :: Nil) + } + + val projections = windows.map(_ +: p.children.head.output) + + val filterExpr = + window.timeColumn >= windowAttr.getField(WINDOW_START) && + window.timeColumn < windowAttr.getField(WINDOW_END) + + val expandedPlan = + Filter(filterExpr, + Expand(projections, windowAttr +: child.output, child)) + + val substitutedPlan = p transformExpressions { + case t: TimeWindow => windowAttr + } + + substitutedPlan.withNewChildren(expandedPlan :: Nil) + } else if (windowExpressions.size > 1) { + p.failAnalysis("Multiple time window expressions would result in a cartesian product " + + "of rows, therefore they are not currently not supported.") + } else { + p // Return unchanged. Analyzer will throw exception later + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index e9788b7e4d..ca8db3cbc5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -297,6 +297,7 @@ object FunctionRegistry { expression[UnixTimestamp]("unix_timestamp"), expression[WeekOfYear]("weekofyear"), expression[Year]("year"), + expression[TimeWindow]("window"), // collection functions expression[ArrayContains]("array_contains"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala new file mode 100644 index 0000000000..8e13833486 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala @@ -0,0 +1,133 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.commons.lang.StringUtils + +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.TypeCheckFailure +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.CalendarInterval + +case class TimeWindow( + timeColumn: Expression, + windowDuration: Long, + slideDuration: Long, + startTime: Long) extends UnaryExpression + with ImplicitCastInputTypes + with Unevaluable + with NonSQLExpression { + + override def child: Expression = timeColumn + override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType) + override def dataType: DataType = new StructType() + .add(StructField("start", TimestampType)) + .add(StructField("end", TimestampType)) + + // This expression is replaced in the analyzer. + override lazy val resolved = false + + /** + * Validate the inputs for the window duration, slide duration, and start time in addition to + * the input data type. + */ + override def checkInputDataTypes(): TypeCheckResult = { + val dataTypeCheck = super.checkInputDataTypes() + if (dataTypeCheck.isSuccess) { + if (windowDuration <= 0) { + return TypeCheckFailure(s"The window duration ($windowDuration) must be greater than 0.") + } + if (slideDuration <= 0) { + return TypeCheckFailure(s"The slide duration ($slideDuration) must be greater than 0.") + } + if (startTime < 0) { + return TypeCheckFailure(s"The start time ($startTime) must be greater than or equal to 0.") + } + if (slideDuration > windowDuration) { + return TypeCheckFailure(s"The slide duration ($slideDuration) must be less than or equal" + + s" to the windowDuration ($windowDuration).") + } + if (startTime >= slideDuration) { + return TypeCheckFailure(s"The start time ($startTime) must be less than the " + + s"slideDuration ($slideDuration).") + } + } + dataTypeCheck + } +} + +object TimeWindow { + /** + * Parses the interval string for a valid time duration. CalendarInterval expects interval + * strings to start with the string `interval`. For usability, we prepend `interval` to the string + * if the user omitted it. + * + * @param interval The interval string + * @return The interval duration in microseconds. SparkSQL casts TimestampType has microsecond + * precision. + */ + private def getIntervalInMicroSeconds(interval: String): Long = { + if (StringUtils.isBlank(interval)) { + throw new IllegalArgumentException( + "The window duration, slide duration and start time cannot be null or blank.") + } + val intervalString = if (interval.startsWith("interval")) { + interval + } else { + "interval " + interval + } + val cal = CalendarInterval.fromString(intervalString) + if (cal == null) { + throw new IllegalArgumentException( + s"The provided interval ($interval) did not correspond to a valid interval string.") + } + if (cal.months > 0) { + throw new IllegalArgumentException( + s"Intervals greater than a month is not supported ($interval).") + } + cal.microseconds + } + + def apply( + timeColumn: Expression, + windowDuration: String, + slideDuration: String, + startTime: String): TimeWindow = { + TimeWindow(timeColumn, + getIntervalInMicroSeconds(windowDuration), + getIntervalInMicroSeconds(slideDuration), + getIntervalInMicroSeconds(startTime)) + } +} + +/** + * Expression used internally to convert the TimestampType to Long without losing + * precision, i.e. in microseconds. Used in time windowing. + */ +case class PreciseTimestamp(child: Expression) extends UnaryExpression with ExpectsInputTypes { + override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType) + override def dataType: DataType = LongType + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { + val eval = child.gen(ctx) + eval.code + + s"""boolean ${ev.isNull} = ${eval.isNull}; + |${ctx.javaType(dataType)} ${ev.value} = ${eval.value}; + """.stripMargin + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala index a90dfc5039..ad101d1c40 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala @@ -272,6 +272,62 @@ class AnalysisErrorSuite extends AnalysisTest { testRelation2.where('bad_column > 1).groupBy('a)(UnresolvedAlias(max('b))), "cannot resolve '`bad_column`'" :: Nil) + errorTest( + "slide duration greater than window in time window", + testRelation2.select( + TimeWindow(Literal("2016-01-01 01:01:01"), "1 second", "2 second", "0 second").as("window")), + s"The slide duration " :: " must be less than or equal to the windowDuration " :: Nil + ) + + errorTest( + "start time greater than slide duration in time window", + testRelation.select( + TimeWindow(Literal("2016-01-01 01:01:01"), "1 second", "1 second", "1 minute").as("window")), + "The start time " :: " must be less than the slideDuration " :: Nil + ) + + errorTest( + "start time equal to slide duration in time window", + testRelation.select( + TimeWindow(Literal("2016-01-01 01:01:01"), "1 second", "1 second", "1 second").as("window")), + "The start time " :: " must be less than the slideDuration " :: Nil + ) + + errorTest( + "negative window duration in time window", + testRelation.select( + TimeWindow(Literal("2016-01-01 01:01:01"), "-1 second", "1 second", "0 second").as("window")), + "The window duration " :: " must be greater than 0." :: Nil + ) + + errorTest( + "zero window duration in time window", + testRelation.select( + TimeWindow(Literal("2016-01-01 01:01:01"), "0 second", "1 second", "0 second").as("window")), + "The window duration " :: " must be greater than 0." :: Nil + ) + + errorTest( + "negative slide duration in time window", + testRelation.select( + TimeWindow(Literal("2016-01-01 01:01:01"), "1 second", "-1 second", "0 second").as("window")), + "The slide duration " :: " must be greater than 0." :: Nil + ) + + errorTest( + "zero slide duration in time window", + testRelation.select( + TimeWindow(Literal("2016-01-01 01:01:01"), "1 second", "0 second", "0 second").as("window")), + "The slide duration" :: " must be greater than 0." :: Nil + ) + + errorTest( + "negative start time in time window", + testRelation.select( + TimeWindow(Literal("2016-01-01 01:01:01"), "1 second", "1 second", "-5 second").as("window")), + "The start time" :: "must be greater than or equal to 0." :: Nil + ) + test("SPARK-6452 regression test") { // CheckAnalysis should throw AnalysisException when Aggregate contains missing attribute(s) // Since we manually construct the logical plan at here and Sum only accept diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/TimeWindowSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/TimeWindowSuite.scala new file mode 100644 index 0000000000..71f969aee2 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/TimeWindowSuite.scala @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.AnalysisException + +class TimeWindowSuite extends SparkFunSuite with ExpressionEvalHelper { + + test("time window is unevaluable") { + intercept[UnsupportedOperationException] { + evaluate(TimeWindow(Literal(10L), "1 second", "1 second", "0 second")) + } + } + + private def checkErrorMessage(msg: String, value: String): Unit = { + val validDuration = "10 second" + val validTime = "5 second" + val e1 = intercept[IllegalArgumentException] { + TimeWindow(Literal(10L), value, validDuration, validTime).windowDuration + } + val e2 = intercept[IllegalArgumentException] { + TimeWindow(Literal(10L), validDuration, value, validTime).slideDuration + } + val e3 = intercept[IllegalArgumentException] { + TimeWindow(Literal(10L), validDuration, validDuration, value).startTime + } + Seq(e1, e2, e3).foreach { e => + e.getMessage.contains(msg) + } + } + + test("blank intervals throw exception") { + for (blank <- Seq(null, " ", "\n", "\t")) { + checkErrorMessage( + "The window duration, slide duration and start time cannot be null or blank.", blank) + } + } + + test("invalid intervals throw exception") { + checkErrorMessage( + "did not correspond to a valid interval string.", "2 apples") + } + + test("intervals greater than a month throws exception") { + checkErrorMessage( + "Intervals greater than or equal to a month is not supported (1 month).", "1 month") + } + + test("interval strings work with and without 'interval' prefix and return microseconds") { + val validDuration = "10 second" + for ((text, seconds) <- Seq( + ("1 second", 1000000), // 1e6 + ("1 minute", 60000000), // 6e7 + ("2 hours", 7200000000L))) { // 72e9 + assert(TimeWindow(Literal(10L), text, validDuration, "0 seconds").windowDuration === seconds) + assert(TimeWindow(Literal(10L), "interval " + text, validDuration, "0 seconds").windowDuration + === seconds) + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 7ce15e3f35..74906050ac 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -2550,6 +2550,143 @@ object functions { ToUTCTimestamp(ts.expr, Literal(tz)) } + /** + * Bucketize rows into one or more time windows given a timestamp specifying column. Window + * starts are inclusive but the window ends are exclusive, e.g. 12:05 will be in the window + * [12:05,12:10) but not in [12:00,12:05). Windows can support microsecond precision. Windows in + * the order of months are not supported. The following example takes the average stock price for + * a one minute window every 10 seconds starting 5 seconds after the hour: + * + * {{{ + * val df = ... // schema => timestamp: TimestampType, stockId: StringType, price: DoubleType + * df.groupBy(window($"time", "1 minute", "10 seconds", "5 seconds"), $"stockId") + * .agg(mean("price")) + * }}} + * + * The windows will look like: + * + * {{{ + * 09:00:05-09:01:05 + * 09:00:15-09:01:15 + * 09:00:25-09:01:25 ... + * }}} + * + * For a continuous query, you may use the function `current_timestamp` to generate windows on + * processing time. + * + * @param timeColumn The column or the expression to use as the timestamp for windowing by time. + * The time can be as TimestampType or LongType, however when using LongType, + * the time must be given in seconds. + * @param windowDuration A string specifying the width of the window, e.g. `10 minutes`, + * `1 second`. Check [[org.apache.spark.unsafe.types.CalendarInterval]] for + * valid duration identifiers. + * @param slideDuration A string specifying the sliding interval of the window, e.g. `1 minute`. + * A new window will be generated every `slideDuration`. Must be less than + * or equal to the `windowDuration`. Check + * [[org.apache.spark.unsafe.types.CalendarInterval]] for valid duration + * identifiers. + * @param startTime The offset with respect to 1970-01-01 00:00:00 UTC with which to start + * window intervals. For example, in order to have hourly tumbling windows that + * start 15 minutes past the hour, e.g. 12:15-13:15, 13:15-14:15... provide + * `startTime` as `15 minutes`. + * + * @group datetime_funcs + * @since 2.0.0 + */ + @Experimental + def window( + timeColumn: Column, + windowDuration: String, + slideDuration: String, + startTime: String): Column = { + withExpr { + TimeWindow(timeColumn.expr, windowDuration, slideDuration, startTime) + }.as("window") + } + + + /** + * Bucketize rows into one or more time windows given a timestamp specifying column. Window + * starts are inclusive but the window ends are exclusive, e.g. 12:05 will be in the window + * [12:05,12:10) but not in [12:00,12:05). Windows can support microsecond precision. Windows in + * the order of months are not supported. The windows start beginning at 1970-01-01 00:00:00 UTC. + * The following example takes the average stock price for a one minute window every 10 seconds: + * + * {{{ + * val df = ... // schema => timestamp: TimestampType, stockId: StringType, price: DoubleType + * df.groupBy(window($"time", "1 minute", "10 seconds"), $"stockId") + * .agg(mean("price")) + * }}} + * + * The windows will look like: + * + * {{{ + * 09:00:00-09:01:00 + * 09:00:10-09:01:10 + * 09:00:20-09:01:20 ... + * }}} + * + * For a continuous query, you may use the function `current_timestamp` to generate windows on + * processing time. + * + * @param timeColumn The column or the expression to use as the timestamp for windowing by time. + * The time can be as TimestampType or LongType, however when using LongType, + * the time must be given in seconds. + * @param windowDuration A string specifying the width of the window, e.g. `10 minutes`, + * `1 second`. Check [[org.apache.spark.unsafe.types.CalendarInterval]] for + * valid duration identifiers. + * @param slideDuration A string specifying the sliding interval of the window, e.g. `1 minute`. + * A new window will be generated every `slideDuration`. Must be less than + * or equal to the `windowDuration`. Check + * [[org.apache.spark.unsafe.types.CalendarInterval]] for valid duration. + * + * @group datetime_funcs + * @since 2.0.0 + */ + @Experimental + def window(timeColumn: Column, windowDuration: String, slideDuration: String): Column = { + window(timeColumn, windowDuration, slideDuration, "0 second") + } + + /** + * Generates tumbling time windows given a timestamp specifying column. Window + * starts are inclusive but the window ends are exclusive, e.g. 12:05 will be in the window + * [12:05,12:10) but not in [12:00,12:05). Windows can support microsecond precision. Windows in + * the order of months are not supported. The windows start beginning at 1970-01-01 00:00:00 UTC. + * The following example takes the average stock price for a one minute tumbling window: + * + * {{{ + * val df = ... // schema => timestamp: TimestampType, stockId: StringType, price: DoubleType + * df.groupBy(window($"time", "1 minute"), $"stockId") + * .agg(mean("price")) + * }}} + * + * The windows will look like: + * + * {{{ + * 09:00:00-09:01:00 + * 09:01:00-09:02:00 + * 09:02:00-09:03:00 ... + * }}} + * + * For a continuous query, you may use the function `current_timestamp` to generate windows on + * processing time. + * + * @param timeColumn The column or the expression to use as the timestamp for windowing by time. + * The time can be as TimestampType or LongType, however when using LongType, + * the time must be given in seconds. + * @param windowDuration A string specifying the width of the window, e.g. `10 minutes`, + * `1 second`. Check [[org.apache.spark.unsafe.types.CalendarInterval]] for + * valid duration identifiers. + * + * @group datetime_funcs + * @since 2.0.0 + */ + @Experimental + def window(timeColumn: Column, windowDuration: String): Column = { + window(timeColumn, windowDuration, windowDuration, "0 second") + } + ////////////////////////////////////////////////////////////////////////////////////////////// // Collection functions ////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTimeWindowingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTimeWindowingSuite.scala new file mode 100644 index 0000000000..e8103a31d5 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTimeWindowingSuite.scala @@ -0,0 +1,242 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import java.util.TimeZone + +import org.scalatest.BeforeAndAfterEach + +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types.StringType + +class DataFrameTimeWindowingSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { + + import testImplicits._ + + override def beforeEach(): Unit = { + super.beforeEach() + TimeZone.setDefault(TimeZone.getTimeZone("UTC")) + } + + override def afterEach(): Unit = { + super.beforeEach() + TimeZone.setDefault(null) + } + + test("tumbling window groupBy statement") { + val df = Seq( + ("2016-03-27 19:39:34", 1, "a"), + ("2016-03-27 19:39:56", 2, "a"), + ("2016-03-27 19:39:27", 4, "b")).toDF("time", "value", "id") + checkAnswer( + df.groupBy(window($"time", "10 seconds")) + .agg(count("*").as("counts")) + .orderBy($"window.start".asc) + .select("counts"), + Seq(Row(1), Row(1), Row(1)) + ) + } + + test("tumbling window groupBy statement with startTime") { + val df = Seq( + ("2016-03-27 19:39:34", 1, "a"), + ("2016-03-27 19:39:56", 2, "a"), + ("2016-03-27 19:39:27", 4, "b")).toDF("time", "value", "id") + + checkAnswer( + df.groupBy(window($"time", "10 seconds", "10 seconds", "5 seconds"), $"id") + .agg(count("*").as("counts")) + .orderBy($"window.start".asc) + .select("counts"), + Seq(Row(1), Row(1), Row(1))) + } + + test("tumbling window with multi-column projection") { + val df = Seq( + ("2016-03-27 19:39:34", 1, "a"), + ("2016-03-27 19:39:56", 2, "a"), + ("2016-03-27 19:39:27", 4, "b")).toDF("time", "value", "id") + + checkAnswer( + df.select(window($"time", "10 seconds"), $"value") + .orderBy($"window.start".asc) + .select($"window.start".cast("string"), $"window.end".cast("string"), $"value"), + Seq( + Row("2016-03-27 19:39:20", "2016-03-27 19:39:30", 4), + Row("2016-03-27 19:39:30", "2016-03-27 19:39:40", 1), + Row("2016-03-27 19:39:50", "2016-03-27 19:40:00", 2) + ) + ) + } + + test("sliding window grouping") { + val df = Seq( + ("2016-03-27 19:39:34", 1, "a"), + ("2016-03-27 19:39:56", 2, "a"), + ("2016-03-27 19:39:27", 4, "b")).toDF("time", "value", "id") + + checkAnswer( + df.groupBy(window($"time", "10 seconds", "3 seconds", "0 second")) + .agg(count("*").as("counts")) + .orderBy($"window.start".asc) + .select($"window.start".cast("string"), $"window.end".cast("string"), $"counts"), + // 2016-03-27 19:39:27 UTC -> 4 bins + // 2016-03-27 19:39:34 UTC -> 3 bins + // 2016-03-27 19:39:56 UTC -> 3 bins + Seq( + Row("2016-03-27 19:39:18", "2016-03-27 19:39:28", 1), + Row("2016-03-27 19:39:21", "2016-03-27 19:39:31", 1), + Row("2016-03-27 19:39:24", "2016-03-27 19:39:34", 1), + Row("2016-03-27 19:39:27", "2016-03-27 19:39:37", 2), + Row("2016-03-27 19:39:30", "2016-03-27 19:39:40", 1), + Row("2016-03-27 19:39:33", "2016-03-27 19:39:43", 1), + Row("2016-03-27 19:39:48", "2016-03-27 19:39:58", 1), + Row("2016-03-27 19:39:51", "2016-03-27 19:40:01", 1), + Row("2016-03-27 19:39:54", "2016-03-27 19:40:04", 1)) + ) + } + + test("sliding window projection") { + val df = Seq( + ("2016-03-27 19:39:34", 1, "a"), + ("2016-03-27 19:39:56", 2, "a"), + ("2016-03-27 19:39:27", 4, "b")).toDF("time", "value", "id") + + checkAnswer( + df.select(window($"time", "10 seconds", "3 seconds", "0 second"), $"value") + .orderBy($"window.start".asc, $"value".desc).select("value"), + // 2016-03-27 19:39:27 UTC -> 4 bins + // 2016-03-27 19:39:34 UTC -> 3 bins + // 2016-03-27 19:39:56 UTC -> 3 bins + Seq(Row(4), Row(4), Row(4), Row(4), Row(1), Row(1), Row(1), Row(2), Row(2), Row(2)) + ) + } + + test("windowing combined with explode expression") { + val df = Seq( + ("2016-03-27 19:39:34", 1, Seq("a", "b")), + ("2016-03-27 19:39:56", 2, Seq("a", "c", "d"))).toDF("time", "value", "ids") + + checkAnswer( + df.select(window($"time", "10 seconds"), $"value", explode($"ids")) + .orderBy($"window.start".asc).select("value"), + // first window exploded to two rows for "a", and "b", second window exploded to 3 rows + Seq(Row(1), Row(1), Row(2), Row(2), Row(2)) + ) + } + + test("null timestamps") { + val df = Seq( + ("2016-03-27 09:00:05", 1), + ("2016-03-27 09:00:32", 2), + (null, 3), + (null, 4)).toDF("time", "value") + + checkDataset( + df.select(window($"time", "10 seconds"), $"value") + .orderBy($"window.start".asc) + .select("value") + .as[Int], + 1, 2) // null columns are dropped + } + + test("time window joins") { + val df = Seq( + ("2016-03-27 09:00:05", 1), + ("2016-03-27 09:00:32", 2), + (null, 3), + (null, 4)).toDF("time", "value") + + val df2 = Seq( + ("2016-03-27 09:00:02", 3), + ("2016-03-27 09:00:35", 6)).toDF("time", "othervalue") + + checkAnswer( + df.select(window($"time", "10 seconds"), $"value").join( + df2.select(window($"time", "10 seconds"), $"othervalue"), Seq("window")) + .groupBy("window") + .agg((sum("value") + sum("othervalue")).as("total")) + .orderBy($"window.start".asc).select("total"), + Seq(Row(4), Row(8))) + } + + test("negative timestamps") { + val df4 = Seq( + ("1970-01-01 00:00:02", 1), + ("1970-01-01 00:00:12", 2)).toDF("time", "value") + checkAnswer( + df4.select(window($"time", "10 seconds", "10 seconds", "5 seconds"), $"value") + .orderBy($"window.start".asc) + .select($"window.start".cast(StringType), $"window.end".cast(StringType), $"value"), + Seq( + Row("1969-12-31 23:59:55", "1970-01-01 00:00:05", 1), + Row("1970-01-01 00:00:05", "1970-01-01 00:00:15", 2)) + ) + } + + test("multiple time windows in a single operator throws nice exception") { + val df = Seq( + ("2016-03-27 09:00:02", 3), + ("2016-03-27 09:00:35", 6)).toDF("time", "value") + val e = intercept[AnalysisException] { + df.select(window($"time", "10 second"), window($"time", "15 second")).collect() + } + assert(e.getMessage.contains( + "Multiple time window expressions would result in a cartesian product")) + } + + test("aliased windows") { + val df = Seq( + ("2016-03-27 19:39:34", 1, Seq("a", "b")), + ("2016-03-27 19:39:56", 2, Seq("a", "c", "d"))).toDF("time", "value", "ids") + + checkAnswer( + df.select(window($"time", "10 seconds").as("time_window"), $"value") + .orderBy($"time_window.start".asc) + .select("value"), + Seq(Row(1), Row(2)) + ) + } + + test("millisecond precision sliding windows") { + val df = Seq( + ("2016-03-27 09:00:00.41", 3), + ("2016-03-27 09:00:00.62", 6), + ("2016-03-27 09:00:00.715", 8)).toDF("time", "value") + checkAnswer( + df.groupBy(window($"time", "200 milliseconds", "40 milliseconds", "0 milliseconds")) + .agg(count("*").as("counts")) + .orderBy($"window.start".asc) + .select($"window.start".cast(StringType), $"window.end".cast(StringType), $"counts"), + Seq( + Row("2016-03-27 09:00:00.24", "2016-03-27 09:00:00.44", 1), + Row("2016-03-27 09:00:00.28", "2016-03-27 09:00:00.48", 1), + Row("2016-03-27 09:00:00.32", "2016-03-27 09:00:00.52", 1), + Row("2016-03-27 09:00:00.36", "2016-03-27 09:00:00.56", 1), + Row("2016-03-27 09:00:00.4", "2016-03-27 09:00:00.6", 1), + Row("2016-03-27 09:00:00.44", "2016-03-27 09:00:00.64", 1), + Row("2016-03-27 09:00:00.48", "2016-03-27 09:00:00.68", 1), + Row("2016-03-27 09:00:00.52", "2016-03-27 09:00:00.72", 2), + Row("2016-03-27 09:00:00.56", "2016-03-27 09:00:00.76", 2), + Row("2016-03-27 09:00:00.6", "2016-03-27 09:00:00.8", 2), + Row("2016-03-27 09:00:00.64", "2016-03-27 09:00:00.84", 1), + Row("2016-03-27 09:00:00.68", "2016-03-27 09:00:00.88", 1)) + ) + } +} -- cgit v1.2.3 From 3e991dbc310a4a33eec7f3909adce50bf8268d04 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Fri, 1 Apr 2016 14:02:32 -0700 Subject: [SPARK-13674] [SQL] Add wholestage codegen support to Sample JIRA: https://issues.apache.org/jira/browse/SPARK-13674 ## What changes were proposed in this pull request? Sample operator doesn't support wholestage codegen now. This pr is to add support to it. ## How was this patch tested? A test is added into `BenchmarkWholeStageCodegen`. Besides, all tests should be passed. Author: Liang-Chi Hsieh Author: Liang-Chi Hsieh Closes #11517 from viirya/add-wholestage-sample. --- .../apache/spark/util/random/RandomSampler.scala | 2 +- project/MimaExcludes.scala | 4 ++ .../spark/sql/execution/BufferedRowIterator.java | 4 +- .../spark/sql/execution/WholeStageCodegen.scala | 12 ++-- .../spark/sql/execution/basicOperators.scala | 72 +++++++++++++++++++--- .../sql/execution/BenchmarkWholeStageCodegen.scala | 25 ++++++++ 6 files changed, 104 insertions(+), 15 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala b/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala index 2921b939bc..d397cca4b4 100644 --- a/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala +++ b/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala @@ -186,7 +186,7 @@ class BernoulliSampler[T: ClassTag](fraction: Double) extends RandomSampler[T, T * @tparam T item type */ @DeveloperApi -class PoissonSampler[T: ClassTag]( +class PoissonSampler[T]( fraction: Double, useGapSamplingIfPossible: Boolean) extends RandomSampler[T, T] { diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index ff11775412..2be490b942 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -597,6 +597,10 @@ object MimaExcludes { // for multilayer perceptron. // This class is marked as `private`. ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.ml.ann.SoftmaxFunction") + ) ++ Seq( + // [SPARK-13674][SQL] Add wholestage codegen support to Sample + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.util.random.PoissonSampler.this"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.util.random.PoissonSampler.this") ) case v if v.startsWith("1.6") => Seq( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/BufferedRowIterator.java b/sql/core/src/main/scala/org/apache/spark/sql/execution/BufferedRowIterator.java index dbea8521be..c2633a9f8c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/BufferedRowIterator.java +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/BufferedRowIterator.java @@ -36,6 +36,8 @@ public abstract class BufferedRowIterator { protected UnsafeRow unsafeRow = new UnsafeRow(0); private long startTimeNs = System.nanoTime(); + protected int partitionIndex = -1; + public boolean hasNext() throws IOException { if (currentRows.isEmpty()) { processNext(); @@ -58,7 +60,7 @@ public abstract class BufferedRowIterator { /** * Initializes from array of iterators of InternalRow. */ - public abstract void init(Iterator iters[]); + public abstract void init(int index, Iterator iters[]); /** * Append a row to currentRows. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala index 6a779abd40..9bdf611f6e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.execution -import org.apache.spark.broadcast +import org.apache.spark.{broadcast, TaskContext} import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ @@ -323,7 +323,8 @@ case class WholeStageCodegen(child: SparkPlan) extends UnaryNode with CodegenSup this.references = references; } - public void init(scala.collection.Iterator inputs[]) { + public void init(int index, scala.collection.Iterator inputs[]) { + partitionIndex = index; ${ctx.initMutableStates()} } @@ -351,10 +352,10 @@ case class WholeStageCodegen(child: SparkPlan) extends UnaryNode with CodegenSup val rdds = child.asInstanceOf[CodegenSupport].upstreams() assert(rdds.size <= 2, "Up to two upstream RDDs can be supported") if (rdds.length == 1) { - rdds.head.mapPartitions { iter => + rdds.head.mapPartitionsWithIndex { (index, iter) => val clazz = CodeGenerator.compile(cleanedSource) val buffer = clazz.generate(references).asInstanceOf[BufferedRowIterator] - buffer.init(Array(iter)) + buffer.init(index, Array(iter)) new Iterator[InternalRow] { override def hasNext: Boolean = { val v = buffer.hasNext @@ -367,9 +368,10 @@ case class WholeStageCodegen(child: SparkPlan) extends UnaryNode with CodegenSup } else { // Right now, we support up to two upstreams. rdds.head.zipPartitions(rdds(1)) { (leftIter, rightIter) => + val partitionIndex = TaskContext.getPartitionId() val clazz = CodeGenerator.compile(cleanedSource) val buffer = clazz.generate(references).asInstanceOf[BufferedRowIterator] - buffer.init(Array(leftIter, rightIter)) + buffer.init(partitionIndex, Array(leftIter, rightIter)) new Iterator[InternalRow] { override def hasNext: Boolean = { val v = buffer.hasNext diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index fca662760d..a6a14df6a3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -20,11 +20,11 @@ package org.apache.spark.sql.execution import org.apache.spark.rdd.{PartitionwiseSampledRDD, RDD} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, ExpressionCanonicalizer} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, ExpressionCanonicalizer, GenerateUnsafeProjection} import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.types.LongType -import org.apache.spark.util.random.PoissonSampler +import org.apache.spark.util.random.{BernoulliCellSampler, PoissonSampler} case class Project(projectList: Seq[NamedExpression], child: SparkPlan) extends UnaryNode with CodegenSupport { @@ -223,9 +223,12 @@ case class Sample( upperBound: Double, withReplacement: Boolean, seed: Long, - child: SparkPlan) extends UnaryNode { + child: SparkPlan) extends UnaryNode with CodegenSupport { override def output: Seq[Attribute] = child.output + private[sql] override lazy val metrics = Map( + "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows")) + protected override def doExecute(): RDD[InternalRow] = { if (withReplacement) { // Disable gap sampling since the gap sampling method buffers two rows internally, @@ -239,6 +242,63 @@ case class Sample( child.execute().randomSampleWithRange(lowerBound, upperBound, seed) } } + + override def upstreams(): Seq[RDD[InternalRow]] = { + child.asInstanceOf[CodegenSupport].upstreams() + } + + protected override def doProduce(ctx: CodegenContext): String = { + child.asInstanceOf[CodegenSupport].produce(ctx, this) + } + + override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = { + val numOutput = metricTerm(ctx, "numOutputRows") + val sampler = ctx.freshName("sampler") + + if (withReplacement) { + val samplerClass = classOf[PoissonSampler[UnsafeRow]].getName + val initSampler = ctx.freshName("initSampler") + ctx.addMutableState(s"$samplerClass", sampler, + s"$initSampler();") + + ctx.addNewFunction(initSampler, + s""" + | private void $initSampler() { + | $sampler = new $samplerClass($upperBound - $lowerBound, false); + | java.util.Random random = new java.util.Random(${seed}L); + | long randomSeed = random.nextLong(); + | int loopCount = 0; + | while (loopCount < partitionIndex) { + | randomSeed = random.nextLong(); + | loopCount += 1; + | } + | $sampler.setSeed(randomSeed); + | } + """.stripMargin.trim) + + val samplingCount = ctx.freshName("samplingCount") + s""" + | int $samplingCount = $sampler.sample(); + | while ($samplingCount-- > 0) { + | $numOutput.add(1); + | ${consume(ctx, input)} + | } + """.stripMargin.trim + } else { + val samplerClass = classOf[BernoulliCellSampler[UnsafeRow]].getName + ctx.addMutableState(s"$samplerClass", sampler, + s""" + | $sampler = new $samplerClass($lowerBound, $upperBound, false); + | $sampler.setSeed(${seed}L + partitionIndex); + """.stripMargin.trim) + + s""" + | if ($sampler.sample() == 0) continue; + | $numOutput.add(1); + | ${consume(ctx, input)} + """.stripMargin.trim + } + } } case class Range( @@ -320,11 +380,7 @@ case class Range( | // initialize Range | if (!$initTerm) { | $initTerm = true; - | if ($input.hasNext()) { - | initRange(((InternalRow) $input.next()).getInt(0)); - | } else { - | return; - | } + | initRange(partitionIndex); | } | | while (!$overflow && $checkEnd) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala index 003d3e062e..55906793c0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala @@ -85,6 +85,31 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite { */ } + ignore("range/sample/sum") { + val N = 500 << 20 + runBenchmark("range/sample/sum", N) { + sqlContext.range(N).sample(true, 0.01).groupBy().sum().collect() + } + /* + Westmere E56xx/L56xx/X56xx (Nehalem-C) + range/sample/sum: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------- + range/sample/sum codegen=false 53888 / 56592 9.7 102.8 1.0X + range/sample/sum codegen=true 41614 / 42607 12.6 79.4 1.3X + */ + + runBenchmark("range/sample/sum", N) { + sqlContext.range(N).sample(false, 0.01).groupBy().sum().collect() + } + /* + Westmere E56xx/L56xx/X56xx (Nehalem-C) + range/sample/sum: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------- + range/sample/sum codegen=false 12982 / 13384 40.4 24.8 1.0X + range/sample/sum codegen=true 7074 / 7383 74.1 13.5 1.8X + */ + } + ignore("stat functions") { val N = 100L << 20 -- cgit v1.2.3 From bd7b91cefb0d192d808778e6182dcdd2c143e132 Mon Sep 17 00:00:00 2001 From: zhonghaihua <793507405@qq.com> Date: Fri, 1 Apr 2016 16:23:14 -0500 Subject: [SPARK-12864][YARN] initialize executorIdCounter after ApplicationMaster killed for max n… MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Currently, when max number of executor failures reached the `maxNumExecutorFailures`, `ApplicationMaster` will be killed and re-register another one.This time, `YarnAllocator` will be created a new instance. But, the value of property `executorIdCounter` in `YarnAllocator` will reset to `0`. Then the Id of new executor will starting from `1`. This will confuse with the executor has already created before, which will cause FetchFailedException. This situation is just in yarn client mode, so this is an issue in yarn client mode. For more details, [link to jira issues SPARK-12864](https://issues.apache.org/jira/browse/SPARK-12864) This PR introduce a mechanism to initialize `executorIdCounter` after `ApplicationMaster` killed. Author: zhonghaihua <793507405@qq.com> Closes #10794 from zhonghaihua/initExecutorIdCounterAfterAMKilled. --- .../cluster/CoarseGrainedClusterMessage.scala | 2 ++ .../cluster/CoarseGrainedSchedulerBackend.scala | 6 ++++++ .../org/apache/spark/deploy/yarn/YarnAllocator.scala | 20 ++++++++++++++++++-- .../scheduler/cluster/YarnSchedulerBackend.scala | 3 +++ 4 files changed, 29 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala index 8d5c11dc36..46a829114e 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala @@ -30,6 +30,8 @@ private[spark] object CoarseGrainedClusterMessages { case object RetrieveSparkProps extends CoarseGrainedClusterMessage + case object RetrieveLastAllocatedExecutorId extends CoarseGrainedClusterMessage + // Driver to executors case class LaunchTask(data: SerializableBuffer) extends CoarseGrainedClusterMessage diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index eb4f5331d6..70470cc6d2 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -79,6 +79,9 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp // Executors that have been lost, but for which we don't yet know the real exit reason. protected val executorsPendingLossReason = new HashSet[String] + // The num of current max ExecutorId used to re-register appMaster + protected var currentExecutorIdCounter = 0 + class DriverEndpoint(override val rpcEnv: RpcEnv, sparkProperties: Seq[(String, String)]) extends ThreadSafeRpcEndpoint with Logging { @@ -156,6 +159,9 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp // in this block are read when requesting executors CoarseGrainedSchedulerBackend.this.synchronized { executorDataMap.put(executorId, data) + if (currentExecutorIdCounter < executorId.toInt) { + currentExecutorIdCounter = executorId.toInt + } if (numPendingExecutors > 0) { numPendingExecutors -= 1 logDebug(s"Decremented number of pending executors ($numPendingExecutors left)") diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala index 7d71a642f6..b0bfe855e9 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala @@ -40,6 +40,7 @@ import org.apache.spark.internal.config._ import org.apache.spark.rpc.{RpcCallContext, RpcEndpointRef} import org.apache.spark.scheduler.{ExecutorExited, ExecutorLossReason} import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages.RemoveExecutor +import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages.RetrieveLastAllocatedExecutorId import org.apache.spark.util.ThreadUtils /** @@ -83,8 +84,23 @@ private[yarn] class YarnAllocator( new ConcurrentHashMap[ContainerId, java.lang.Boolean]) @volatile private var numExecutorsRunning = 0 - // Used to generate a unique ID per executor - private var executorIdCounter = 0 + + /** + * Used to generate a unique ID per executor + * + * Init `executorIdCounter`. when AM restart, `executorIdCounter` will reset to 0. Then + * the id of new executor will start from 1, this will conflict with the executor has + * already created before. So, we should initialize the `executorIdCounter` by getting + * the max executorId from driver. + * + * And this situation of executorId conflict is just in yarn client mode, so this is an issue + * in yarn client mode. For more details, can check in jira. + * + * @see SPARK-12864 + */ + private var executorIdCounter: Int = + driverRef.askWithRetry[Int](RetrieveLastAllocatedExecutorId) + @volatile private var numExecutorsFailed = 0 @volatile private var targetNumExecutors = diff --git a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala index a8781636f2..5aeaf44732 100644 --- a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala +++ b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala @@ -292,6 +292,9 @@ private[spark] abstract class YarnSchedulerBackend( logWarning("Attempted to kill executors before the AM has registered!") context.reply(false) } + + case RetrieveLastAllocatedExecutorId => + context.reply(currentExecutorIdCounter) } override def onDisconnected(remoteAddress: RpcAddress): Unit = { -- cgit v1.2.3 From e41acb757327e3226ffe312766ec759c16616588 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Fri, 1 Apr 2016 14:34:59 -0700 Subject: [SPARK-13992] Add support for off-heap caching This patch adds support for caching blocks in the executor processes using direct / off-heap memory. ## User-facing changes **Updated semantics of `OFF_HEAP` storage level**: In Spark 1.x, the `OFF_HEAP` storage level indicated that an RDD should be cached in Tachyon. Spark 2.x removed the external block store API that Tachyon caching was based on (see #10752 / SPARK-12667), so `OFF_HEAP` became an alias for `MEMORY_ONLY_SER`. As of this patch, `OFF_HEAP` means "serialized and cached in off-heap memory or on disk". Via the `StorageLevel` constructor, `useOffHeap` can be set if `serialized == true` and can be used to construct custom storage levels which support replication. **Storage UI reporting**: the storage UI will now report whether in-memory blocks are stored on- or off-heap. **Only supported by UnifiedMemoryManager**: for simplicity, this feature is only supported when the default UnifiedMemoryManager is used; applications which use the legacy memory manager (`spark.memory.useLegacyMode=true`) are not currently able to allocate off-heap storage memory, so using off-heap caching will fail with an error when legacy memory management is enabled. Given that we plan to eventually remove the legacy memory manager, this is not a significant restriction. **Memory management policies:** the policies for dividing available memory between execution and storage are the same for both on- and off-heap memory. For off-heap memory, the total amount of memory available for use by Spark is controlled by `spark.memory.offHeap.size`, which is an absolute size. Off-heap storage memory obeys `spark.memory.storageFraction` in order to control the amount of unevictable storage memory. For example, if `spark.memory.offHeap.size` is 1 gigabyte and Spark uses the default `storageFraction` of 0.5, then up to 500 megabytes of off-heap cached blocks will be protected from eviction due to execution memory pressure. If necessary, we can split `spark.memory.storageFraction` into separate on- and off-heap configurations, but this doesn't seem necessary now and can be done later without any breaking changes. **Use of off-heap memory does not imply use of off-heap execution (or vice-versa)**: for now, the settings controlling the use of off-heap execution memory (`spark.memory.offHeap.enabled`) and off-heap caching are completely independent, so Spark SQL can be configured to use off-heap memory for execution while continuing to cache blocks on-heap. If desired, we can change this in a followup patch so that `spark.memory.offHeap.enabled` affect the default storage level for cached SQL tables. ## Internal changes - Rename `ByteArrayChunkOutputStream` to `ChunkedByteBufferOutputStream` - It now returns a `ChunkedByteBuffer` instead of an array of byte arrays. - Its constructor now accept an `allocator` function which is called to allocate `ByteBuffer`s. This allows us to control whether it allocates regular ByteBuffers or off-heap DirectByteBuffers. - Because block serialization is now performed during the unroll process, a `ChunkedByteBufferOutputStream` which is configured with a `DirectByteBuffer` allocator will use off-heap memory for both unroll and storage memory. - The `MemoryStore`'s MemoryEntries now tracks whether blocks are stored on- or off-heap. - `evictBlocksToFreeSpace()` now accepts a `MemoryMode` parameter so that we don't try to evict off-heap blocks in response to on-heap memory pressure (or vice-versa). - Make sure that off-heap buffers are properly de-allocated during MemoryStore eviction. - The JVM limits the total size of allocated direct byte buffers using the `-XX:MaxDirectMemorySize` flag and the default tends to be fairly low (< 512 megabytes in some JVMs). To work around this limitation, this patch adds a custom DirectByteBuffer allocator which ignores this memory limit. Author: Josh Rosen Closes #11805 from JoshRosen/off-heap-caching. --- .../java/org/apache/spark/unsafe/Platform.java | 32 +++++ .../apache/spark/broadcast/TorrentBroadcast.scala | 8 +- .../apache/spark/memory/StorageMemoryPool.scala | 22 ++- .../scala/org/apache/spark/scheduler/Task.scala | 5 +- .../spark/serializer/SerializerManager.scala | 16 +-- .../org/apache/spark/storage/BlockManager.scala | 70 ++++++---- .../spark/storage/BlockManagerMasterEndpoint.scala | 2 +- .../org/apache/spark/storage/StorageLevel.scala | 21 ++- .../apache/spark/storage/memory/MemoryStore.scala | 153 +++++++++++++++------ .../spark/util/io/ByteArrayChunkOutputStream.scala | 99 ------------- .../apache/spark/util/io/ChunkedByteBuffer.scala | 14 +- .../util/io/ChunkedByteBufferOutputStream.scala | 113 +++++++++++++++ .../scala/org/apache/spark/DistributedSuite.scala | 4 +- .../apache/spark/io/ChunkedByteBufferSuite.scala | 2 +- .../apache/spark/memory/MemoryManagerSuite.scala | 3 +- .../storage/BlockManagerReplicationSuite.scala | 24 +++- .../apache/spark/storage/BlockManagerSuite.scala | 32 ++++- .../apache/spark/storage/MemoryStoreSuite.scala | 22 +-- .../util/io/ByteArrayChunkOutputStreamSuite.scala | 109 --------------- .../io/ChunkedByteBufferOutputStreamSuite.scala | 114 +++++++++++++++ .../rdd/WriteAheadLogBackedBlockRDD.scala | 3 +- .../streaming/ReceivedBlockHandlerSuite.scala | 3 +- 22 files changed, 520 insertions(+), 351 deletions(-) delete mode 100644 core/src/main/scala/org/apache/spark/util/io/ByteArrayChunkOutputStream.scala create mode 100644 core/src/main/scala/org/apache/spark/util/io/ChunkedByteBufferOutputStream.scala delete mode 100644 core/src/test/scala/org/apache/spark/util/io/ByteArrayChunkOutputStreamSuite.scala create mode 100644 core/src/test/scala/org/apache/spark/util/io/ChunkedByteBufferOutputStreamSuite.scala diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java index 672552cc65..bdf52f32c6 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java @@ -17,9 +17,12 @@ package org.apache.spark.unsafe; +import java.lang.reflect.Constructor; import java.lang.reflect.Field; import java.lang.reflect.Method; +import java.nio.ByteBuffer; +import sun.misc.Cleaner; import sun.misc.Unsafe; public final class Platform { @@ -144,6 +147,35 @@ public final class Platform { return newMemory; } + /** + * Uses internal JDK APIs to allocate a DirectByteBuffer while ignoring the JVM's + * MaxDirectMemorySize limit (the default limit is too low and we do not want to require users + * to increase it). + */ + @SuppressWarnings("unchecked") + public static ByteBuffer allocateDirectBuffer(int size) { + try { + Class cls = Class.forName("java.nio.DirectByteBuffer"); + Constructor constructor = cls.getDeclaredConstructor(Long.TYPE, Integer.TYPE); + constructor.setAccessible(true); + Field cleanerField = cls.getDeclaredField("cleaner"); + cleanerField.setAccessible(true); + final long memory = allocateMemory(size); + ByteBuffer buffer = (ByteBuffer) constructor.newInstance(memory, size); + Cleaner cleaner = Cleaner.create(buffer, new Runnable() { + @Override + public void run() { + freeMemory(memory); + } + }); + cleanerField.set(buffer, cleaner); + return buffer; + } catch (Exception e) { + throwException(e); + } + throw new IllegalStateException("unreachable"); + } + public static void setMemory(long address, byte value, long size) { _UNSAFE.setMemory(address, size, value); } diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala index e5e6a9e4a8..632b0ae9c2 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala @@ -30,7 +30,7 @@ import org.apache.spark.io.CompressionCodec import org.apache.spark.serializer.Serializer import org.apache.spark.storage.{BlockId, BroadcastBlockId, StorageLevel} import org.apache.spark.util.{ByteBufferInputStream, Utils} -import org.apache.spark.util.io.{ByteArrayChunkOutputStream, ChunkedByteBuffer} +import org.apache.spark.util.io.{ChunkedByteBuffer, ChunkedByteBufferOutputStream} /** * A BitTorrent-like implementation of [[org.apache.spark.broadcast.Broadcast]]. @@ -228,12 +228,12 @@ private object TorrentBroadcast extends Logging { blockSize: Int, serializer: Serializer, compressionCodec: Option[CompressionCodec]): Array[ByteBuffer] = { - val bos = new ByteArrayChunkOutputStream(blockSize) - val out: OutputStream = compressionCodec.map(c => c.compressedOutputStream(bos)).getOrElse(bos) + val cbbos = new ChunkedByteBufferOutputStream(blockSize, ByteBuffer.allocate) + val out = compressionCodec.map(c => c.compressedOutputStream(cbbos)).getOrElse(cbbos) val ser = serializer.newInstance() val serOut = ser.serializeStream(out) serOut.writeObject[T](obj).close() - bos.toArrays.map(ByteBuffer.wrap) + cbbos.toChunkedByteBuffer.getChunks() } def unBlockifyObject[T: ClassTag]( diff --git a/core/src/main/scala/org/apache/spark/memory/StorageMemoryPool.scala b/core/src/main/scala/org/apache/spark/memory/StorageMemoryPool.scala index a67e8da26b..0b552cabfc 100644 --- a/core/src/main/scala/org/apache/spark/memory/StorageMemoryPool.scala +++ b/core/src/main/scala/org/apache/spark/memory/StorageMemoryPool.scala @@ -35,6 +35,11 @@ private[memory] class StorageMemoryPool( memoryMode: MemoryMode ) extends MemoryPool(lock) with Logging { + private[this] val poolName: String = memoryMode match { + case MemoryMode.ON_HEAP => "on-heap storage" + case MemoryMode.OFF_HEAP => "off-heap storage" + } + @GuardedBy("lock") private[this] var _memoryUsed: Long = 0L @@ -60,7 +65,7 @@ private[memory] class StorageMemoryPool( /** * Acquire N bytes of memory to cache the given block, evicting existing ones if necessary. - * + * * @return whether all N bytes were successfully granted. */ def acquireMemory(blockId: BlockId, numBytes: Long): Boolean = lock.synchronized { @@ -83,9 +88,8 @@ private[memory] class StorageMemoryPool( assert(numBytesToAcquire >= 0) assert(numBytesToFree >= 0) assert(memoryUsed <= poolSize) - // Once we support off-heap caching, this will need to change: - if (numBytesToFree > 0 && memoryMode == MemoryMode.ON_HEAP) { - memoryStore.evictBlocksToFreeSpace(Some(blockId), numBytesToFree) + if (numBytesToFree > 0) { + memoryStore.evictBlocksToFreeSpace(Some(blockId), numBytesToFree, memoryMode) } // NOTE: If the memory store evicts blocks, then those evictions will synchronously call // back into this StorageMemoryPool in order to free memory. Therefore, these variables @@ -122,14 +126,8 @@ private[memory] class StorageMemoryPool( val remainingSpaceToFree = spaceToFree - spaceFreedByReleasingUnusedMemory if (remainingSpaceToFree > 0) { // If reclaiming free memory did not adequately shrink the pool, begin evicting blocks: - val spaceFreedByEviction = { - // Once we support off-heap caching, this will need to change: - if (memoryMode == MemoryMode.ON_HEAP) { - memoryStore.evictBlocksToFreeSpace(None, remainingSpaceToFree) - } else { - 0 - } - } + val spaceFreedByEviction = + memoryStore.evictBlocksToFreeSpace(None, remainingSpaceToFree, memoryMode) // When a block is released, BlockManager.dropFromMemory() calls releaseMemory(), so we do // not need to decrement _memoryUsed here. However, we do need to decrement the pool size. decrementPoolSize(spaceFreedByEviction) diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala index d2b8ca90a9..46c64f61de 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala @@ -24,7 +24,7 @@ import scala.collection.mutable.HashMap import org.apache.spark.{Accumulator, SparkEnv, TaskContext, TaskContextImpl} import org.apache.spark.executor.TaskMetrics -import org.apache.spark.memory.TaskMemoryManager +import org.apache.spark.memory.{MemoryMode, TaskMemoryManager} import org.apache.spark.metrics.MetricsSystem import org.apache.spark.serializer.SerializerInstance import org.apache.spark.util.{ByteBufferInputStream, ByteBufferOutputStream, Utils} @@ -90,7 +90,8 @@ private[spark] abstract class Task[T]( try { Utils.tryLogNonFatalError { // Release memory used by this thread for unrolling blocks - SparkEnv.get.blockManager.memoryStore.releaseUnrollMemoryForThisTask() + SparkEnv.get.blockManager.memoryStore.releaseUnrollMemoryForThisTask(MemoryMode.ON_HEAP) + SparkEnv.get.blockManager.memoryStore.releaseUnrollMemoryForThisTask(MemoryMode.OFF_HEAP) // Notify any tasks waiting for execution memory to be freed to wake up and try to // acquire memory again. This makes impossible the scenario where a task sleeps forever // because there are no other tasks left to notify it. Since this is safe to do but may diff --git a/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala b/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala index 27e5fa4c2b..745ef12691 100644 --- a/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala +++ b/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala @@ -25,7 +25,7 @@ import scala.reflect.ClassTag import org.apache.spark.SparkConf import org.apache.spark.io.CompressionCodec import org.apache.spark.storage._ -import org.apache.spark.util.io.{ByteArrayChunkOutputStream, ChunkedByteBuffer} +import org.apache.spark.util.io.{ChunkedByteBuffer, ChunkedByteBufferOutputStream} /** * Component which configures serialization and compression for various Spark components, including @@ -128,17 +128,9 @@ private[spark] class SerializerManager(defaultSerializer: Serializer, conf: Spar /** Serializes into a chunked byte buffer. */ def dataSerialize[T: ClassTag](blockId: BlockId, values: Iterator[T]): ChunkedByteBuffer = { - val byteArrayChunkOutputStream = new ByteArrayChunkOutputStream(1024 * 1024 * 4) - dataSerializeStream(blockId, byteArrayChunkOutputStream, values) - new ChunkedByteBuffer(byteArrayChunkOutputStream.toArrays.map(ByteBuffer.wrap)) - } - - /** - * Deserializes a ByteBuffer into an iterator of values and disposes of it when the end of - * the iterator is reached. - */ - def dataDeserialize[T: ClassTag](blockId: BlockId, bytes: ChunkedByteBuffer): Iterator[T] = { - dataDeserializeStream[T](blockId, bytes.toInputStream(dispose = true)) + val bbos = new ChunkedByteBufferOutputStream(1024 * 1024 * 4, ByteBuffer.allocate) + dataSerializeStream(blockId, bbos, values) + bbos.toChunkedByteBuffer } /** diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 3014cafc28..9608418b43 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -18,6 +18,7 @@ package org.apache.spark.storage import java.io._ +import java.nio.ByteBuffer import scala.collection.mutable.{ArrayBuffer, HashMap} import scala.concurrent.{Await, ExecutionContext, Future} @@ -39,6 +40,7 @@ import org.apache.spark.rpc.RpcEnv import org.apache.spark.serializer.{SerializerInstance, SerializerManager} import org.apache.spark.shuffle.ShuffleManager import org.apache.spark.storage.memory._ +import org.apache.spark.unsafe.Platform import org.apache.spark.util._ import org.apache.spark.util.io.ChunkedByteBuffer @@ -372,8 +374,12 @@ private[spark] class BlockManager( val onDisk = level.useDisk && diskStore.contains(blockId) val deserialized = if (inMem) level.deserialized else false val replication = if (inMem || onDisk) level.replication else 1 - val storageLevel = - StorageLevel(onDisk, inMem, deserialized, replication) + val storageLevel = StorageLevel( + useDisk = onDisk, + useMemory = inMem, + useOffHeap = level.useOffHeap, + deserialized = deserialized, + replication = replication) val memSize = if (inMem) memoryStore.getSize(blockId) else 0L val diskSize = if (onDisk) diskStore.getSize(blockId) else 0L BlockStatus(storageLevel, memSize, diskSize) @@ -407,8 +413,8 @@ private[spark] class BlockManager( val iter: Iterator[Any] = if (level.deserialized) { memoryStore.getValues(blockId).get } else { - serializerManager.dataDeserialize( - blockId, memoryStore.getBytes(blockId).get)(info.classTag) + serializerManager.dataDeserializeStream( + blockId, memoryStore.getBytes(blockId).get.toInputStream())(info.classTag) } val ci = CompletionIterator[Any, Iterator[Any]](iter, releaseLock(blockId)) Some(new BlockResult(ci, DataReadMethod.Memory, info.size)) @@ -416,11 +422,15 @@ private[spark] class BlockManager( val iterToReturn: Iterator[Any] = { val diskBytes = diskStore.getBytes(blockId) if (level.deserialized) { - val diskValues = serializerManager.dataDeserialize(blockId, diskBytes)(info.classTag) + val diskValues = serializerManager.dataDeserializeStream( + blockId, + diskBytes.toInputStream(dispose = true))(info.classTag) maybeCacheDiskValuesInMemory(info, blockId, level, diskValues) } else { - val bytes = maybeCacheDiskBytesInMemory(info, blockId, level, diskBytes) - serializerManager.dataDeserialize(blockId, bytes)(info.classTag) + val stream = maybeCacheDiskBytesInMemory(info, blockId, level, diskBytes) + .map {_.toInputStream(dispose = false)} + .getOrElse { diskBytes.toInputStream(dispose = true) } + serializerManager.dataDeserializeStream(blockId, stream)(info.classTag) } } val ci = CompletionIterator[Any, Iterator[Any]](iterToReturn, releaseLock(blockId)) @@ -481,7 +491,8 @@ private[spark] class BlockManager( if (level.useMemory && memoryStore.contains(blockId)) { memoryStore.getBytes(blockId).get } else if (level.useDisk && diskStore.contains(blockId)) { - maybeCacheDiskBytesInMemory(info, blockId, level, diskStore.getBytes(blockId)) + val diskBytes = diskStore.getBytes(blockId) + maybeCacheDiskBytesInMemory(info, blockId, level, diskBytes).getOrElse(diskBytes) } else { releaseLock(blockId) throw new SparkException(s"Block $blockId was not found even though it's read-locked") @@ -496,8 +507,9 @@ private[spark] class BlockManager( */ private def getRemoteValues(blockId: BlockId): Option[BlockResult] = { getRemoteBytes(blockId).map { data => - new BlockResult( - serializerManager.dataDeserialize(blockId, data), DataReadMethod.Network, data.size) + val values = + serializerManager.dataDeserializeStream(blockId, data.toInputStream(dispose = true)) + new BlockResult(values, DataReadMethod.Network, data.size) } } @@ -745,7 +757,8 @@ private[spark] class BlockManager( // Put it in memory first, even if it also has useDisk set to true; // We will drop it to disk later if the memory store can't hold it. val putSucceeded = if (level.deserialized) { - val values = serializerManager.dataDeserialize(blockId, bytes)(classTag) + val values = + serializerManager.dataDeserializeStream(blockId, bytes.toInputStream())(classTag) memoryStore.putIteratorAsValues(blockId, values, classTag) match { case Right(_) => true case Left(iter) => @@ -755,7 +768,7 @@ private[spark] class BlockManager( false } } else { - memoryStore.putBytes(blockId, size, () => bytes) + memoryStore.putBytes(blockId, size, level.memoryMode, () => bytes) } if (!putSucceeded && level.useDisk) { logWarning(s"Persisting block $blockId to disk instead.") @@ -893,7 +906,7 @@ private[spark] class BlockManager( } } } else { // !level.deserialized - memoryStore.putIteratorAsBytes(blockId, iterator(), classTag) match { + memoryStore.putIteratorAsBytes(blockId, iterator(), classTag, level.memoryMode) match { case Right(s) => size = s case Left(partiallySerializedValues) => @@ -951,14 +964,16 @@ private[spark] class BlockManager( * Attempts to cache spilled bytes read from disk into the MemoryStore in order to speed up * subsequent reads. This method requires the caller to hold a read lock on the block. * - * @return a copy of the bytes. The original bytes passed this method should no longer - * be used after this method returns. + * @return a copy of the bytes from the memory store if the put succeeded, otherwise None. + * If this returns bytes from the memory store then the original disk store bytes will + * automatically be disposed and the caller should not continue to use them. Otherwise, + * if this returns None then the original disk store bytes will be unaffected. */ private def maybeCacheDiskBytesInMemory( blockInfo: BlockInfo, blockId: BlockId, level: StorageLevel, - diskBytes: ChunkedByteBuffer): ChunkedByteBuffer = { + diskBytes: ChunkedByteBuffer): Option[ChunkedByteBuffer] = { require(!level.deserialized) if (level.useMemory) { // Synchronize on blockInfo to guard against a race condition where two readers both try to @@ -966,25 +981,29 @@ private[spark] class BlockManager( blockInfo.synchronized { if (memoryStore.contains(blockId)) { diskBytes.dispose() - memoryStore.getBytes(blockId).get + Some(memoryStore.getBytes(blockId).get) } else { - val putSucceeded = memoryStore.putBytes(blockId, diskBytes.size, () => { + val allocator = level.memoryMode match { + case MemoryMode.ON_HEAP => ByteBuffer.allocate _ + case MemoryMode.OFF_HEAP => Platform.allocateDirectBuffer _ + } + val putSucceeded = memoryStore.putBytes(blockId, diskBytes.size, level.memoryMode, () => { // https://issues.apache.org/jira/browse/SPARK-6076 // If the file size is bigger than the free memory, OOM will happen. So if we // cannot put it into MemoryStore, copyForMemory should not be created. That's why // this action is put into a `() => ChunkedByteBuffer` and created lazily. - diskBytes.copy() + diskBytes.copy(allocator) }) if (putSucceeded) { diskBytes.dispose() - memoryStore.getBytes(blockId).get + Some(memoryStore.getBytes(blockId).get) } else { - diskBytes + None } } } } else { - diskBytes + None } } @@ -1055,7 +1074,12 @@ private[spark] class BlockManager( val peersForReplication = new ArrayBuffer[BlockManagerId] val peersReplicatedTo = new ArrayBuffer[BlockManagerId] val peersFailedToReplicateTo = new ArrayBuffer[BlockManagerId] - val tLevel = StorageLevel(level.useDisk, level.useMemory, level.deserialized, 1) + val tLevel = StorageLevel( + useDisk = level.useDisk, + useMemory = level.useMemory, + useOffHeap = level.useOffHeap, + deserialized = level.deserialized, + replication = 1) val startTime = System.currentTimeMillis val random = new Random(blockId.hashCode) diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala index d2a5c69e15..8fa1215011 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala @@ -453,7 +453,7 @@ private[spark] class BlockManagerInfo( } if (storageLevel.isValid) { - /* isValid means it is either stored in-memory, on-disk or on-externalBlockStore. + /* isValid means it is either stored in-memory or on-disk. * The memSize here indicates the data size in or dropped from memory, * externalBlockStoreSize here indicates the data size in or dropped from externalBlockStore, * and the diskSize here indicates the data size in or dropped to disk. diff --git a/core/src/main/scala/org/apache/spark/storage/StorageLevel.scala b/core/src/main/scala/org/apache/spark/storage/StorageLevel.scala index 7d23295e25..216ec07934 100644 --- a/core/src/main/scala/org/apache/spark/storage/StorageLevel.scala +++ b/core/src/main/scala/org/apache/spark/storage/StorageLevel.scala @@ -60,10 +60,7 @@ class StorageLevel private( assert(replication < 40, "Replication restricted to be less than 40 for calculating hash codes") if (useOffHeap) { - require(!useDisk, "Off-heap storage level does not support using disk") - require(!useMemory, "Off-heap storage level does not support using heap memory") require(!deserialized, "Off-heap storage level does not support deserialized storage") - require(replication == 1, "Off-heap storage level does not support multiple replication") } private[spark] def memoryMode: MemoryMode = { @@ -86,7 +83,7 @@ class StorageLevel private( false } - def isValid: Boolean = (useMemory || useDisk || useOffHeap) && (replication > 0) + def isValid: Boolean = (useMemory || useDisk) && (replication > 0) def toInt: Int = { var ret = 0 @@ -123,7 +120,8 @@ class StorageLevel private( private def readResolve(): Object = StorageLevel.getCachedStorageLevel(this) override def toString: String = { - s"StorageLevel($useDisk, $useMemory, $useOffHeap, $deserialized, $replication)" + s"StorageLevel(disk=$useDisk, memory=$useMemory, offheap=$useOffHeap, " + + s"deserialized=$deserialized, replication=$replication)" } override def hashCode(): Int = toInt * 41 + replication @@ -131,8 +129,9 @@ class StorageLevel private( def description: String = { var result = "" result += (if (useDisk) "Disk " else "") - result += (if (useMemory) "Memory " else "") - result += (if (useOffHeap) "ExternalBlockStore " else "") + if (useMemory) { + result += (if (useOffHeap) "Memory (off heap) " else "Memory ") + } result += (if (deserialized) "Deserialized " else "Serialized ") result += s"${replication}x Replicated" result @@ -156,9 +155,7 @@ object StorageLevel { val MEMORY_AND_DISK_2 = new StorageLevel(true, true, false, true, 2) val MEMORY_AND_DISK_SER = new StorageLevel(true, true, false, false) val MEMORY_AND_DISK_SER_2 = new StorageLevel(true, true, false, false, 2) - - // Redirect to MEMORY_ONLY_SER for now. - val OFF_HEAP = MEMORY_ONLY_SER + val OFF_HEAP = new StorageLevel(true, true, true, false, 1) /** * :: DeveloperApi :: @@ -183,7 +180,7 @@ object StorageLevel { /** * :: DeveloperApi :: - * Create a new StorageLevel object without setting useOffHeap. + * Create a new StorageLevel object. */ @DeveloperApi def apply( @@ -198,7 +195,7 @@ object StorageLevel { /** * :: DeveloperApi :: - * Create a new StorageLevel object. + * Create a new StorageLevel object without setting useOffHeap. */ @DeveloperApi def apply( diff --git a/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala b/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala index 3ca41f32c1..df38d11e43 100644 --- a/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala @@ -32,20 +32,25 @@ import org.apache.spark.internal.Logging import org.apache.spark.memory.{MemoryManager, MemoryMode} import org.apache.spark.serializer.{SerializationStream, SerializerManager} import org.apache.spark.storage.{BlockId, BlockInfoManager, StorageLevel} +import org.apache.spark.unsafe.Platform import org.apache.spark.util.{CompletionIterator, SizeEstimator, Utils} import org.apache.spark.util.collection.SizeTrackingVector -import org.apache.spark.util.io.{ByteArrayChunkOutputStream, ChunkedByteBuffer} +import org.apache.spark.util.io.{ChunkedByteBuffer, ChunkedByteBufferOutputStream} private sealed trait MemoryEntry[T] { def size: Long + def memoryMode: MemoryMode def classTag: ClassTag[T] } private case class DeserializedMemoryEntry[T]( value: Array[T], size: Long, - classTag: ClassTag[T]) extends MemoryEntry[T] + classTag: ClassTag[T]) extends MemoryEntry[T] { + val memoryMode: MemoryMode = MemoryMode.ON_HEAP +} private case class SerializedMemoryEntry[T]( buffer: ChunkedByteBuffer, + memoryMode: MemoryMode, classTag: ClassTag[T]) extends MemoryEntry[T] { def size: Long = buffer.size } @@ -86,7 +91,10 @@ private[spark] class MemoryStore( // A mapping from taskAttemptId to amount of memory used for unrolling a block (in bytes) // All accesses of this map are assumed to have manually synchronized on `memoryManager` - private val unrollMemoryMap = mutable.HashMap[Long, Long]() + private val onHeapUnrollMemoryMap = mutable.HashMap[Long, Long]() + // Note: off-heap unroll memory is only used in putIteratorAsBytes() because off-heap caching + // always stores serialized values. + private val offHeapUnrollMemoryMap = mutable.HashMap[Long, Long]() // Initial memory to request before unrolling any block private val unrollMemoryThreshold: Long = @@ -131,13 +139,14 @@ private[spark] class MemoryStore( def putBytes[T: ClassTag]( blockId: BlockId, size: Long, + memoryMode: MemoryMode, _bytes: () => ChunkedByteBuffer): Boolean = { require(!contains(blockId), s"Block $blockId is already present in the MemoryStore") - if (memoryManager.acquireStorageMemory(blockId, size, MemoryMode.ON_HEAP)) { + if (memoryManager.acquireStorageMemory(blockId, size, memoryMode)) { // We acquired enough memory for the block, so go ahead and put it val bytes = _bytes() assert(bytes.size == size) - val entry = new SerializedMemoryEntry[T](bytes, implicitly[ClassTag[T]]) + val entry = new SerializedMemoryEntry[T](bytes, memoryMode, implicitly[ClassTag[T]]) entries.synchronized { entries.put(blockId, entry) } @@ -190,7 +199,8 @@ private[spark] class MemoryStore( var vector = new SizeTrackingVector[T]()(classTag) // Request enough memory to begin unrolling - keepUnrolling = reserveUnrollMemoryForThisTask(blockId, initialMemoryThreshold) + keepUnrolling = + reserveUnrollMemoryForThisTask(blockId, initialMemoryThreshold, MemoryMode.ON_HEAP) if (!keepUnrolling) { logWarning(s"Failed to reserve initial memory threshold of " + @@ -207,7 +217,8 @@ private[spark] class MemoryStore( val currentSize = vector.estimateSize() if (currentSize >= memoryThreshold) { val amountToRequest = (currentSize * memoryGrowthFactor - memoryThreshold).toLong - keepUnrolling = reserveUnrollMemoryForThisTask(blockId, amountToRequest) + keepUnrolling = + reserveUnrollMemoryForThisTask(blockId, amountToRequest, MemoryMode.ON_HEAP) if (keepUnrolling) { unrollMemoryUsedByThisBlock += amountToRequest } @@ -228,7 +239,7 @@ private[spark] class MemoryStore( def transferUnrollToStorage(amount: Long): Unit = { // Synchronize so that transfer is atomic memoryManager.synchronized { - releaseUnrollMemoryForThisTask(amount) + releaseUnrollMemoryForThisTask(MemoryMode.ON_HEAP, amount) val success = memoryManager.acquireStorageMemory(blockId, amount, MemoryMode.ON_HEAP) assert(success, "transferring unroll memory to storage memory failed") } @@ -247,7 +258,7 @@ private[spark] class MemoryStore( // If this task attempt already owns more unroll memory than is necessary to store the // block, then release the extra memory that will not be used. val excessUnrollMemory = unrollMemoryUsedByThisBlock - size - releaseUnrollMemoryForThisTask(excessUnrollMemory) + releaseUnrollMemoryForThisTask(MemoryMode.ON_HEAP, excessUnrollMemory) transferUnrollToStorage(size) true } @@ -295,10 +306,16 @@ private[spark] class MemoryStore( private[storage] def putIteratorAsBytes[T]( blockId: BlockId, values: Iterator[T], - classTag: ClassTag[T]): Either[PartiallySerializedBlock[T], Long] = { + classTag: ClassTag[T], + memoryMode: MemoryMode): Either[PartiallySerializedBlock[T], Long] = { require(!contains(blockId), s"Block $blockId is already present in the MemoryStore") + val allocator = memoryMode match { + case MemoryMode.ON_HEAP => ByteBuffer.allocate _ + case MemoryMode.OFF_HEAP => Platform.allocateDirectBuffer _ + } + // Whether there is still enough memory for us to continue unrolling this block var keepUnrolling = true // Initial per-task memory to request for unrolling blocks (bytes). @@ -307,15 +324,15 @@ private[spark] class MemoryStore( var unrollMemoryUsedByThisBlock = 0L // Underlying buffer for unrolling the block val redirectableStream = new RedirectableOutputStream - val byteArrayChunkOutputStream = new ByteArrayChunkOutputStream(initialMemoryThreshold.toInt) - redirectableStream.setOutputStream(byteArrayChunkOutputStream) + val bbos = new ChunkedByteBufferOutputStream(initialMemoryThreshold.toInt, allocator) + redirectableStream.setOutputStream(bbos) val serializationStream: SerializationStream = { val ser = serializerManager.getSerializer(classTag).newInstance() ser.serializeStream(serializerManager.wrapForCompression(blockId, redirectableStream)) } // Request enough memory to begin unrolling - keepUnrolling = reserveUnrollMemoryForThisTask(blockId, initialMemoryThreshold) + keepUnrolling = reserveUnrollMemoryForThisTask(blockId, initialMemoryThreshold, memoryMode) if (!keepUnrolling) { logWarning(s"Failed to reserve initial memory threshold of " + @@ -325,9 +342,9 @@ private[spark] class MemoryStore( } def reserveAdditionalMemoryIfNecessary(): Unit = { - if (byteArrayChunkOutputStream.size > unrollMemoryUsedByThisBlock) { - val amountToRequest = byteArrayChunkOutputStream.size - unrollMemoryUsedByThisBlock - keepUnrolling = reserveUnrollMemoryForThisTask(blockId, amountToRequest) + if (bbos.size > unrollMemoryUsedByThisBlock) { + val amountToRequest = bbos.size - unrollMemoryUsedByThisBlock + keepUnrolling = reserveUnrollMemoryForThisTask(blockId, amountToRequest, memoryMode) if (keepUnrolling) { unrollMemoryUsedByThisBlock += amountToRequest } @@ -349,12 +366,11 @@ private[spark] class MemoryStore( } if (keepUnrolling) { - val entry = SerializedMemoryEntry[T]( - new ChunkedByteBuffer(byteArrayChunkOutputStream.toArrays.map(ByteBuffer.wrap)), classTag) + val entry = SerializedMemoryEntry[T](bbos.toChunkedByteBuffer, memoryMode, classTag) // Synchronize so that transfer is atomic memoryManager.synchronized { - releaseUnrollMemoryForThisTask(unrollMemoryUsedByThisBlock) - val success = memoryManager.acquireStorageMemory(blockId, entry.size, MemoryMode.ON_HEAP) + releaseUnrollMemoryForThisTask(memoryMode, unrollMemoryUsedByThisBlock) + val success = memoryManager.acquireStorageMemory(blockId, entry.size, memoryMode) assert(success, "transferring unroll memory to storage memory failed") } entries.synchronized { @@ -365,7 +381,7 @@ private[spark] class MemoryStore( Right(entry.size) } else { // We ran out of space while unrolling the values for this block - logUnrollFailureMessage(blockId, byteArrayChunkOutputStream.size) + logUnrollFailureMessage(blockId, bbos.size) Left( new PartiallySerializedBlock( this, @@ -374,7 +390,8 @@ private[spark] class MemoryStore( serializationStream, redirectableStream, unrollMemoryUsedByThisBlock, - new ChunkedByteBuffer(byteArrayChunkOutputStream.toArrays.map(ByteBuffer.wrap)), + memoryMode, + bbos.toChunkedByteBuffer, values, classTag)) } @@ -386,7 +403,7 @@ private[spark] class MemoryStore( case null => None case e: DeserializedMemoryEntry[_] => throw new IllegalArgumentException("should only call getBytes on serialized blocks") - case SerializedMemoryEntry(bytes, _) => Some(bytes) + case SerializedMemoryEntry(bytes, _, _) => Some(bytes) } } @@ -407,8 +424,12 @@ private[spark] class MemoryStore( entries.remove(blockId) } if (entry != null) { - memoryManager.releaseStorageMemory(entry.size, MemoryMode.ON_HEAP) - logInfo(s"Block $blockId of size ${entry.size} dropped " + + entry match { + case SerializedMemoryEntry(buffer, _, _) => buffer.dispose() + case _ => + } + memoryManager.releaseStorageMemory(entry.size, entry.memoryMode) + logDebug(s"Block $blockId of size ${entry.size} dropped " + s"from memory (free ${maxMemory - blocksMemoryUsed})") true } else { @@ -420,7 +441,8 @@ private[spark] class MemoryStore( entries.synchronized { entries.clear() } - unrollMemoryMap.clear() + onHeapUnrollMemoryMap.clear() + offHeapUnrollMemoryMap.clear() memoryManager.releaseAllStorageMemory() logInfo("MemoryStore cleared") } @@ -440,16 +462,20 @@ private[spark] class MemoryStore( * * @param blockId the ID of the block we are freeing space for, if any * @param space the size of this block + * @param memoryMode the type of memory to free (on- or off-heap) * @return the amount of memory (in bytes) freed by eviction */ - private[spark] def evictBlocksToFreeSpace(blockId: Option[BlockId], space: Long): Long = { + private[spark] def evictBlocksToFreeSpace( + blockId: Option[BlockId], + space: Long, + memoryMode: MemoryMode): Long = { assert(space > 0) memoryManager.synchronized { var freedMemory = 0L val rddToAdd = blockId.flatMap(getRddId) val selectedBlocks = new ArrayBuffer[BlockId] - def blockIsEvictable(blockId: BlockId): Boolean = { - rddToAdd.isEmpty || rddToAdd != getRddId(blockId) + def blockIsEvictable(blockId: BlockId, entry: MemoryEntry[_]): Boolean = { + entry.memoryMode == memoryMode && (rddToAdd.isEmpty || rddToAdd != getRddId(blockId)) } // This is synchronized to ensure that the set of entries is not changed // (because of getValue or getBytes) while traversing the iterator, as that @@ -459,7 +485,8 @@ private[spark] class MemoryStore( while (freedMemory < space && iterator.hasNext) { val pair = iterator.next() val blockId = pair.getKey - if (blockIsEvictable(blockId)) { + val entry = pair.getValue + if (blockIsEvictable(blockId, entry)) { // We don't want to evict blocks which are currently being read, so we need to obtain // an exclusive write lock on blocks which are candidates for eviction. We perform a // non-blocking "tryLock" here in order to ignore blocks which are locked for reading: @@ -474,7 +501,7 @@ private[spark] class MemoryStore( def dropBlock[T](blockId: BlockId, entry: MemoryEntry[T]): Unit = { val data = entry match { case DeserializedMemoryEntry(values, _, _) => Left(values) - case SerializedMemoryEntry(buffer, _) => Right(buffer) + case SerializedMemoryEntry(buffer, _, _) => Right(buffer) } val newEffectiveStorageLevel = blockEvictionHandler.dropFromMemory(blockId, () => data)(entry.classTag) @@ -530,11 +557,18 @@ private[spark] class MemoryStore( * * @return whether the request is granted. */ - def reserveUnrollMemoryForThisTask(blockId: BlockId, memory: Long): Boolean = { + def reserveUnrollMemoryForThisTask( + blockId: BlockId, + memory: Long, + memoryMode: MemoryMode): Boolean = { memoryManager.synchronized { - val success = memoryManager.acquireUnrollMemory(blockId, memory, MemoryMode.ON_HEAP) + val success = memoryManager.acquireUnrollMemory(blockId, memory, memoryMode) if (success) { val taskAttemptId = currentTaskAttemptId() + val unrollMemoryMap = memoryMode match { + case MemoryMode.ON_HEAP => onHeapUnrollMemoryMap + case MemoryMode.OFF_HEAP => offHeapUnrollMemoryMap + } unrollMemoryMap(taskAttemptId) = unrollMemoryMap.getOrElse(taskAttemptId, 0L) + memory } success @@ -545,9 +579,13 @@ private[spark] class MemoryStore( * Release memory used by this task for unrolling blocks. * If the amount is not specified, remove the current task's allocation altogether. */ - def releaseUnrollMemoryForThisTask(memory: Long = Long.MaxValue): Unit = { + def releaseUnrollMemoryForThisTask(memoryMode: MemoryMode, memory: Long = Long.MaxValue): Unit = { val taskAttemptId = currentTaskAttemptId() memoryManager.synchronized { + val unrollMemoryMap = memoryMode match { + case MemoryMode.ON_HEAP => onHeapUnrollMemoryMap + case MemoryMode.OFF_HEAP => offHeapUnrollMemoryMap + } if (unrollMemoryMap.contains(taskAttemptId)) { val memoryToRelease = math.min(memory, unrollMemoryMap(taskAttemptId)) if (memoryToRelease > 0) { @@ -555,7 +593,7 @@ private[spark] class MemoryStore( if (unrollMemoryMap(taskAttemptId) == 0) { unrollMemoryMap.remove(taskAttemptId) } - memoryManager.releaseUnrollMemory(memoryToRelease, MemoryMode.ON_HEAP) + memoryManager.releaseUnrollMemory(memoryToRelease, memoryMode) } } } @@ -565,20 +603,23 @@ private[spark] class MemoryStore( * Return the amount of memory currently occupied for unrolling blocks across all tasks. */ def currentUnrollMemory: Long = memoryManager.synchronized { - unrollMemoryMap.values.sum + onHeapUnrollMemoryMap.values.sum + offHeapUnrollMemoryMap.values.sum } /** * Return the amount of memory currently occupied for unrolling blocks by this task. */ def currentUnrollMemoryForThisTask: Long = memoryManager.synchronized { - unrollMemoryMap.getOrElse(currentTaskAttemptId(), 0L) + onHeapUnrollMemoryMap.getOrElse(currentTaskAttemptId(), 0L) + + offHeapUnrollMemoryMap.getOrElse(currentTaskAttemptId(), 0L) } /** * Return the number of tasks currently unrolling blocks. */ - private def numTasksUnrolling: Int = memoryManager.synchronized { unrollMemoryMap.keys.size } + private def numTasksUnrolling: Int = memoryManager.synchronized { + (onHeapUnrollMemoryMap.keys ++ offHeapUnrollMemoryMap.keys).toSet.size + } /** * Log information about current memory usage. @@ -627,7 +668,7 @@ private[storage] class PartiallyUnrolledIterator[T]( private[this] var iter: Iterator[T] = { val completionIterator = CompletionIterator[T, Iterator[T]](unrolled, { unrolledIteratorIsConsumed = true - memoryStore.releaseUnrollMemoryForThisTask(unrollMemory) + memoryStore.releaseUnrollMemoryForThisTask(MemoryMode.ON_HEAP, unrollMemory) }) completionIterator ++ rest } @@ -640,7 +681,7 @@ private[storage] class PartiallyUnrolledIterator[T]( */ def close(): Unit = { if (!unrolledIteratorIsConsumed) { - memoryStore.releaseUnrollMemoryForThisTask(unrollMemory) + memoryStore.releaseUnrollMemoryForThisTask(MemoryMode.ON_HEAP, unrollMemory) unrolledIteratorIsConsumed = true } iter = null @@ -669,6 +710,7 @@ private class RedirectableOutputStream extends OutputStream { * @param serializationStream a serialization stream which writes to [[redirectableOutputStream]]. * @param redirectableOutputStream an OutputStream which can be redirected to a different sink. * @param unrollMemory the amount of unroll memory used by the values in `unrolled`. + * @param memoryMode whether the unroll memory is on- or off-heap * @param unrolled a byte buffer containing the partially-serialized values. * @param rest the rest of the original iterator passed to * [[MemoryStore.putIteratorAsValues()]]. @@ -681,18 +723,36 @@ private[storage] class PartiallySerializedBlock[T]( serializationStream: SerializationStream, redirectableOutputStream: RedirectableOutputStream, unrollMemory: Long, + memoryMode: MemoryMode, unrolled: ChunkedByteBuffer, rest: Iterator[T], classTag: ClassTag[T]) { + // If the task does not fully consume `valuesIterator` or otherwise fails to consume or dispose of + // this PartiallySerializedBlock then we risk leaking of direct buffers, so we use a task + // completion listener here in order to ensure that `unrolled.dispose()` is called at least once. + // The dispose() method is idempotent, so it's safe to call it unconditionally. + Option(TaskContext.get()).foreach { taskContext => + taskContext.addTaskCompletionListener { _ => + // When a task completes, its unroll memory will automatically be freed. Thus we do not call + // releaseUnrollMemoryForThisTask() here because we want to avoid double-freeing. + unrolled.dispose() + } + } + /** * Called to dispose of this block and free its memory. */ def discard(): Unit = { try { + // We want to close the output stream in order to free any resources associated with the + // serializer itself (such as Kryo's internal buffers). close() might cause data to be + // written, so redirect the output stream to discard that data. + redirectableOutputStream.setOutputStream(ByteStreams.nullOutputStream()) serializationStream.close() } finally { - memoryStore.releaseUnrollMemoryForThisTask(unrollMemory) + unrolled.dispose() + memoryStore.releaseUnrollMemoryForThisTask(memoryMode, unrollMemory) } } @@ -701,12 +761,14 @@ private[storage] class PartiallySerializedBlock[T]( * and then serializing the values from the original input iterator. */ def finishWritingToStream(os: OutputStream): Unit = { - ByteStreams.copy(unrolled.toInputStream(), os) + // `unrolled`'s underlying buffers will be freed once this input stream is fully read: + ByteStreams.copy(unrolled.toInputStream(dispose = true), os) + memoryStore.releaseUnrollMemoryForThisTask(memoryMode, unrollMemory) redirectableOutputStream.setOutputStream(os) while (rest.hasNext) { serializationStream.writeObject(rest.next())(classTag) } - discard() + serializationStream.close() } /** @@ -717,10 +779,13 @@ private[storage] class PartiallySerializedBlock[T]( * `close()` on it to free its resources. */ def valuesIterator: PartiallyUnrolledIterator[T] = { + // `unrolled`'s underlying buffers will be freed once this input stream is fully read: + val unrolledIter = serializerManager.dataDeserializeStream( + blockId, unrolled.toInputStream(dispose = true))(classTag) new PartiallyUnrolledIterator( memoryStore, unrollMemory, - unrolled = serializerManager.dataDeserialize(blockId, unrolled)(classTag), + unrolled = CompletionIterator[T, Iterator[T]](unrolledIter, discard()), rest = rest) } } diff --git a/core/src/main/scala/org/apache/spark/util/io/ByteArrayChunkOutputStream.scala b/core/src/main/scala/org/apache/spark/util/io/ByteArrayChunkOutputStream.scala deleted file mode 100644 index 16fe3be303..0000000000 --- a/core/src/main/scala/org/apache/spark/util/io/ByteArrayChunkOutputStream.scala +++ /dev/null @@ -1,99 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.util.io - -import java.io.OutputStream - -import scala.collection.mutable.ArrayBuffer - - -/** - * An OutputStream that writes to fixed-size chunks of byte arrays. - * - * @param chunkSize size of each chunk, in bytes. - */ -private[spark] -class ByteArrayChunkOutputStream(chunkSize: Int) extends OutputStream { - - private[this] val chunks = new ArrayBuffer[Array[Byte]] - - /** Index of the last chunk. Starting with -1 when the chunks array is empty. */ - private[this] var lastChunkIndex = -1 - - /** - * Next position to write in the last chunk. - * - * If this equals chunkSize, it means for next write we need to allocate a new chunk. - * This can also never be 0. - */ - private[this] var position = chunkSize - private[this] var _size = 0 - - def size: Long = _size - - override def write(b: Int): Unit = { - allocateNewChunkIfNeeded() - chunks(lastChunkIndex)(position) = b.toByte - position += 1 - _size += 1 - } - - override def write(bytes: Array[Byte], off: Int, len: Int): Unit = { - var written = 0 - while (written < len) { - allocateNewChunkIfNeeded() - val thisBatch = math.min(chunkSize - position, len - written) - System.arraycopy(bytes, written + off, chunks(lastChunkIndex), position, thisBatch) - written += thisBatch - position += thisBatch - } - _size += len - } - - @inline - private def allocateNewChunkIfNeeded(): Unit = { - if (position == chunkSize) { - chunks += new Array[Byte](chunkSize) - lastChunkIndex += 1 - position = 0 - } - } - - def toArrays: Array[Array[Byte]] = { - if (lastChunkIndex == -1) { - new Array[Array[Byte]](0) - } else { - // Copy the first n-1 chunks to the output, and then create an array that fits the last chunk. - // An alternative would have been returning an array of ByteBuffers, with the last buffer - // bounded to only the last chunk's position. However, given our use case in Spark (to put - // the chunks in block manager), only limiting the view bound of the buffer would still - // require the block manager to store the whole chunk. - val ret = new Array[Array[Byte]](chunks.size) - for (i <- 0 until chunks.size - 1) { - ret(i) = chunks(i) - } - if (position == chunkSize) { - ret(lastChunkIndex) = chunks(lastChunkIndex) - } else { - ret(lastChunkIndex) = new Array[Byte](position) - System.arraycopy(chunks(lastChunkIndex), 0, ret(lastChunkIndex), 0, position) - } - ret - } - } -} diff --git a/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala b/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala index c643c4b63c..fb4706e78d 100644 --- a/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala +++ b/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala @@ -41,6 +41,8 @@ private[spark] class ChunkedByteBuffer(var chunks: Array[ByteBuffer]) { require(chunks.forall(_.limit() > 0), "chunks must be non-empty") require(chunks.forall(_.position() == 0), "chunks' positions must be 0") + private[this] var disposed: Boolean = false + /** * This size of this buffer, in bytes. */ @@ -117,11 +119,12 @@ private[spark] class ChunkedByteBuffer(var chunks: Array[ByteBuffer]) { /** * Make a copy of this ChunkedByteBuffer, copying all of the backing data into new buffers. * The new buffer will share no resources with the original buffer. + * + * @param allocator a method for allocating byte buffers */ - def copy(): ChunkedByteBuffer = { + def copy(allocator: Int => ByteBuffer): ChunkedByteBuffer = { val copiedChunks = getChunks().map { chunk => - // TODO: accept an allocator in this copy method to integrate with mem. accounting systems - val newChunk = ByteBuffer.allocate(chunk.limit()) + val newChunk = allocator(chunk.limit()) newChunk.put(chunk) newChunk.flip() newChunk @@ -136,7 +139,10 @@ private[spark] class ChunkedByteBuffer(var chunks: Array[ByteBuffer]) { * unfortunately no standard API to do this. */ def dispose(): Unit = { - chunks.foreach(StorageUtils.dispose) + if (!disposed) { + chunks.foreach(StorageUtils.dispose) + disposed = true + } } } diff --git a/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBufferOutputStream.scala b/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBufferOutputStream.scala new file mode 100644 index 0000000000..67b50d1e70 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBufferOutputStream.scala @@ -0,0 +1,113 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.io + +import java.io.OutputStream +import java.nio.ByteBuffer + +import scala.collection.mutable.ArrayBuffer + +import org.apache.spark.storage.StorageUtils + +/** + * An OutputStream that writes to fixed-size chunks of byte arrays. + * + * @param chunkSize size of each chunk, in bytes. + */ +private[spark] class ChunkedByteBufferOutputStream( + chunkSize: Int, + allocator: Int => ByteBuffer) + extends OutputStream { + + private[this] var toChunkedByteBufferWasCalled = false + + private val chunks = new ArrayBuffer[ByteBuffer] + + /** Index of the last chunk. Starting with -1 when the chunks array is empty. */ + private[this] var lastChunkIndex = -1 + + /** + * Next position to write in the last chunk. + * + * If this equals chunkSize, it means for next write we need to allocate a new chunk. + * This can also never be 0. + */ + private[this] var position = chunkSize + private[this] var _size = 0 + + def size: Long = _size + + override def write(b: Int): Unit = { + allocateNewChunkIfNeeded() + chunks(lastChunkIndex).put(b.toByte) + position += 1 + _size += 1 + } + + override def write(bytes: Array[Byte], off: Int, len: Int): Unit = { + var written = 0 + while (written < len) { + allocateNewChunkIfNeeded() + val thisBatch = math.min(chunkSize - position, len - written) + chunks(lastChunkIndex).put(bytes, written + off, thisBatch) + written += thisBatch + position += thisBatch + } + _size += len + } + + @inline + private def allocateNewChunkIfNeeded(): Unit = { + require(!toChunkedByteBufferWasCalled, "cannot write after toChunkedByteBuffer() is called") + if (position == chunkSize) { + chunks += allocator(chunkSize) + lastChunkIndex += 1 + position = 0 + } + } + + def toChunkedByteBuffer: ChunkedByteBuffer = { + require(!toChunkedByteBufferWasCalled, "toChunkedByteBuffer() can only be called once") + toChunkedByteBufferWasCalled = true + if (lastChunkIndex == -1) { + new ChunkedByteBuffer(Array.empty[ByteBuffer]) + } else { + // Copy the first n-1 chunks to the output, and then create an array that fits the last chunk. + // An alternative would have been returning an array of ByteBuffers, with the last buffer + // bounded to only the last chunk's position. However, given our use case in Spark (to put + // the chunks in block manager), only limiting the view bound of the buffer would still + // require the block manager to store the whole chunk. + val ret = new Array[ByteBuffer](chunks.size) + for (i <- 0 until chunks.size - 1) { + ret(i) = chunks(i) + ret(i).flip() + } + if (position == chunkSize) { + ret(lastChunkIndex) = chunks(lastChunkIndex) + ret(lastChunkIndex).flip() + } else { + ret(lastChunkIndex) = allocator(position) + chunks(lastChunkIndex).flip() + ret(lastChunkIndex).put(chunks(lastChunkIndex)) + ret(lastChunkIndex).flip() + StorageUtils.dispose(chunks(lastChunkIndex)) + } + new ChunkedByteBuffer(ret) + } + } +} diff --git a/core/src/test/scala/org/apache/spark/DistributedSuite.scala b/core/src/test/scala/org/apache/spark/DistributedSuite.scala index 3dded4d486..67d722c1dc 100644 --- a/core/src/test/scala/org/apache/spark/DistributedSuite.scala +++ b/core/src/test/scala/org/apache/spark/DistributedSuite.scala @@ -198,8 +198,8 @@ class DistributedSuite extends SparkFunSuite with Matchers with LocalSparkContex blockManager.master.getLocations(blockId).foreach { cmId => val bytes = blockTransfer.fetchBlockSync(cmId.host, cmId.port, cmId.executorId, blockId.toString) - val deserialized = serializerManager.dataDeserialize[Int](blockId, - new ChunkedByteBuffer(bytes.nioByteBuffer())).toList + val deserialized = serializerManager.dataDeserializeStream[Int](blockId, + new ChunkedByteBuffer(bytes.nioByteBuffer()).toInputStream()).toList assert(deserialized === (1 to 100).toList) } } diff --git a/core/src/test/scala/org/apache/spark/io/ChunkedByteBufferSuite.scala b/core/src/test/scala/org/apache/spark/io/ChunkedByteBufferSuite.scala index aab70e7431..f205d4f0d6 100644 --- a/core/src/test/scala/org/apache/spark/io/ChunkedByteBufferSuite.scala +++ b/core/src/test/scala/org/apache/spark/io/ChunkedByteBufferSuite.scala @@ -52,7 +52,7 @@ class ChunkedByteBufferSuite extends SparkFunSuite { test("copy() does not affect original buffer's position") { val chunkedByteBuffer = new ChunkedByteBuffer(Array(ByteBuffer.allocate(8))) - chunkedByteBuffer.copy() + chunkedByteBuffer.copy(ByteBuffer.allocate) assert(chunkedByteBuffer.getChunks().head.position() === 0) } diff --git a/core/src/test/scala/org/apache/spark/memory/MemoryManagerSuite.scala b/core/src/test/scala/org/apache/spark/memory/MemoryManagerSuite.scala index aaca653c58..3d1a0e9795 100644 --- a/core/src/test/scala/org/apache/spark/memory/MemoryManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/memory/MemoryManagerSuite.scala @@ -71,7 +71,8 @@ private[memory] trait MemoryManagerSuite extends SparkFunSuite with BeforeAndAft */ protected def makeMemoryStore(mm: MemoryManager): MemoryStore = { val ms = mock(classOf[MemoryStore], RETURNS_SMART_NULLS) - when(ms.evictBlocksToFreeSpace(any(), anyLong())).thenAnswer(evictBlocksToFreeSpaceAnswer(mm)) + when(ms.evictBlocksToFreeSpace(any(), anyLong(), any())) + .thenAnswer(evictBlocksToFreeSpaceAnswer(mm)) mm.setMemoryStore(ms) ms } diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala index 98e8450fa1..2ec5319d55 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala @@ -27,7 +27,7 @@ import org.scalatest.{BeforeAndAfter, Matchers} import org.scalatest.concurrent.Eventually._ import org.apache.spark._ -import org.apache.spark.memory.StaticMemoryManager +import org.apache.spark.memory.UnifiedMemoryManager import org.apache.spark.network.BlockTransferService import org.apache.spark.network.netty.NettyBlockTransferService import org.apache.spark.rpc.RpcEnv @@ -60,8 +60,10 @@ class BlockManagerReplicationSuite extends SparkFunSuite with Matchers with Befo private def makeBlockManager( maxMem: Long, name: String = SparkContext.DRIVER_IDENTIFIER): BlockManager = { + conf.set("spark.testing.memory", maxMem.toString) + conf.set("spark.memory.offHeap.size", maxMem.toString) val transfer = new NettyBlockTransferService(conf, securityMgr, numCores = 1) - val memManager = new StaticMemoryManager(conf, Long.MaxValue, maxMem, numCores = 1) + val memManager = UnifiedMemoryManager(conf, numCores = 1) val serializerManager = new SerializerManager(serializer, conf) val store = new BlockManager(name, rpcEnv, master, serializerManager, conf, memManager, mapOutputTracker, shuffleManager, transfer, securityMgr, 0) @@ -76,6 +78,9 @@ class BlockManagerReplicationSuite extends SparkFunSuite with Matchers with Befo conf.set("spark.authenticate", "false") conf.set("spark.driver.port", rpcEnv.address.port.toString) + conf.set("spark.testing", "true") + conf.set("spark.memory.fraction", "1") + conf.set("spark.memory.storageFraction", "1") conf.set("spark.storage.unrollFraction", "0.4") conf.set("spark.storage.unrollMemoryThreshold", "512") @@ -172,6 +177,10 @@ class BlockManagerReplicationSuite extends SparkFunSuite with Matchers with Befo testReplication(5, storageLevels) } + test("block replication - off-heap") { + testReplication(2, Seq(OFF_HEAP, StorageLevel(true, true, true, false, 2))) + } + test("block replication - 2x replication without peers") { intercept[org.scalatest.exceptions.TestFailedException] { testReplication(1, @@ -262,7 +271,8 @@ class BlockManagerReplicationSuite extends SparkFunSuite with Matchers with Befo val failableTransfer = mock(classOf[BlockTransferService]) // this wont actually work when(failableTransfer.hostName).thenReturn("some-hostname") when(failableTransfer.port).thenReturn(1000) - val memManager = new StaticMemoryManager(conf, Long.MaxValue, 10000, numCores = 1) + conf.set("spark.testing.memory", "10000") + val memManager = UnifiedMemoryManager(conf, numCores = 1) val serializerManager = new SerializerManager(serializer, conf) val failableStore = new BlockManager("failable-store", rpcEnv, master, serializerManager, conf, memManager, mapOutputTracker, shuffleManager, failableTransfer, securityMgr, 0) @@ -392,10 +402,14 @@ class BlockManagerReplicationSuite extends SparkFunSuite with Matchers with Befo // If the block is supposed to be in memory, then drop the copy of the block in // this store test whether master is updated with zero memory usage this store if (storageLevel.useMemory) { + val sl = if (storageLevel.useOffHeap) { + StorageLevel(false, true, true, false, 1) + } else { + MEMORY_ONLY_SER + } // Force the block to be dropped by adding a number of dummy blocks (1 to 10).foreach { - i => - testStore.putSingle(s"dummy-block-$i", new Array[Byte](1000), MEMORY_ONLY_SER) + i => testStore.putSingle(s"dummy-block-$i", new Array[Byte](1000), sl) } (1 to 10).foreach { i => testStore.removeBlock(s"dummy-block-$i") diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index 9f3a775654..32c00ac687 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -34,7 +34,7 @@ import org.scalatest.concurrent.Timeouts._ import org.apache.spark._ import org.apache.spark.executor.DataReadMethod -import org.apache.spark.memory.{MemoryMode, StaticMemoryManager} +import org.apache.spark.memory.UnifiedMemoryManager import org.apache.spark.network.{BlockDataManager, BlockTransferService} import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer} import org.apache.spark.network.netty.NettyBlockTransferService @@ -74,10 +74,12 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE name: String = SparkContext.DRIVER_IDENTIFIER, master: BlockManagerMaster = this.master, transferService: Option[BlockTransferService] = Option.empty): BlockManager = { + conf.set("spark.testing.memory", maxMem.toString) + conf.set("spark.memory.offHeap.size", maxMem.toString) val serializer = new KryoSerializer(conf) val transfer = transferService .getOrElse(new NettyBlockTransferService(conf, securityMgr, numCores = 1)) - val memManager = new StaticMemoryManager(conf, Long.MaxValue, maxMem, numCores = 1) + val memManager = UnifiedMemoryManager(conf, numCores = 1) val serializerManager = new SerializerManager(serializer, conf) val blockManager = new BlockManager(name, rpcEnv, master, serializerManager, conf, memManager, mapOutputTracker, shuffleManager, transfer, securityMgr, 0) @@ -92,6 +94,9 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE System.setProperty("os.arch", "amd64") conf = new SparkConf(false) .set("spark.app.id", "test") + .set("spark.testing", "true") + .set("spark.memory.fraction", "1") + .set("spark.memory.storageFraction", "1") .set("spark.kryoserializer.buffer", "1m") .set("spark.test.useCompressedOops", "true") .set("spark.storage.unrollFraction", "0.4") @@ -518,6 +523,14 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE testInMemoryLRUStorage(StorageLevel.MEMORY_ONLY_SER) } + test("in-memory LRU storage with off-heap") { + testInMemoryLRUStorage(StorageLevel( + useDisk = false, + useMemory = true, + useOffHeap = true, + deserialized = false, replication = 1)) + } + private def testInMemoryLRUStorage(storageLevel: StorageLevel): Unit = { store = makeBlockManager(12000) val a1 = new Array[Byte](4000) @@ -608,6 +621,14 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE testDiskAndMemoryStorage(StorageLevel.MEMORY_AND_DISK_SER, getAsBytes = true) } + test("disk and off-heap memory storage") { + testDiskAndMemoryStorage(StorageLevel.OFF_HEAP, getAsBytes = false) + } + + test("disk and off-heap memory storage with getLocalBytes") { + testDiskAndMemoryStorage(StorageLevel.OFF_HEAP, getAsBytes = true) + } + def testDiskAndMemoryStorage( storageLevel: StorageLevel, getAsBytes: Boolean): Unit = { @@ -817,12 +838,9 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE test("block store put failure") { // Use Java serializer so we can create an unserializable error. + conf.set("spark.testing.memory", "1200") val transfer = new NettyBlockTransferService(conf, securityMgr, numCores = 1) - val memoryManager = new StaticMemoryManager( - conf, - maxOnHeapExecutionMemory = Long.MaxValue, - maxOnHeapStorageMemory = 1200, - numCores = 1) + val memoryManager = UnifiedMemoryManager(conf, numCores = 1) val serializerManager = new SerializerManager(new JavaSerializer(conf), conf) store = new BlockManager(SparkContext.DRIVER_IDENTIFIER, rpcEnv, master, serializerManager, conf, memoryManager, mapOutputTracker, diff --git a/core/src/test/scala/org/apache/spark/storage/MemoryStoreSuite.scala b/core/src/test/scala/org/apache/spark/storage/MemoryStoreSuite.scala index 43e832dc02..145d432afe 100644 --- a/core/src/test/scala/org/apache/spark/storage/MemoryStoreSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/MemoryStoreSuite.scala @@ -27,7 +27,7 @@ import scala.reflect.ClassTag import org.scalatest._ import org.apache.spark._ -import org.apache.spark.memory.StaticMemoryManager +import org.apache.spark.memory.{MemoryMode, StaticMemoryManager} import org.apache.spark.serializer.{KryoSerializer, SerializerManager} import org.apache.spark.storage.memory.{BlockEvictionHandler, MemoryStore, PartiallySerializedBlock, PartiallyUnrolledIterator} import org.apache.spark.util._ @@ -86,7 +86,7 @@ class MemoryStoreSuite assert(memoryStore.currentUnrollMemoryForThisTask === 0) def reserveUnrollMemoryForThisTask(memory: Long): Boolean = { - memoryStore.reserveUnrollMemoryForThisTask(TestBlockId(""), memory) + memoryStore.reserveUnrollMemoryForThisTask(TestBlockId(""), memory, MemoryMode.ON_HEAP) } // Reserve @@ -99,9 +99,9 @@ class MemoryStoreSuite assert(!reserveUnrollMemoryForThisTask(1000000)) assert(memoryStore.currentUnrollMemoryForThisTask === 800) // not granted // Release - memoryStore.releaseUnrollMemoryForThisTask(100) + memoryStore.releaseUnrollMemoryForThisTask(MemoryMode.ON_HEAP, 100) assert(memoryStore.currentUnrollMemoryForThisTask === 700) - memoryStore.releaseUnrollMemoryForThisTask(100) + memoryStore.releaseUnrollMemoryForThisTask(MemoryMode.ON_HEAP, 100) assert(memoryStore.currentUnrollMemoryForThisTask === 600) // Reserve again assert(reserveUnrollMemoryForThisTask(4400)) @@ -109,9 +109,9 @@ class MemoryStoreSuite assert(!reserveUnrollMemoryForThisTask(20000)) assert(memoryStore.currentUnrollMemoryForThisTask === 5000) // not granted // Release again - memoryStore.releaseUnrollMemoryForThisTask(1000) + memoryStore.releaseUnrollMemoryForThisTask(MemoryMode.ON_HEAP, 1000) assert(memoryStore.currentUnrollMemoryForThisTask === 4000) - memoryStore.releaseUnrollMemoryForThisTask() // release all + memoryStore.releaseUnrollMemoryForThisTask(MemoryMode.ON_HEAP) // release all assert(memoryStore.currentUnrollMemoryForThisTask === 0) } @@ -254,7 +254,7 @@ class MemoryStoreSuite assert(blockInfoManager.lockNewBlockForWriting( blockId, new BlockInfo(StorageLevel.MEMORY_ONLY_SER, classTag, tellMaster = false))) - val res = memoryStore.putIteratorAsBytes(blockId, iter, classTag) + val res = memoryStore.putIteratorAsBytes(blockId, iter, classTag, MemoryMode.ON_HEAP) blockInfoManager.unlock(blockId) res } @@ -312,7 +312,7 @@ class MemoryStoreSuite assert(blockInfoManager.lockNewBlockForWriting( "b1", new BlockInfo(StorageLevel.MEMORY_ONLY_SER, ClassTag.Any, tellMaster = false))) - val res = memoryStore.putIteratorAsBytes("b1", bigIterator, ClassTag.Any) + val res = memoryStore.putIteratorAsBytes("b1", bigIterator, ClassTag.Any, MemoryMode.ON_HEAP) blockInfoManager.unlock("b1") assert(res.isLeft) assert(memoryStore.currentUnrollMemoryForThisTask > 0) @@ -333,7 +333,7 @@ class MemoryStoreSuite assert(blockInfoManager.lockNewBlockForWriting( "b1", new BlockInfo(StorageLevel.MEMORY_ONLY_SER, ClassTag.Any, tellMaster = false))) - val res = memoryStore.putIteratorAsBytes("b1", bigIterator, ClassTag.Any) + val res = memoryStore.putIteratorAsBytes("b1", bigIterator, ClassTag.Any, MemoryMode.ON_HEAP) blockInfoManager.unlock("b1") assert(res.isLeft) assert(memoryStore.currentUnrollMemoryForThisTask > 0) @@ -395,7 +395,7 @@ class MemoryStoreSuite val blockId = BlockId("rdd_3_10") blockInfoManager.lockNewBlockForWriting( blockId, new BlockInfo(StorageLevel.MEMORY_ONLY, ClassTag.Any, tellMaster = false)) - memoryStore.putBytes(blockId, 13000, () => { + memoryStore.putBytes(blockId, 13000, MemoryMode.ON_HEAP, () => { fail("A big ByteBuffer that cannot be put into MemoryStore should not be created") }) } @@ -404,7 +404,7 @@ class MemoryStoreSuite val (memoryStore, _) = makeMemoryStore(12000) val blockId = BlockId("rdd_3_10") var bytes: ChunkedByteBuffer = null - memoryStore.putBytes(blockId, 10000, () => { + memoryStore.putBytes(blockId, 10000, MemoryMode.ON_HEAP, () => { bytes = new ChunkedByteBuffer(ByteBuffer.allocate(10000)) bytes }) diff --git a/core/src/test/scala/org/apache/spark/util/io/ByteArrayChunkOutputStreamSuite.scala b/core/src/test/scala/org/apache/spark/util/io/ByteArrayChunkOutputStreamSuite.scala deleted file mode 100644 index 361ec95654..0000000000 --- a/core/src/test/scala/org/apache/spark/util/io/ByteArrayChunkOutputStreamSuite.scala +++ /dev/null @@ -1,109 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.util.io - -import scala.util.Random - -import org.apache.spark.SparkFunSuite - - -class ByteArrayChunkOutputStreamSuite extends SparkFunSuite { - - test("empty output") { - val o = new ByteArrayChunkOutputStream(1024) - assert(o.toArrays.length === 0) - } - - test("write a single byte") { - val o = new ByteArrayChunkOutputStream(1024) - o.write(10) - assert(o.toArrays.length === 1) - assert(o.toArrays.head.toSeq === Seq(10.toByte)) - } - - test("write a single near boundary") { - val o = new ByteArrayChunkOutputStream(10) - o.write(new Array[Byte](9)) - o.write(99) - assert(o.toArrays.length === 1) - assert(o.toArrays.head(9) === 99.toByte) - } - - test("write a single at boundary") { - val o = new ByteArrayChunkOutputStream(10) - o.write(new Array[Byte](10)) - o.write(99) - assert(o.toArrays.length === 2) - assert(o.toArrays(1).length === 1) - assert(o.toArrays(1)(0) === 99.toByte) - } - - test("single chunk output") { - val ref = new Array[Byte](8) - Random.nextBytes(ref) - val o = new ByteArrayChunkOutputStream(10) - o.write(ref) - val arrays = o.toArrays - assert(arrays.length === 1) - assert(arrays.head.length === ref.length) - assert(arrays.head.toSeq === ref.toSeq) - } - - test("single chunk output at boundary size") { - val ref = new Array[Byte](10) - Random.nextBytes(ref) - val o = new ByteArrayChunkOutputStream(10) - o.write(ref) - val arrays = o.toArrays - assert(arrays.length === 1) - assert(arrays.head.length === ref.length) - assert(arrays.head.toSeq === ref.toSeq) - } - - test("multiple chunk output") { - val ref = new Array[Byte](26) - Random.nextBytes(ref) - val o = new ByteArrayChunkOutputStream(10) - o.write(ref) - val arrays = o.toArrays - assert(arrays.length === 3) - assert(arrays(0).length === 10) - assert(arrays(1).length === 10) - assert(arrays(2).length === 6) - - assert(arrays(0).toSeq === ref.slice(0, 10)) - assert(arrays(1).toSeq === ref.slice(10, 20)) - assert(arrays(2).toSeq === ref.slice(20, 26)) - } - - test("multiple chunk output at boundary size") { - val ref = new Array[Byte](30) - Random.nextBytes(ref) - val o = new ByteArrayChunkOutputStream(10) - o.write(ref) - val arrays = o.toArrays - assert(arrays.length === 3) - assert(arrays(0).length === 10) - assert(arrays(1).length === 10) - assert(arrays(2).length === 10) - - assert(arrays(0).toSeq === ref.slice(0, 10)) - assert(arrays(1).toSeq === ref.slice(10, 20)) - assert(arrays(2).toSeq === ref.slice(20, 30)) - } -} diff --git a/core/src/test/scala/org/apache/spark/util/io/ChunkedByteBufferOutputStreamSuite.scala b/core/src/test/scala/org/apache/spark/util/io/ChunkedByteBufferOutputStreamSuite.scala new file mode 100644 index 0000000000..226622075a --- /dev/null +++ b/core/src/test/scala/org/apache/spark/util/io/ChunkedByteBufferOutputStreamSuite.scala @@ -0,0 +1,114 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.io + +import java.nio.ByteBuffer + +import scala.util.Random + +import org.apache.spark.SparkFunSuite + + +class ChunkedByteBufferOutputStreamSuite extends SparkFunSuite { + + test("empty output") { + val o = new ChunkedByteBufferOutputStream(1024, ByteBuffer.allocate) + assert(o.toChunkedByteBuffer.size === 0) + } + + test("write a single byte") { + val o = new ChunkedByteBufferOutputStream(1024, ByteBuffer.allocate) + o.write(10) + val chunkedByteBuffer = o.toChunkedByteBuffer + assert(chunkedByteBuffer.getChunks().length === 1) + assert(chunkedByteBuffer.getChunks().head.array().toSeq === Seq(10.toByte)) + } + + test("write a single near boundary") { + val o = new ChunkedByteBufferOutputStream(10, ByteBuffer.allocate) + o.write(new Array[Byte](9)) + o.write(99) + val chunkedByteBuffer = o.toChunkedByteBuffer + assert(chunkedByteBuffer.getChunks().length === 1) + assert(chunkedByteBuffer.getChunks().head.array()(9) === 99.toByte) + } + + test("write a single at boundary") { + val o = new ChunkedByteBufferOutputStream(10, ByteBuffer.allocate) + o.write(new Array[Byte](10)) + o.write(99) + val arrays = o.toChunkedByteBuffer.getChunks().map(_.array()) + assert(arrays.length === 2) + assert(arrays(1).length === 1) + assert(arrays(1)(0) === 99.toByte) + } + + test("single chunk output") { + val ref = new Array[Byte](8) + Random.nextBytes(ref) + val o = new ChunkedByteBufferOutputStream(10, ByteBuffer.allocate) + o.write(ref) + val arrays = o.toChunkedByteBuffer.getChunks().map(_.array()) + assert(arrays.length === 1) + assert(arrays.head.length === ref.length) + assert(arrays.head.toSeq === ref.toSeq) + } + + test("single chunk output at boundary size") { + val ref = new Array[Byte](10) + Random.nextBytes(ref) + val o = new ChunkedByteBufferOutputStream(10, ByteBuffer.allocate) + o.write(ref) + val arrays = o.toChunkedByteBuffer.getChunks().map(_.array()) + assert(arrays.length === 1) + assert(arrays.head.length === ref.length) + assert(arrays.head.toSeq === ref.toSeq) + } + + test("multiple chunk output") { + val ref = new Array[Byte](26) + Random.nextBytes(ref) + val o = new ChunkedByteBufferOutputStream(10, ByteBuffer.allocate) + o.write(ref) + val arrays = o.toChunkedByteBuffer.getChunks().map(_.array()) + assert(arrays.length === 3) + assert(arrays(0).length === 10) + assert(arrays(1).length === 10) + assert(arrays(2).length === 6) + + assert(arrays(0).toSeq === ref.slice(0, 10)) + assert(arrays(1).toSeq === ref.slice(10, 20)) + assert(arrays(2).toSeq === ref.slice(20, 26)) + } + + test("multiple chunk output at boundary size") { + val ref = new Array[Byte](30) + Random.nextBytes(ref) + val o = new ChunkedByteBufferOutputStream(10, ByteBuffer.allocate) + o.write(ref) + val arrays = o.toChunkedByteBuffer.getChunks().map(_.array()) + assert(arrays.length === 3) + assert(arrays(0).length === 10) + assert(arrays(1).length === 10) + assert(arrays(2).length === 10) + + assert(arrays(0).toSeq === ref.slice(0, 10)) + assert(arrays(1).toSeq === ref.slice(10, 20)) + assert(arrays(2).toSeq === ref.slice(20, 30)) + } +} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala b/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala index c56520b1e2..53fccd8d5e 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDD.scala @@ -162,7 +162,8 @@ class WriteAheadLogBackedBlockRDD[T: ClassTag]( logDebug(s"Stored partition data of $this into block manager with level $storageLevel") dataRead.rewind() } - serializerManager.dataDeserialize(blockId, new ChunkedByteBuffer(dataRead)) + serializerManager + .dataDeserializeStream(blockId, new ChunkedByteBuffer(dataRead).toInputStream()) .asInstanceOf[Iterator[T]] } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala index 4e77cd6347..5fc53bcb91 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala @@ -157,7 +157,8 @@ class ReceivedBlockHandlerSuite val reader = new FileBasedWriteAheadLogRandomReader(fileSegment.path, hadoopConf) val bytes = reader.read(fileSegment) reader.close() - serializerManager.dataDeserialize(generateBlockId(), new ChunkedByteBuffer(bytes)).toList + serializerManager.dataDeserializeStream( + generateBlockId(), new ChunkedByteBuffer(bytes).toInputStream()).toList } loggedData shouldEqual data } -- cgit v1.2.3 From 0b7d4966ca7e02f351c4b92a74789cef4799fcb1 Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Fri, 1 Apr 2016 15:00:38 -0700 Subject: [SPARK-14316][SQL] StateStoreCoordinator should extend ThreadSafeRpcEndpoint ## What changes were proposed in this pull request? RpcEndpoint is not thread safe and allows multiple messages to be processed at the same time. StateStoreCoordinator should use ThreadSafeRpcEndpoint. ## How was this patch tested? Existing unit tests. Author: Shixiong Zhu Closes #12100 from zsxwing/fix-StateStoreCoordinator. --- .../sql/execution/streaming/state/StateStoreCoordinator.scala | 4 ++-- .../spark/sql/execution/streaming/state/StateStoreRDDSuite.scala | 8 +++----- 2 files changed, 5 insertions(+), 7 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala index 5aa0636850..812e1b0a39 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala @@ -21,7 +21,7 @@ import scala.collection.mutable import org.apache.spark.SparkEnv import org.apache.spark.internal.Logging -import org.apache.spark.rpc.{RpcCallContext, RpcEndpoint, RpcEndpointRef, RpcEnv} +import org.apache.spark.rpc.{RpcCallContext, RpcEndpointRef, RpcEnv, ThreadSafeRpcEndpoint} import org.apache.spark.scheduler.ExecutorCacheTaskLocation import org.apache.spark.util.RpcUtils @@ -112,7 +112,7 @@ private[sql] class StateStoreCoordinatorRef private(rpcEndpointRef: RpcEndpointR * Class for coordinating instances of [[StateStore]]s loaded in executors across the cluster, * and get their locations for job scheduling. */ -private class StateStoreCoordinator(override val rpcEnv: RpcEnv) extends RpcEndpoint { +private class StateStoreCoordinator(override val rpcEnv: RpcEnv) extends ThreadSafeRpcEndpoint { private val instances = new mutable.HashMap[StateStoreId, ExecutorCacheTaskLocation] override def receive: PartialFunction[Any, Unit] = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala index df50cbde56..85db05157c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala @@ -124,11 +124,9 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn coordinatorRef.reportActiveInstance(StateStoreId(path, opId, 0), "host1", "exec1") coordinatorRef.reportActiveInstance(StateStoreId(path, opId, 1), "host2", "exec2") - eventually(timeout(10 seconds)) { - assert( - coordinatorRef.getLocation(StateStoreId(path, opId, 0)) === - Some(ExecutorCacheTaskLocation("host1", "exec1").toString)) - } + assert( + coordinatorRef.getLocation(StateStoreId(path, opId, 0)) === + Some(ExecutorCacheTaskLocation("host1", "exec1").toString)) val rdd = makeRDD(sc, Seq("a", "b", "a")).mapPartitionWithStateStore( increment, path, opId, storeVersion = 0, keySchema, valueSchema) -- cgit v1.2.3 From 0fc4aaa71c5b4531b3a7c8ac71d62ea8e66b6f0c Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Fri, 1 Apr 2016 15:15:16 -0700 Subject: [SPARK-14255][SQL] Streaming Aggregation This PR adds the ability to perform aggregations inside of a `ContinuousQuery`. In order to implement this feature, the planning of aggregation has augmented with a new `StatefulAggregationStrategy`. Unlike batch aggregation, stateful-aggregation uses the `StateStore` (introduced in #11645) to persist the results of partial aggregation across different invocations. The resulting physical plan performs the aggregation using the following progression: - Partial Aggregation - Shuffle - Partial Merge (now there is at most 1 tuple per group) - StateStoreRestore (now there is 1 tuple from this batch + optionally one from the previous) - Partial Merge (now there is at most 1 tuple per group) - StateStoreSave (saves the tuple for the next batch) - Complete (output the current result of the aggregation) The following refactoring was also performed to allow us to plug into existing code: - The get/put implementation is taken from #12013 - The logic for breaking down and de-duping the physical execution of aggregation has been move into a new pattern `PhysicalAggregation` - The `AttributeReference` used to identify the result of an `AggregateFunction` as been moved into the `AggregateExpression` container. This change moves the reference into the same object as the other intermediate references used in aggregation and eliminates the need to pass around a `Map[(AggregateFunction, Boolean), Attribute]`. Further clean up (using a different aggregation container for logical/physical plans) is deferred to a followup. - Some planning logic is moved from the `SessionState` into the `QueryExecution` to make it easier to override in the streaming case. - The ability to write a `StreamTest` that checks only the output of the last batch has been added to simulate the future addition of output modes. Author: Michael Armbrust Closes #12048 from marmbrus/statefulAgg. --- .../spark/sql/catalyst/analysis/Analyzer.scala | 9 +- .../sql/catalyst/analysis/CheckAnalysis.scala | 2 +- .../apache/spark/sql/catalyst/errors/package.scala | 7 +- .../expressions/aggregate/interfaces.scala | 37 ++++- .../catalyst/expressions/namedExpressions.scala | 2 +- .../spark/sql/catalyst/optimizer/Optimizer.scala | 14 +- .../spark/sql/catalyst/planning/patterns.scala | 73 ++++++++++ .../apache/spark/sql/catalyst/plans/PlanTest.scala | 3 + .../spark/sql/execution/QueryExecution.scala | 24 +++- .../org/apache/spark/sql/execution/SparkPlan.scala | 7 + .../apache/spark/sql/execution/SparkPlanner.scala | 4 +- .../spark/sql/execution/SparkStrategies.scala | 92 ++++--------- .../org/apache/spark/sql/execution/Window.scala | 2 +- .../aggregate/TungstenAggregationIterator.scala | 4 +- .../spark/sql/execution/aggregate/utils.scala | 121 ++++++++++++---- .../execution/streaming/IncrementalExecution.scala | 72 ++++++++++ .../execution/streaming/StatefulAggregate.scala | 119 ++++++++++++++++ .../sql/execution/streaming/StreamExecution.scala | 12 +- .../spark/sql/execution/streaming/memory.scala | 4 +- .../state/HDFSBackedStateStoreProvider.scala | 36 +++-- .../sql/execution/streaming/state/StateStore.scala | 19 ++- .../execution/streaming/state/StateStoreConf.scala | 4 +- .../execution/streaming/state/StateStoreRDD.scala | 17 +-- .../sql/execution/streaming/state/package.scala | 21 ++- .../org/apache/spark/sql/execution/subquery.scala | 11 +- .../apache/spark/sql/expressions/Aggregator.scala | 4 +- .../apache/spark/sql/internal/SessionState.scala | 16 +-- .../scala/org/apache/spark/sql/StreamTest.scala | 36 +++-- .../apache/spark/sql/execution/SparkPlanTest.scala | 10 +- .../streaming/state/StateStoreRDDSuite.scala | 152 ++++++++++++--------- .../streaming/state/StateStoreSuite.scala | 61 +++++---- .../sql/streaming/StreamingAggregationSuite.scala | 132 ++++++++++++++++++ .../apache/spark/sql/hive/HiveSessionState.scala | 5 +- 33 files changed, 827 insertions(+), 305 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulAggregate.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index d82ee3a205..05e2b9a447 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -336,6 +336,11 @@ class Analyzer( Last(ifExpr(expr), Literal(true)) case a: AggregateFunction => a.withNewChildren(a.children.map(ifExpr)) + }.transform { + // We are duplicating aggregates that are now computing a different value for each + // pivot value. + // TODO: Don't construct the physical container until after analysis. + case ae: AggregateExpression => ae.copy(resultId = NamedExpression.newExprId) } if (filteredAggregate.fastEquals(aggregate)) { throw new AnalysisException( @@ -1153,11 +1158,11 @@ class Analyzer( // Extract Windowed AggregateExpression case we @ WindowExpression( - AggregateExpression(function, mode, isDistinct), + ae @ AggregateExpression(function, _, _, _), spec: WindowSpecDefinition) => val newChildren = function.children.map(extractExpr) val newFunction = function.withNewChildren(newChildren).asInstanceOf[AggregateFunction] - val newAgg = AggregateExpression(newFunction, mode, isDistinct) + val newAgg = ae.copy(aggregateFunction = newFunction) seenWindowAggregates += newAgg WindowExpression(newAgg, spec) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 1d1e892e32..4880502398 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -76,7 +76,7 @@ trait CheckAnalysis { case g: GroupingID => failAnalysis(s"grouping_id() can only be used with GroupingSets/Cube/Rollup") - case w @ WindowExpression(AggregateExpression(_, _, true), _) => + case w @ WindowExpression(AggregateExpression(_, _, true, _), _) => failAnalysis(s"Distinct window functions are not supported: $w") case w @ WindowExpression(_: OffsetWindowFunction, WindowSpecDefinition(_, order, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/errors/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/errors/package.scala index 0d44d1dd96..0420b4b538 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/errors/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/errors/package.scala @@ -25,15 +25,18 @@ import org.apache.spark.sql.catalyst.trees.TreeNode package object errors { class TreeNodeException[TreeType <: TreeNode[_]]( - tree: TreeType, msg: String, cause: Throwable) + @transient val tree: TreeType, + msg: String, + cause: Throwable) extends Exception(msg, cause) { + val treeString = tree.toString + // Yes, this is the same as a default parameter, but... those don't seem to work with SBT // external project dependencies for some reason. def this(tree: TreeType, msg: String) = this(tree, msg, null) override def getMessage: String = { - val treeString = tree.toString s"${super.getMessage}, tree:${if (treeString contains "\n") "\n" else " "}$tree" } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala index ff3064ac66..d31ccf9985 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.expressions.aggregate import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.types._ @@ -66,6 +67,19 @@ private[sql] case object NoOp extends Expression with Unevaluable { override def children: Seq[Expression] = Nil } +object AggregateExpression { + def apply( + aggregateFunction: AggregateFunction, + mode: AggregateMode, + isDistinct: Boolean): AggregateExpression = { + AggregateExpression( + aggregateFunction, + mode, + isDistinct, + NamedExpression.newExprId) + } +} + /** * A container for an [[AggregateFunction]] with its [[AggregateMode]] and a field * (`isDistinct`) indicating if DISTINCT keyword is specified for this function. @@ -73,10 +87,31 @@ private[sql] case object NoOp extends Expression with Unevaluable { private[sql] case class AggregateExpression( aggregateFunction: AggregateFunction, mode: AggregateMode, - isDistinct: Boolean) + isDistinct: Boolean, + resultId: ExprId) extends Expression with Unevaluable { + lazy val resultAttribute: Attribute = if (aggregateFunction.resolved) { + AttributeReference( + aggregateFunction.toString, + aggregateFunction.dataType, + aggregateFunction.nullable)(exprId = resultId) + } else { + // This is a bit of a hack. Really we should not be constructing this container and reasoning + // about datatypes / aggregation mode until after we have finished analysis and made it to + // planning. + UnresolvedAttribute(aggregateFunction.toString) + } + + // We compute the same thing regardless of our final result. + override lazy val canonicalized: Expression = + AggregateExpression( + aggregateFunction.canonicalized.asInstanceOf[AggregateFunction], + mode, + isDistinct, + ExprId(0)) + override def children: Seq[Expression] = aggregateFunction :: Nil override def dataType: DataType = aggregateFunction.dataType override def foldable: Boolean = false diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala index 262582ca5d..2307122ea1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala @@ -329,7 +329,7 @@ case class PrettyAttribute( override def withName(newName: String): Attribute = throw new UnsupportedOperationException override def qualifier: Option[String] = throw new UnsupportedOperationException override def exprId: ExprId = throw new UnsupportedOperationException - override def nullable: Boolean = throw new UnsupportedOperationException + override def nullable: Boolean = true } object VirtualColumn { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index a7a948ef1b..326933ec9e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -534,7 +534,7 @@ object NullPropagation extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { case q: LogicalPlan => q transformExpressionsUp { - case e @ AggregateExpression(Count(exprs), _, _) if !exprs.exists(nonNullLiteral) => + case e @ AggregateExpression(Count(exprs), _, _, _) if !exprs.exists(nonNullLiteral) => Cast(Literal(0L), e.dataType) case e @ IsNull(c) if !c.nullable => Literal.create(false, BooleanType) case e @ IsNotNull(c) if !c.nullable => Literal.create(true, BooleanType) @@ -547,9 +547,9 @@ object NullPropagation extends Rule[LogicalPlan] { Literal.create(null, e.dataType) case e @ EqualNullSafe(Literal(null, _), r) => IsNull(r) case e @ EqualNullSafe(l, Literal(null, _)) => IsNull(l) - case e @ AggregateExpression(Count(exprs), mode, false) if !exprs.exists(_.nullable) => + case ae @ AggregateExpression(Count(exprs), _, false, _) if !exprs.exists(_.nullable) => // This rule should be only triggered when isDistinct field is false. - AggregateExpression(Count(Literal(1)), mode, isDistinct = false) + ae.copy(aggregateFunction = Count(Literal(1))) // For Coalesce, remove null literals. case e @ Coalesce(children) => @@ -1225,13 +1225,13 @@ object DecimalAggregates extends Rule[LogicalPlan] { private val MAX_DOUBLE_DIGITS = 15 def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { - case AggregateExpression(Sum(e @ DecimalType.Expression(prec, scale)), mode, isDistinct) + case ae @ AggregateExpression(Sum(e @ DecimalType.Expression(prec, scale)), _, _, _) if prec + 10 <= MAX_LONG_DIGITS => - MakeDecimal(AggregateExpression(Sum(UnscaledValue(e)), mode, isDistinct), prec + 10, scale) + MakeDecimal(ae.copy(aggregateFunction = Sum(UnscaledValue(e))), prec + 10, scale) - case AggregateExpression(Average(e @ DecimalType.Expression(prec, scale)), mode, isDistinct) + case ae @ AggregateExpression(Average(e @ DecimalType.Expression(prec, scale)), _, _, _) if prec + 4 <= MAX_DOUBLE_DIGITS => - val newAggExpr = AggregateExpression(Average(UnscaledValue(e)), mode, isDistinct) + val newAggExpr = ae.copy(aggregateFunction = Average(UnscaledValue(e))) Cast( Divide(newAggExpr, Literal.create(math.pow(10.0, scale), DoubleType)), DecimalType(prec + 4, scale + 4)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index 9c927077d0..28d2c445b1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -22,6 +22,7 @@ import scala.collection.mutable import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.types.IntegerType @@ -216,3 +217,75 @@ object IntegerIndex { case _ => None } } + +/** + * An extractor used when planning the physical execution of an aggregation. Compared with a logical + * aggregation, the following transformations are performed: + * - Unnamed grouping expressions are named so that they can be referred to across phases of + * aggregation + * - Aggregations that appear multiple times are deduplicated. + * - The compution of the aggregations themselves is separated from the final result. For example, + * the `count` in `count + 1` will be split into an [[AggregateExpression]] and a final + * computation that computes `count.resultAttribute + 1`. + */ +object PhysicalAggregation { + // groupingExpressions, aggregateExpressions, resultExpressions, child + type ReturnType = + (Seq[NamedExpression], Seq[AggregateExpression], Seq[NamedExpression], LogicalPlan) + + def unapply(a: Any): Option[ReturnType] = a match { + case logical.Aggregate(groupingExpressions, resultExpressions, child) => + // A single aggregate expression might appear multiple times in resultExpressions. + // In order to avoid evaluating an individual aggregate function multiple times, we'll + // build a set of the distinct aggregate expressions and build a function which can + // be used to re-write expressions so that they reference the single copy of the + // aggregate function which actually gets computed. + val aggregateExpressions = resultExpressions.flatMap { expr => + expr.collect { + case agg: AggregateExpression => agg + } + }.distinct + + val namedGroupingExpressions = groupingExpressions.map { + case ne: NamedExpression => ne -> ne + // If the expression is not a NamedExpressions, we add an alias. + // So, when we generate the result of the operator, the Aggregate Operator + // can directly get the Seq of attributes representing the grouping expressions. + case other => + val withAlias = Alias(other, other.toString)() + other -> withAlias + } + val groupExpressionMap = namedGroupingExpressions.toMap + + // The original `resultExpressions` are a set of expressions which may reference + // aggregate expressions, grouping column values, and constants. When aggregate operator + // emits output rows, we will use `resultExpressions` to generate an output projection + // which takes the grouping columns and final aggregate result buffer as input. + // Thus, we must re-write the result expressions so that their attributes match up with + // the attributes of the final result projection's input row: + val rewrittenResultExpressions = resultExpressions.map { expr => + expr.transformDown { + case ae: AggregateExpression => + // The final aggregation buffer's attributes will be `finalAggregationAttributes`, + // so replace each aggregate expression by its corresponding attribute in the set: + ae.resultAttribute + case expression => + // Since we're using `namedGroupingAttributes` to extract the grouping key + // columns, we need to replace grouping key expressions with their corresponding + // attributes. We do not rely on the equality check at here since attributes may + // differ cosmetically. Instead, we use semanticEquals. + groupExpressionMap.collectFirst { + case (expr, ne) if expr semanticEquals expression => ne.toAttribute + }.getOrElse(expression) + }.asInstanceOf[NamedExpression] + } + + Some(( + namedGroupingExpressions.map(_._2), + aggregateExpressions, + rewrittenResultExpressions, + child)) + + case _ => None + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala index aa5d4330d3..7191936699 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.plans import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, OneRowRelation, Sample} import org.apache.spark.sql.catalyst.util._ @@ -38,6 +39,8 @@ abstract class PlanTest extends SparkFunSuite with PredicateHelper { AttributeReference(a.name, a.dataType, a.nullable)(exprId = ExprId(0)) case a: Alias => Alias(a.child, a.name)(exprId = ExprId(0)) + case ae: AggregateExpression => + ae.copy(resultId = ExprId(0)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala index 912b84abc1..4843553211 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala @@ -21,6 +21,8 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.{AnalysisException, SQLContext} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, ReturnAnswer} +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.execution.exchange.{EnsureRequirements, ReuseExchange} /** * The primary workflow for executing relational queries using Spark. Designed to allow easy @@ -31,6 +33,9 @@ import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, ReturnAnswer} */ class QueryExecution(val sqlContext: SQLContext, val logical: LogicalPlan) { + // TODO: Move the planner an optimizer into here from SessionState. + protected def planner = sqlContext.sessionState.planner + def assertAnalyzed(): Unit = try sqlContext.sessionState.analyzer.checkAnalysis(analyzed) catch { case e: AnalysisException => val ae = new AnalysisException(e.message, e.line, e.startPosition, Some(analyzed)) @@ -49,16 +54,31 @@ class QueryExecution(val sqlContext: SQLContext, val logical: LogicalPlan) { lazy val sparkPlan: SparkPlan = { SQLContext.setActive(sqlContext) - sqlContext.sessionState.planner.plan(ReturnAnswer(optimizedPlan)).next() + planner.plan(ReturnAnswer(optimizedPlan)).next() } // executedPlan should not be used to initialize any SparkPlan. It should be // only used for execution. - lazy val executedPlan: SparkPlan = sqlContext.sessionState.prepareForExecution.execute(sparkPlan) + lazy val executedPlan: SparkPlan = prepareForExecution(sparkPlan) /** Internal version of the RDD. Avoids copies and has no schema */ lazy val toRdd: RDD[InternalRow] = executedPlan.execute() + /** + * Prepares a planned [[SparkPlan]] for execution by inserting shuffle operations and internal + * row format conversions as needed. + */ + protected def prepareForExecution(plan: SparkPlan): SparkPlan = { + preparations.foldLeft(plan) { case (sp, rule) => rule.apply(sp) } + } + + /** A sequence of rules that will be applied in order to the physical plan before execution. */ + protected def preparations: Seq[Rule[SparkPlan]] = Seq( + PlanSubqueries(sqlContext), + EnsureRequirements(sqlContext.conf), + CollapseCodegenStages(sqlContext.conf), + ReuseExchange(sqlContext.conf)) + protected def stringOrError[A](f: => A): String = try f.toString catch { case e: Throwable => e.toString } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index 010ed7f500..b1b3d4ac81 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -379,6 +379,13 @@ private[sql] trait LeafNode extends SparkPlan { override def producedAttributes: AttributeSet = outputSet } +object UnaryNode { + def unapply(a: Any): Option[(SparkPlan, SparkPlan)] = a match { + case s: SparkPlan if s.children.size == 1 => Some((s, s.children.head)) + case _ => None + } +} + private[sql] trait UnaryNode extends SparkPlan { def child: SparkPlan diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala index 9da2c74c62..ac8072f3ca 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala @@ -26,13 +26,13 @@ import org.apache.spark.sql.internal.SQLConf class SparkPlanner( val sparkContext: SparkContext, val conf: SQLConf, - val experimentalMethods: ExperimentalMethods) + val extraStrategies: Seq[Strategy]) extends SparkStrategies { def numPartitions: Int = conf.numShufflePartitions def strategies: Seq[Strategy] = - experimentalMethods.extraStrategies ++ ( + extraStrategies ++ ( FileSourceStrategy :: DataSourceStrategy :: DDLStrategy :: diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 7a2e2b7382..5bcc172ca7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -20,7 +20,6 @@ package org.apache.spark.sql.execution import org.apache.spark.sql.Strategy import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.planning._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.{BroadcastHint, LogicalPlan} @@ -203,29 +202,33 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { } } + /** + * Used to plan aggregation queries that are computed incrementally as part of a + * [[org.apache.spark.sql.ContinuousQuery]]. Currently this rule is injected into the planner + * on-demand, only when planning in a [[org.apache.spark.sql.execution.streaming.StreamExecution]] + */ + object StatefulAggregationStrategy extends Strategy { + override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + case PhysicalAggregation( + namedGroupingExpressions, aggregateExpressions, rewrittenResultExpressions, child) => + + aggregate.Utils.planStreamingAggregation( + namedGroupingExpressions, + aggregateExpressions, + rewrittenResultExpressions, + planLater(child)) + + case _ => Nil + } + } + /** * Used to plan the aggregate operator for expressions based on the AggregateFunction2 interface. */ object Aggregation extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { - case logical.Aggregate(groupingExpressions, resultExpressions, child) => - // A single aggregate expression might appear multiple times in resultExpressions. - // In order to avoid evaluating an individual aggregate function multiple times, we'll - // build a set of the distinct aggregate expressions and build a function which can - // be used to re-write expressions so that they reference the single copy of the - // aggregate function which actually gets computed. - val aggregateExpressions = resultExpressions.flatMap { expr => - expr.collect { - case agg: AggregateExpression => agg - } - }.distinct - // For those distinct aggregate expressions, we create a map from the - // aggregate function to the corresponding attribute of the function. - val aggregateFunctionToAttribute = aggregateExpressions.map { agg => - val aggregateFunction = agg.aggregateFunction - val attribute = Alias(aggregateFunction, aggregateFunction.toString)().toAttribute - (aggregateFunction, agg.isDistinct) -> attribute - }.toMap + case PhysicalAggregation( + groupingExpressions, aggregateExpressions, resultExpressions, child) => val (functionsWithDistinct, functionsWithoutDistinct) = aggregateExpressions.partition(_.isDistinct) @@ -233,41 +236,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { // This is a sanity check. We should not reach here when we have multiple distinct // column sets. Our MultipleDistinctRewriter should take care this case. sys.error("You hit a query analyzer bug. Please report your query to " + - "Spark user mailing list.") - } - - val namedGroupingExpressions = groupingExpressions.map { - case ne: NamedExpression => ne -> ne - // If the expression is not a NamedExpressions, we add an alias. - // So, when we generate the result of the operator, the Aggregate Operator - // can directly get the Seq of attributes representing the grouping expressions. - case other => - val withAlias = Alias(other, other.toString)() - other -> withAlias - } - val groupExpressionMap = namedGroupingExpressions.toMap - - // The original `resultExpressions` are a set of expressions which may reference - // aggregate expressions, grouping column values, and constants. When aggregate operator - // emits output rows, we will use `resultExpressions` to generate an output projection - // which takes the grouping columns and final aggregate result buffer as input. - // Thus, we must re-write the result expressions so that their attributes match up with - // the attributes of the final result projection's input row: - val rewrittenResultExpressions = resultExpressions.map { expr => - expr.transformDown { - case AggregateExpression(aggregateFunction, _, isDistinct) => - // The final aggregation buffer's attributes will be `finalAggregationAttributes`, - // so replace each aggregate expression by its corresponding attribute in the set: - aggregateFunctionToAttribute(aggregateFunction, isDistinct) - case expression => - // Since we're using `namedGroupingAttributes` to extract the grouping key - // columns, we need to replace grouping key expressions with their corresponding - // attributes. We do not rely on the equality check at here since attributes may - // differ cosmetically. Instead, we use semanticEquals. - groupExpressionMap.collectFirst { - case (expr, ne) if expr semanticEquals expression => ne.toAttribute - }.getOrElse(expression) - }.asInstanceOf[NamedExpression] + "Spark user mailing list.") } val aggregateOperator = @@ -277,26 +246,23 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { "aggregate functions which don't support partial aggregation.") } else { aggregate.Utils.planAggregateWithoutPartial( - namedGroupingExpressions.map(_._2), + groupingExpressions, aggregateExpressions, - aggregateFunctionToAttribute, - rewrittenResultExpressions, + resultExpressions, planLater(child)) } } else if (functionsWithDistinct.isEmpty) { aggregate.Utils.planAggregateWithoutDistinct( - namedGroupingExpressions.map(_._2), + groupingExpressions, aggregateExpressions, - aggregateFunctionToAttribute, - rewrittenResultExpressions, + resultExpressions, planLater(child)) } else { aggregate.Utils.planAggregateWithOneDistinct( - namedGroupingExpressions.map(_._2), + groupingExpressions, functionsWithDistinct, functionsWithoutDistinct, - aggregateFunctionToAttribute, - rewrittenResultExpressions, + resultExpressions, planLater(child)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala index 270c09aff3..7acf020b28 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala @@ -177,7 +177,7 @@ case class Window( case e @ WindowExpression(function, spec) => val frame = spec.frameSpecification.asInstanceOf[SpecifiedWindowFrame] function match { - case AggregateExpression(f, _, _) => collect("AGGREGATE", frame, e, f) + case AggregateExpression(f, _, _, _) => collect("AGGREGATE", frame, e, f) case f: AggregateWindowFunction => collect("AGGREGATE", frame, e, f) case f: OffsetWindowFunction => collect("OFFSET", frame, e, f) case f => sys.error(s"Unsupported window function: $f") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala index 213bca907b..ce504e20e6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala @@ -242,9 +242,9 @@ class TungstenAggregationIterator( // Basically the value of the KVIterator returned by externalSorter // will be just aggregation buffer, so we rewrite the aggregateExpressions to reflect it. val newExpressions = aggregateExpressions.map { - case agg @ AggregateExpression(_, Partial, _) => + case agg @ AggregateExpression(_, Partial, _, _) => agg.copy(mode = PartialMerge) - case agg @ AggregateExpression(_, Complete, _) => + case agg @ AggregateExpression(_, Complete, _, _) => agg.copy(mode = Final) case other => other } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala index 1e113ccd4e..4682949fa1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.aggregate import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.streaming.{StateStoreRestore, StateStoreSave} /** * Utility functions used by the query planner to convert our plan to new aggregation code path. @@ -29,15 +30,11 @@ object Utils { def planAggregateWithoutPartial( groupingExpressions: Seq[NamedExpression], aggregateExpressions: Seq[AggregateExpression], - aggregateFunctionToAttribute: Map[(AggregateFunction, Boolean), Attribute], resultExpressions: Seq[NamedExpression], child: SparkPlan): Seq[SparkPlan] = { val completeAggregateExpressions = aggregateExpressions.map(_.copy(mode = Complete)) - val completeAggregateAttributes = completeAggregateExpressions.map { - expr => aggregateFunctionToAttribute(expr.aggregateFunction, expr.isDistinct) - } - + val completeAggregateAttributes = completeAggregateExpressions.map(_.resultAttribute) SortBasedAggregate( requiredChildDistributionExpressions = Some(groupingExpressions), groupingExpressions = groupingExpressions, @@ -83,7 +80,6 @@ object Utils { def planAggregateWithoutDistinct( groupingExpressions: Seq[NamedExpression], aggregateExpressions: Seq[AggregateExpression], - aggregateFunctionToAttribute: Map[(AggregateFunction, Boolean), Attribute], resultExpressions: Seq[NamedExpression], child: SparkPlan): Seq[SparkPlan] = { // Check if we can use TungstenAggregate. @@ -111,9 +107,7 @@ object Utils { val finalAggregateExpressions = aggregateExpressions.map(_.copy(mode = Final)) // The attributes of the final aggregation buffer, which is presented as input to the result // projection: - val finalAggregateAttributes = finalAggregateExpressions.map { - expr => aggregateFunctionToAttribute(expr.aggregateFunction, expr.isDistinct) - } + val finalAggregateAttributes = finalAggregateExpressions.map(_.resultAttribute) val finalAggregate = createAggregate( requiredChildDistributionExpressions = Some(groupingAttributes), @@ -131,7 +125,6 @@ object Utils { groupingExpressions: Seq[NamedExpression], functionsWithDistinct: Seq[AggregateExpression], functionsWithoutDistinct: Seq[AggregateExpression], - aggregateFunctionToAttribute: Map[(AggregateFunction, Boolean), Attribute], resultExpressions: Seq[NamedExpression], child: SparkPlan): Seq[SparkPlan] = { @@ -151,9 +144,7 @@ object Utils { // 1. Create an Aggregate Operator for partial aggregations. val partialAggregate: SparkPlan = { val aggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = Partial)) - val aggregateAttributes = aggregateExpressions.map { - expr => aggregateFunctionToAttribute(expr.aggregateFunction, expr.isDistinct) - } + val aggregateAttributes = aggregateExpressions.map(_.resultAttribute) // We will group by the original grouping expression, plus an additional expression for the // DISTINCT column. For example, for AVG(DISTINCT value) GROUP BY key, the grouping // expressions will be [key, value]. @@ -169,9 +160,7 @@ object Utils { // 2. Create an Aggregate Operator for partial merge aggregations. val partialMergeAggregate: SparkPlan = { val aggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = PartialMerge)) - val aggregateAttributes = aggregateExpressions.map { - expr => aggregateFunctionToAttribute(expr.aggregateFunction, expr.isDistinct) - } + val aggregateAttributes = aggregateExpressions.map(_.resultAttribute) createAggregate( requiredChildDistributionExpressions = Some(groupingAttributes ++ distinctAttributes), @@ -190,7 +179,7 @@ object Utils { // Children of an AggregateFunction with DISTINCT keyword has already // been evaluated. At here, we need to replace original children // to AttributeReferences. - case agg @ AggregateExpression(aggregateFunction, mode, true) => + case agg @ AggregateExpression(aggregateFunction, mode, true, _) => aggregateFunction.transformDown(distinctColumnAttributeLookup) .asInstanceOf[AggregateFunction] } @@ -199,9 +188,7 @@ object Utils { val mergeAggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = PartialMerge)) // The attributes of the final aggregation buffer, which is presented as input to the result // projection: - val mergeAggregateAttributes = mergeAggregateExpressions.map { - expr => aggregateFunctionToAttribute(expr.aggregateFunction, expr.isDistinct) - } + val mergeAggregateAttributes = mergeAggregateExpressions.map(_.resultAttribute) val (distinctAggregateExpressions, distinctAggregateAttributes) = rewrittenDistinctFunctions.zipWithIndex.map { case (func, i) => // We rewrite the aggregate function to a non-distinct aggregation because @@ -211,7 +198,7 @@ object Utils { val expr = AggregateExpression(func, Partial, isDistinct = true) // Use original AggregationFunction to lookup attributes, which is used to build // aggregateFunctionToAttribute - val attr = aggregateFunctionToAttribute(functionsWithDistinct(i).aggregateFunction, true) + val attr = functionsWithDistinct(i).resultAttribute (expr, attr) }.unzip @@ -232,9 +219,7 @@ object Utils { val finalAggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = Final)) // The attributes of the final aggregation buffer, which is presented as input to the result // projection: - val finalAggregateAttributes = finalAggregateExpressions.map { - expr => aggregateFunctionToAttribute(expr.aggregateFunction, expr.isDistinct) - } + val finalAggregateAttributes = finalAggregateExpressions.map(_.resultAttribute) val (distinctAggregateExpressions, distinctAggregateAttributes) = rewrittenDistinctFunctions.zipWithIndex.map { case (func, i) => @@ -245,7 +230,7 @@ object Utils { val expr = AggregateExpression(func, Final, isDistinct = true) // Use original AggregationFunction to lookup attributes, which is used to build // aggregateFunctionToAttribute - val attr = aggregateFunctionToAttribute(functionsWithDistinct(i).aggregateFunction, true) + val attr = functionsWithDistinct(i).resultAttribute (expr, attr) }.unzip @@ -261,4 +246,90 @@ object Utils { finalAndCompleteAggregate :: Nil } + + /** + * Plans a streaming aggregation using the following progression: + * - Partial Aggregation + * - Shuffle + * - Partial Merge (now there is at most 1 tuple per group) + * - StateStoreRestore (now there is 1 tuple from this batch + optionally one from the previous) + * - PartialMerge (now there is at most 1 tuple per group) + * - StateStoreSave (saves the tuple for the next batch) + * - Complete (output the current result of the aggregation) + */ + def planStreamingAggregation( + groupingExpressions: Seq[NamedExpression], + functionsWithoutDistinct: Seq[AggregateExpression], + resultExpressions: Seq[NamedExpression], + child: SparkPlan): Seq[SparkPlan] = { + + val groupingAttributes = groupingExpressions.map(_.toAttribute) + + val partialAggregate: SparkPlan = { + val aggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = Partial)) + val aggregateAttributes = aggregateExpressions.map(_.resultAttribute) + // We will group by the original grouping expression, plus an additional expression for the + // DISTINCT column. For example, for AVG(DISTINCT value) GROUP BY key, the grouping + // expressions will be [key, value]. + createAggregate( + groupingExpressions = groupingExpressions, + aggregateExpressions = aggregateExpressions, + aggregateAttributes = aggregateAttributes, + resultExpressions = groupingAttributes ++ + aggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes), + child = child) + } + + val partialMerged1: SparkPlan = { + val aggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = PartialMerge)) + val aggregateAttributes = aggregateExpressions.map(_.resultAttribute) + createAggregate( + requiredChildDistributionExpressions = + Some(groupingAttributes), + groupingExpressions = groupingAttributes, + aggregateExpressions = aggregateExpressions, + aggregateAttributes = aggregateAttributes, + initialInputBufferOffset = groupingAttributes.length, + resultExpressions = groupingAttributes ++ + aggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes), + child = partialAggregate) + } + + val restored = StateStoreRestore(groupingAttributes, None, partialMerged1) + + val partialMerged2: SparkPlan = { + val aggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = PartialMerge)) + val aggregateAttributes = aggregateExpressions.map(_.resultAttribute) + createAggregate( + requiredChildDistributionExpressions = + Some(groupingAttributes), + groupingExpressions = groupingAttributes, + aggregateExpressions = aggregateExpressions, + aggregateAttributes = aggregateAttributes, + initialInputBufferOffset = groupingAttributes.length, + resultExpressions = groupingAttributes ++ + aggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes), + child = restored) + } + + val saved = StateStoreSave(groupingAttributes, None, partialMerged2) + + val finalAndCompleteAggregate: SparkPlan = { + val finalAggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = Final)) + // The attributes of the final aggregation buffer, which is presented as input to the result + // projection: + val finalAggregateAttributes = finalAggregateExpressions.map(_.resultAttribute) + + createAggregate( + requiredChildDistributionExpressions = Some(groupingAttributes), + groupingExpressions = groupingAttributes, + aggregateExpressions = finalAggregateExpressions, + aggregateAttributes = finalAggregateAttributes, + initialInputBufferOffset = groupingAttributes.length, + resultExpressions = resultExpressions, + child = saved) + } + + finalAndCompleteAggregate :: Nil + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala new file mode 100644 index 0000000000..aaced49dd1 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala @@ -0,0 +1,72 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.execution.streaming + +import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.execution.{QueryExecution, SparkPlan, SparkPlanner, UnaryNode} + +/** + * A variant of [[QueryExecution]] that allows the execution of the given [[LogicalPlan]] + * plan incrementally. Possibly preserving state in between each execution. + */ +class IncrementalExecution( + ctx: SQLContext, + logicalPlan: LogicalPlan, + checkpointLocation: String, + currentBatchId: Long) extends QueryExecution(ctx, logicalPlan) { + + // TODO: make this always part of planning. + val stateStrategy = sqlContext.sessionState.planner.StatefulAggregationStrategy :: Nil + + // Modified planner with stateful operations. + override def planner: SparkPlanner = + new SparkPlanner( + sqlContext.sparkContext, + sqlContext.conf, + stateStrategy) + + /** + * Records the current id for a given stateful operator in the query plan as the `state` + * preperation walks the query plan. + */ + private var operatorId = 0 + + /** Locates save/restore pairs surrounding aggregation. */ + val state = new Rule[SparkPlan] { + override def apply(plan: SparkPlan): SparkPlan = plan transform { + case StateStoreSave(keys, None, + UnaryNode(agg, + StateStoreRestore(keys2, None, child))) => + val stateId = OperatorStateId(checkpointLocation, operatorId, currentBatchId - 1) + operatorId += 1 + + StateStoreSave( + keys, + Some(stateId), + agg.withNewChildren( + StateStoreRestore( + keys, + Some(stateId), + child) :: Nil)) + } + } + + override def preparations: Seq[Rule[SparkPlan]] = state +: super.preparations +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulAggregate.scala new file mode 100644 index 0000000000..595774761c --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulAggregate.scala @@ -0,0 +1,119 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming + +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.errors._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection +import org.apache.spark.sql.execution +import org.apache.spark.sql.execution.streaming.state._ +import org.apache.spark.sql.execution.SparkPlan + +/** Used to identify the state store for a given operator. */ +case class OperatorStateId( + checkpointLocation: String, + operatorId: Long, + batchId: Long) + +/** + * An operator that saves or restores state from the [[StateStore]]. The [[OperatorStateId]] should + * be filled in by `prepareForExecution` in [[IncrementalExecution]]. + */ +trait StatefulOperator extends SparkPlan { + def stateId: Option[OperatorStateId] + + protected def getStateId: OperatorStateId = attachTree(this) { + stateId.getOrElse { + throw new IllegalStateException("State location not present for execution") + } + } +} + +/** + * For each input tuple, the key is calculated and the value from the [[StateStore]] is added + * to the stream (in addition to the input tuple) if present. + */ +case class StateStoreRestore( + keyExpressions: Seq[Attribute], + stateId: Option[OperatorStateId], + child: SparkPlan) extends execution.UnaryNode with StatefulOperator { + + override protected def doExecute(): RDD[InternalRow] = { + child.execute().mapPartitionsWithStateStore( + getStateId.checkpointLocation, + operatorId = getStateId.operatorId, + storeVersion = getStateId.batchId, + keyExpressions.toStructType, + child.output.toStructType, + new StateStoreConf(sqlContext.conf), + Some(sqlContext.streams.stateStoreCoordinator)) { case (store, iter) => + val getKey = GenerateUnsafeProjection.generate(keyExpressions, child.output) + iter.flatMap { row => + val key = getKey(row) + val savedState = store.get(key) + row +: savedState.toSeq + } + } + } + override def output: Seq[Attribute] = child.output +} + +/** + * For each input tuple, the key is calculated and the tuple is `put` into the [[StateStore]]. + */ +case class StateStoreSave( + keyExpressions: Seq[Attribute], + stateId: Option[OperatorStateId], + child: SparkPlan) extends execution.UnaryNode with StatefulOperator { + + override protected def doExecute(): RDD[InternalRow] = { + child.execute().mapPartitionsWithStateStore( + getStateId.checkpointLocation, + operatorId = getStateId.operatorId, + storeVersion = getStateId.batchId, + keyExpressions.toStructType, + child.output.toStructType, + new StateStoreConf(sqlContext.conf), + Some(sqlContext.streams.stateStoreCoordinator)) { case (store, iter) => + new Iterator[InternalRow] { + private[this] val baseIterator = iter + private[this] val getKey = GenerateUnsafeProjection.generate(keyExpressions, child.output) + + override def hasNext: Boolean = { + if (!baseIterator.hasNext) { + store.commit() + false + } else { + true + } + } + + override def next(): InternalRow = { + val row = baseIterator.next().asInstanceOf[UnsafeRow] + val key = getKey(row) + store.put(key.copy(), row.copy()) + row + } + } + } + } + + override def output: Seq[Attribute] = child.output +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala index c4e410d92c..511e30c70c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala @@ -27,6 +27,7 @@ import org.apache.hadoop.fs.Path import org.apache.spark.internal.Logging import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap} import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} import org.apache.spark.sql.catalyst.util._ @@ -272,6 +273,8 @@ class StreamExecution( private def runBatch(): Unit = { val startTime = System.nanoTime() + // TODO: Move this to IncrementalExecution. + // Request unprocessed data from all sources. val newData = availableOffsets.flatMap { case (source, available) if committedOffsets.get(source).map(_ < available).getOrElse(true) => @@ -305,13 +308,14 @@ class StreamExecution( } val optimizerStart = System.nanoTime() - - lastExecution = new QueryExecution(sqlContext, newPlan) - val executedPlan = lastExecution.executedPlan + lastExecution = + new IncrementalExecution(sqlContext, newPlan, checkpointFile("state"), currentBatchId) + lastExecution.executedPlan val optimizerTime = (System.nanoTime() - optimizerStart).toDouble / 1000000 logDebug(s"Optimized batch in ${optimizerTime}ms") - val nextBatch = Dataset.ofRows(sqlContext, newPlan) + val nextBatch = + new Dataset(sqlContext, lastExecution, RowEncoder(lastExecution.analyzed.schema)) sink.addBatch(currentBatchId - 1, nextBatch) awaitBatchLock.synchronized { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala index 0f91e59e04..7d97f81b0f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala @@ -108,7 +108,7 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) * A sink that stores the results in memory. This [[Sink]] is primarily intended for use in unit * tests and does not provide durability. */ -class MemorySink(schema: StructType) extends Sink with Logging { +class MemorySink(val schema: StructType) extends Sink with Logging { /** An order list of batches that have been written to this [[Sink]]. */ private val batches = new ArrayBuffer[Array[Row]]() @@ -117,6 +117,8 @@ class MemorySink(schema: StructType) extends Sink with Logging { batches.flatten } + def lastBatch: Seq[Row] = batches.last + def toDebugString: String = synchronized { batches.zipWithIndex.map { case (b, i) => val dataStr = try b.mkString(" ") catch { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala index ee015baf3f..998eb82de1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala @@ -81,7 +81,7 @@ private[state] class HDFSBackedStateStoreProvider( trait STATE case object UPDATING extends STATE case object COMMITTED extends STATE - case object CANCELLED extends STATE + case object ABORTED extends STATE private val newVersion = version + 1 private val tempDeltaFile = new Path(baseDir, s"temp-${Random.nextLong}") @@ -94,15 +94,14 @@ private[state] class HDFSBackedStateStoreProvider( override def id: StateStoreId = HDFSBackedStateStoreProvider.this.id - /** - * Update the value of a key using the value generated by the update function. - * @note Do not mutate the retrieved value row as it will unexpectedly affect the previous - * versions of the store data. - */ - override def update(key: UnsafeRow, updateFunc: Option[UnsafeRow] => UnsafeRow): Unit = { - verify(state == UPDATING, "Cannot update after already committed or cancelled") - val oldValueOption = Option(mapToUpdate.get(key)) - val value = updateFunc(oldValueOption) + override def get(key: UnsafeRow): Option[UnsafeRow] = { + Option(mapToUpdate.get(key)) + } + + override def put(key: UnsafeRow, value: UnsafeRow): Unit = { + verify(state == UPDATING, "Cannot remove after already committed or cancelled") + + val isNewKey = !mapToUpdate.containsKey(key) mapToUpdate.put(key, value) Option(allUpdates.get(key)) match { @@ -115,8 +114,7 @@ private[state] class HDFSBackedStateStoreProvider( case None => // There was no prior update, so mark this as added or updated according to its presence // in previous version. - val update = - if (oldValueOption.nonEmpty) ValueUpdated(key, value) else ValueAdded(key, value) + val update = if (isNewKey) ValueAdded(key, value) else ValueUpdated(key, value) allUpdates.put(key, update) } writeToDeltaFile(tempDeltaFileStream, ValueUpdated(key, value)) @@ -148,7 +146,7 @@ private[state] class HDFSBackedStateStoreProvider( /** Commit all the updates that have been made to the store, and return the new version. */ override def commit(): Long = { - verify(state == UPDATING, "Cannot commit again after already committed or cancelled") + verify(state == UPDATING, "Cannot commit after already committed or cancelled") try { finalizeDeltaFile(tempDeltaFileStream) @@ -164,8 +162,8 @@ private[state] class HDFSBackedStateStoreProvider( } /** Cancel all the updates made on this store. This store will not be usable any more. */ - override def cancel(): Unit = { - state = CANCELLED + override def abort(): Unit = { + state = ABORTED if (tempDeltaFileStream != null) { tempDeltaFileStream.close() } @@ -176,8 +174,8 @@ private[state] class HDFSBackedStateStoreProvider( } /** - * Get an iterator of all the store data. This can be called only after committing the - * updates. + * Get an iterator of all the store data. + * This can be called only after committing all the updates made in the current thread. */ override def iterator(): Iterator[(UnsafeRow, UnsafeRow)] = { verify(state == COMMITTED, "Cannot get iterator of store data before comitting") @@ -186,7 +184,7 @@ private[state] class HDFSBackedStateStoreProvider( /** * Get an iterator of all the updates made to the store in the current version. - * This can be called only after committing the updates. + * This can be called only after committing all the updates made in the current thread. */ override def updates(): Iterator[StoreUpdate] = { verify(state == COMMITTED, "Cannot get iterator of updates before committing") @@ -196,7 +194,7 @@ private[state] class HDFSBackedStateStoreProvider( /** * Whether all updates have been committed */ - override def hasCommitted: Boolean = { + override private[state] def hasCommitted: Boolean = { state == COMMITTED } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala index ca5c864d9e..d60e6185ac 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala @@ -47,12 +47,11 @@ trait StateStore { /** Version of the data in this store before committing updates. */ def version: Long - /** - * Update the value of a key using the value generated by the update function. - * @note Do not mutate the retrieved value row as it will unexpectedly affect the previous - * versions of the store data. - */ - def update(key: UnsafeRow, updateFunc: Option[UnsafeRow] => UnsafeRow): Unit + /** Get the current value of a key. */ + def get(key: UnsafeRow): Option[UnsafeRow] + + /** Put a new value for a key. */ + def put(key: UnsafeRow, value: UnsafeRow) /** * Remove keys that match the following condition. @@ -65,24 +64,24 @@ trait StateStore { def commit(): Long /** Cancel all the updates that have been made to the store. */ - def cancel(): Unit + def abort(): Unit /** * Iterator of store data after a set of updates have been committed. - * This can be called only after commitUpdates() has been called in the current thread. + * This can be called only after committing all the updates made in the current thread. */ def iterator(): Iterator[(UnsafeRow, UnsafeRow)] /** * Iterator of the updates that have been committed. - * This can be called only after commitUpdates() has been called in the current thread. + * This can be called only after committing all the updates made in the current thread. */ def updates(): Iterator[StoreUpdate] /** * Whether all updates have been committed */ - def hasCommitted: Boolean + private[state] def hasCommitted: Boolean } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala index cca22a0af8..f0f1f3a1a8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.streaming.state import org.apache.spark.sql.internal.SQLConf /** A class that contains configuration parameters for [[StateStore]]s. */ -private[state] class StateStoreConf(@transient private val conf: SQLConf) extends Serializable { +private[streaming] class StateStoreConf(@transient private val conf: SQLConf) extends Serializable { def this() = this(new SQLConf) @@ -31,7 +31,7 @@ private[state] class StateStoreConf(@transient private val conf: SQLConf) extend val minVersionsToRetain = conf.getConf(STATE_STORE_MIN_VERSIONS_TO_RETAIN) } -private[state] object StateStoreConf { +private[streaming] object StateStoreConf { val empty = new StateStoreConf() } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala index 3318660895..df3d82c113 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala @@ -54,17 +54,10 @@ class StateStoreRDD[T: ClassTag, U: ClassTag]( override def compute(partition: Partition, ctxt: TaskContext): Iterator[U] = { var store: StateStore = null - - Utils.tryWithSafeFinally { - val storeId = StateStoreId(checkpointLocation, operatorId, partition.index) - store = StateStore.get( - storeId, keySchema, valueSchema, storeVersion, storeConf, confBroadcast.value.value) - val inputIter = dataRDD.iterator(partition, ctxt) - val outputIter = storeUpdateFunction(store, inputIter) - assert(store.hasCommitted) - outputIter - } { - if (store != null) store.cancel() - } + val storeId = StateStoreId(checkpointLocation, operatorId, partition.index) + store = StateStore.get( + storeId, keySchema, valueSchema, storeVersion, storeConf, confBroadcast.value.value) + val inputIter = dataRDD.iterator(partition, ctxt) + storeUpdateFunction(store, inputIter) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala index b249e37921..9b6d0918e2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala @@ -28,37 +28,36 @@ package object state { implicit class StateStoreOps[T: ClassTag](dataRDD: RDD[T]) { /** Map each partition of a RDD along with data in a [[StateStore]]. */ - def mapPartitionWithStateStore[U: ClassTag]( - storeUpdateFunction: (StateStore, Iterator[T]) => Iterator[U], + def mapPartitionsWithStateStore[U: ClassTag]( + sqlContext: SQLContext, checkpointLocation: String, operatorId: Long, storeVersion: Long, keySchema: StructType, - valueSchema: StructType - )(implicit sqlContext: SQLContext): StateStoreRDD[T, U] = { + valueSchema: StructType)( + storeUpdateFunction: (StateStore, Iterator[T]) => Iterator[U]): StateStoreRDD[T, U] = { - mapPartitionWithStateStore( - storeUpdateFunction, + mapPartitionsWithStateStore( checkpointLocation, operatorId, storeVersion, keySchema, valueSchema, new StateStoreConf(sqlContext.conf), - Some(sqlContext.streams.stateStoreCoordinator)) + Some(sqlContext.streams.stateStoreCoordinator))( + storeUpdateFunction) } /** Map each partition of a RDD along with data in a [[StateStore]]. */ - private[state] def mapPartitionWithStateStore[U: ClassTag]( - storeUpdateFunction: (StateStore, Iterator[T]) => Iterator[U], + private[streaming] def mapPartitionsWithStateStore[U: ClassTag]( checkpointLocation: String, operatorId: Long, storeVersion: Long, keySchema: StructType, valueSchema: StructType, storeConf: StateStoreConf, - storeCoordinator: Option[StateStoreCoordinatorRef] - ): StateStoreRDD[T, U] = { + storeCoordinator: Option[StateStoreCoordinatorRef])( + storeUpdateFunction: (StateStore, Iterator[T]) => Iterator[U]): StateStoreRDD[T, U] = { val cleanedF = dataRDD.sparkContext.clean(storeUpdateFunction) new StateStoreRDD( dataRDD, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala index 0d580703f5..4b3091ba22 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala @@ -17,12 +17,12 @@ package org.apache.spark.sql.execution +import org.apache.spark.sql.SQLContext import org.apache.spark.sql.catalyst.{expressions, InternalRow} import org.apache.spark.sql.catalyst.expressions.{ExprId, Literal, SubqueryExpression} import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} -import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, ReturnAnswer} +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.internal.SessionState import org.apache.spark.sql.types.DataType /** @@ -60,14 +60,13 @@ case class ScalarSubquery( } /** - * Convert the subquery from logical plan into executed plan. + * Plans scalar subqueries from that are present in the given [[SparkPlan]]. */ -case class PlanSubqueries(sessionState: SessionState) extends Rule[SparkPlan] { +case class PlanSubqueries(sqlContext: SQLContext) extends Rule[SparkPlan] { def apply(plan: SparkPlan): SparkPlan = { plan.transformAllExpressions { case subquery: expressions.ScalarSubquery => - val sparkPlan = sessionState.planner.plan(ReturnAnswer(subquery.query)).next() - val executedPlan = sessionState.prepareForExecution.execute(sparkPlan) + val executedPlan = new QueryExecution(sqlContext, subquery.plan).executedPlan ScalarSubquery(executedPlan, subquery.exprId) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala index 844f3051fa..9cb356f1ca 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala @@ -84,10 +84,10 @@ abstract class Aggregator[-I, B, O] extends Serializable { implicit bEncoder: Encoder[B], cEncoder: Encoder[O]): TypedColumn[I, O] = { val expr = - new AggregateExpression( + AggregateExpression( TypedAggregateExpression(this), Complete, - false) + isDistinct = false) new TypedColumn[I, O](expr, encoderFor[O]) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala index f7fdfacd31..cd3d254d1e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala @@ -86,20 +86,8 @@ private[sql] class SessionState(ctx: SQLContext) { /** * Planner that converts optimized logical plans to physical plans. */ - lazy val planner: SparkPlanner = new SparkPlanner(ctx.sparkContext, conf, experimentalMethods) - - /** - * Prepares a planned [[SparkPlan]] for execution by inserting shuffle operations and internal - * row format conversions as needed. - */ - lazy val prepareForExecution = new RuleExecutor[SparkPlan] { - override val batches: Seq[Batch] = Seq( - Batch("Subquery", Once, PlanSubqueries(SessionState.this)), - Batch("Add exchange", Once, EnsureRequirements(conf)), - Batch("Whole stage codegen", Once, CollapseCodegenStages(conf)), - Batch("Reuse duplicated exchanges", Once, ReuseExchange(conf)) - ) - } + def planner: SparkPlanner = + new SparkPlanner(ctx.sparkContext, conf, experimentalMethods.extraStrategies) /** * An interface to register custom [[org.apache.spark.sql.util.QueryExecutionListener]]s diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala index b5be7ef47e..550c3c6f9c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala @@ -116,15 +116,30 @@ trait StreamTest extends QueryTest with Timeouts { def apply[A : Encoder](data: A*): CheckAnswerRows = { val encoder = encoderFor[A] val toExternalRow = RowEncoder(encoder.schema) - CheckAnswerRows(data.map(d => toExternalRow.fromRow(encoder.toRow(d)))) + CheckAnswerRows(data.map(d => toExternalRow.fromRow(encoder.toRow(d))), false) } - def apply(rows: Row*): CheckAnswerRows = CheckAnswerRows(rows) + def apply(rows: Row*): CheckAnswerRows = CheckAnswerRows(rows, false) } - case class CheckAnswerRows(expectedAnswer: Seq[Row]) + /** + * Checks to make sure that the current data stored in the sink matches the `expectedAnswer`. + * This operation automatically blocks until all added data has been processed. + */ + object CheckLastBatch { + def apply[A : Encoder](data: A*): CheckAnswerRows = { + val encoder = encoderFor[A] + val toExternalRow = RowEncoder(encoder.schema) + CheckAnswerRows(data.map(d => toExternalRow.fromRow(encoder.toRow(d))), true) + } + + def apply(rows: Row*): CheckAnswerRows = CheckAnswerRows(rows, true) + } + + case class CheckAnswerRows(expectedAnswer: Seq[Row], lastOnly: Boolean) extends StreamAction with StreamMustBeRunning { - override def toString: String = s"CheckAnswer: ${expectedAnswer.mkString(",")}" + override def toString: String = s"$operatorName: ${expectedAnswer.mkString(",")}" + private def operatorName = if (lastOnly) "CheckLastBatch" else "CheckAnswer" } /** Stops the stream. It must currently be running. */ @@ -224,11 +239,8 @@ trait StreamTest extends QueryTest with Timeouts { """.stripMargin def verify(condition: => Boolean, message: String): Unit = { - try { - Assertions.assert(condition) - } catch { - case NonFatal(e) => - failTest(message, e) + if (!condition) { + failTest(message) } } @@ -351,7 +363,7 @@ trait StreamTest extends QueryTest with Timeouts { case a: AddData => awaiting.put(a.source, a.addData()) - case CheckAnswerRows(expectedAnswer) => + case CheckAnswerRows(expectedAnswer, lastOnly) => verify(currentStream != null, "stream not running") // Block until all data added has been processed @@ -361,12 +373,12 @@ trait StreamTest extends QueryTest with Timeouts { } } - val allData = try sink.allData catch { + val sparkAnswer = try if (lastOnly) sink.lastBatch else sink.allData catch { case e: Exception => failTest("Exception while getting data from sink", e) } - QueryTest.sameRows(expectedAnswer, allData).foreach { + QueryTest.sameRows(expectedAnswer, sparkAnswer).foreach { error => failTest(error) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala index ed0d3f56e5..38318740a5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala @@ -231,10 +231,8 @@ object SparkPlanTest { } private def executePlan(outputPlan: SparkPlan, sqlContext: SQLContext): Seq[Row] = { - // A very simple resolver to make writing tests easier. In contrast to the real resolver - // this is always case sensitive and does not try to handle scoping or complex type resolution. - val resolvedPlan = sqlContext.sessionState.prepareForExecution.execute( - outputPlan transform { + val execution = new QueryExecution(sqlContext, null) { + override lazy val sparkPlan: SparkPlan = outputPlan transform { case plan: SparkPlan => val inputMap = plan.children.flatMap(_.output).map(a => (a.name, a)).toMap plan transformExpressions { @@ -243,8 +241,8 @@ object SparkPlanTest { sys.error(s"Invalid Test: Cannot resolve $u given input $inputMap")) } } - ) - resolvedPlan.executeCollectPublic().toSeq + } + execution.executedPlan.executeCollectPublic().toSeq } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala index 85db05157c..6be94eb24f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala @@ -33,7 +33,7 @@ import org.apache.spark.scheduler.ExecutorCacheTaskLocation import org.apache.spark.sql.SQLContext import org.apache.spark.sql.catalyst.util.quietly import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType} -import org.apache.spark.util.Utils +import org.apache.spark.util.{CompletionIterator, Utils} class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAndAfterAll { @@ -54,62 +54,93 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn } test("versioning and immutability") { - quietly { - withSpark(new SparkContext(sparkConf)) { sc => - implicit val sqlContet = new SQLContext(sc) - val path = Utils.createDirectory(tempDir, Random.nextString(10)).toString - val increment = (store: StateStore, iter: Iterator[String]) => { - iter.foreach { s => - store.update( - stringToRow(s), oldRow => { - val oldValue = oldRow.map(rowToInt).getOrElse(0) - intToRow(oldValue + 1) - }) - } - store.commit() - store.iterator().map(rowsToStringInt) - } - val opId = 0 - val rdd1 = makeRDD(sc, Seq("a", "b", "a")).mapPartitionWithStateStore( - increment, path, opId, storeVersion = 0, keySchema, valueSchema) - assert(rdd1.collect().toSet === Set("a" -> 2, "b" -> 1)) + withSpark(new SparkContext(sparkConf)) { sc => + val sqlContext = new SQLContext(sc) + val path = Utils.createDirectory(tempDir, Random.nextString(10)).toString + val opId = 0 + val rdd1 = + makeRDD(sc, Seq("a", "b", "a")).mapPartitionsWithStateStore( + sqlContext, path, opId, storeVersion = 0, keySchema, valueSchema)( + increment) + assert(rdd1.collect().toSet === Set("a" -> 2, "b" -> 1)) + + // Generate next version of stores + val rdd2 = makeRDD(sc, Seq("a", "c")).mapPartitionsWithStateStore( + sqlContext, path, opId, storeVersion = 1, keySchema, valueSchema)(increment) + assert(rdd2.collect().toSet === Set("a" -> 3, "b" -> 1, "c" -> 1)) + + // Make sure the previous RDD still has the same data. + assert(rdd1.collect().toSet === Set("a" -> 2, "b" -> 1)) + } + } - // Generate next version of stores - val rdd2 = makeRDD(sc, Seq("a", "c")).mapPartitionWithStateStore( - increment, path, opId, storeVersion = 1, keySchema, valueSchema) - assert(rdd2.collect().toSet === Set("a" -> 3, "b" -> 1, "c" -> 1)) + test("recovering from files") { + val opId = 0 + val path = Utils.createDirectory(tempDir, Random.nextString(10)).toString + + def makeStoreRDD( + sc: SparkContext, + seq: Seq[String], + storeVersion: Int): RDD[(String, Int)] = { + implicit val sqlContext = new SQLContext(sc) + makeRDD(sc, Seq("a")).mapPartitionsWithStateStore( + sqlContext, path, opId, storeVersion, keySchema, valueSchema)(increment) + } - // Make sure the previous RDD still has the same data. - assert(rdd1.collect().toSet === Set("a" -> 2, "b" -> 1)) + // Generate RDDs and state store data + withSpark(new SparkContext(sparkConf)) { sc => + for (i <- 1 to 20) { + require(makeStoreRDD(sc, Seq("a"), i - 1).collect().toSet === Set("a" -> i)) } } + + // With a new context, try using the earlier state store data + withSpark(new SparkContext(sparkConf)) { sc => + assert(makeStoreRDD(sc, Seq("a"), 20).collect().toSet === Set("a" -> 21)) + } } - test("recovering from files") { - quietly { - val opId = 0 + test("usage with iterators - only gets and only puts") { + withSpark(new SparkContext(sparkConf)) { sc => + implicit val sqlContext = new SQLContext(sc) val path = Utils.createDirectory(tempDir, Random.nextString(10)).toString + val opId = 0 - def makeStoreRDD( - sc: SparkContext, - seq: Seq[String], - storeVersion: Int): RDD[(String, Int)] = { - implicit val sqlContext = new SQLContext(sc) - makeRDD(sc, Seq("a")).mapPartitionWithStateStore( - increment, path, opId, storeVersion, keySchema, valueSchema) + // Returns an iterator of the incremented value made into the store + def iteratorOfPuts(store: StateStore, iter: Iterator[String]): Iterator[(String, Int)] = { + val resIterator = iter.map { s => + val key = stringToRow(s) + val oldValue = store.get(key).map(rowToInt).getOrElse(0) + val newValue = oldValue + 1 + store.put(key, intToRow(newValue)) + (s, newValue) + } + CompletionIterator[(String, Int), Iterator[(String, Int)]](resIterator, { + store.commit() + }) } - // Generate RDDs and state store data - withSpark(new SparkContext(sparkConf)) { sc => - for (i <- 1 to 20) { - require(makeStoreRDD(sc, Seq("a"), i - 1).collect().toSet === Set("a" -> i)) + def iteratorOfGets( + store: StateStore, + iter: Iterator[String]): Iterator[(String, Option[Int])] = { + iter.map { s => + val key = stringToRow(s) + val value = store.get(key).map(rowToInt) + (s, value) } } - // With a new context, try using the earlier state store data - withSpark(new SparkContext(sparkConf)) { sc => - assert(makeStoreRDD(sc, Seq("a"), 20).collect().toSet === Set("a" -> 21)) - } + val rddOfGets1 = makeRDD(sc, Seq("a", "b", "c")).mapPartitionsWithStateStore( + sqlContext, path, opId, storeVersion = 0, keySchema, valueSchema)(iteratorOfGets) + assert(rddOfGets1.collect().toSet === Set("a" -> None, "b" -> None, "c" -> None)) + + val rddOfPuts = makeRDD(sc, Seq("a", "b", "a")).mapPartitionsWithStateStore( + sqlContext, path, opId, storeVersion = 0, keySchema, valueSchema)(iteratorOfPuts) + assert(rddOfPuts.collect().toSet === Set("a" -> 1, "a" -> 2, "b" -> 1)) + + val rddOfGets2 = makeRDD(sc, Seq("a", "b", "c")).mapPartitionsWithStateStore( + sqlContext, path, opId, storeVersion = 1, keySchema, valueSchema)(iteratorOfGets) + assert(rddOfGets2.collect().toSet === Set("a" -> Some(2), "b" -> Some(1), "c" -> None)) } } @@ -128,8 +159,8 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn coordinatorRef.getLocation(StateStoreId(path, opId, 0)) === Some(ExecutorCacheTaskLocation("host1", "exec1").toString)) - val rdd = makeRDD(sc, Seq("a", "b", "a")).mapPartitionWithStateStore( - increment, path, opId, storeVersion = 0, keySchema, valueSchema) + val rdd = makeRDD(sc, Seq("a", "b", "a")).mapPartitionsWithStateStore( + sqlContext, path, opId, storeVersion = 0, keySchema, valueSchema)(increment) require(rdd.partitions.length === 2) assert( @@ -148,27 +179,16 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn test("distributed test") { quietly { withSpark(new SparkContext(sparkConf.setMaster("local-cluster[2, 1, 1024]"))) { sc => - implicit val sqlContet = new SQLContext(sc) + implicit val sqlContext = new SQLContext(sc) val path = Utils.createDirectory(tempDir, Random.nextString(10)).toString - val increment = (store: StateStore, iter: Iterator[String]) => { - iter.foreach { s => - store.update( - stringToRow(s), oldRow => { - val oldValue = oldRow.map(rowToInt).getOrElse(0) - intToRow(oldValue + 1) - }) - } - store.commit() - store.iterator().map(rowsToStringInt) - } val opId = 0 - val rdd1 = makeRDD(sc, Seq("a", "b", "a")).mapPartitionWithStateStore( - increment, path, opId, storeVersion = 0, keySchema, valueSchema) + val rdd1 = makeRDD(sc, Seq("a", "b", "a")).mapPartitionsWithStateStore( + sqlContext, path, opId, storeVersion = 0, keySchema, valueSchema)(increment) assert(rdd1.collect().toSet === Set("a" -> 2, "b" -> 1)) // Generate next version of stores - val rdd2 = makeRDD(sc, Seq("a", "c")).mapPartitionWithStateStore( - increment, path, opId, storeVersion = 1, keySchema, valueSchema) + val rdd2 = makeRDD(sc, Seq("a", "c")).mapPartitionsWithStateStore( + sqlContext, path, opId, storeVersion = 1, keySchema, valueSchema)(increment) assert(rdd2.collect().toSet === Set("a" -> 3, "b" -> 1, "c" -> 1)) // Make sure the previous RDD still has the same data. @@ -183,11 +203,9 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn private val increment = (store: StateStore, iter: Iterator[String]) => { iter.foreach { s => - store.update( - stringToRow(s), oldRow => { - val oldValue = oldRow.map(rowToInt).getOrElse(0) - intToRow(oldValue + 1) - }) + val key = stringToRow(s) + val oldValue = store.get(key).map(rowToInt).getOrElse(0) + store.put(key, intToRow(oldValue + 1)) } store.commit() store.iterator().map(rowsToStringInt) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala index 22b2f4f75d..0e5936d53f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala @@ -51,7 +51,7 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth StateStore.stop() } - test("update, remove, commit, and all data iterator") { + test("get, put, remove, commit, and all data iterator") { val provider = newStoreProvider() // Verify state before starting a new set of updates @@ -67,7 +67,7 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth } // Verify state after updating - update(store, "a", 1) + put(store, "a", 1) intercept[IllegalStateException] { store.iterator() } @@ -77,8 +77,8 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth assert(provider.latestIterator().isEmpty) // Make updates, commit and then verify state - update(store, "b", 2) - update(store, "aa", 3) + put(store, "b", 2) + put(store, "aa", 3) remove(store, _.startsWith("a")) assert(store.commit() === 1) @@ -101,7 +101,7 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth val reloadedProvider = new HDFSBackedStateStoreProvider( store.id, keySchema, valueSchema, StateStoreConf.empty, new Configuration) val reloadedStore = reloadedProvider.getStore(1) - update(reloadedStore, "c", 4) + put(reloadedStore, "c", 4) assert(reloadedStore.commit() === 2) assert(rowsToSet(reloadedStore.iterator()) === Set("b" -> 2, "c" -> 4)) assert(getDataFromFiles(provider) === Set("b" -> 2, "c" -> 4)) @@ -112,6 +112,7 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth test("updates iterator with all combos of updates and removes") { val provider = newStoreProvider() var currentVersion: Int = 0 + def withStore(body: StateStore => Unit): Unit = { val store = provider.getStore(currentVersion) body(store) @@ -120,9 +121,9 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth // New data should be seen in updates as value added, even if they had multiple updates withStore { store => - update(store, "a", 1) - update(store, "aa", 1) - update(store, "aa", 2) + put(store, "a", 1) + put(store, "aa", 1) + put(store, "aa", 2) store.commit() assert(updatesToSet(store.updates()) === Set(Added("a", 1), Added("aa", 2))) assert(rowsToSet(store.iterator()) === Set("a" -> 1, "aa" -> 2)) @@ -131,8 +132,8 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth // Multiple updates to same key should be collapsed in the updates as a single value update // Keys that have not been updated should not appear in the updates withStore { store => - update(store, "a", 4) - update(store, "a", 6) + put(store, "a", 4) + put(store, "a", 6) store.commit() assert(updatesToSet(store.updates()) === Set(Updated("a", 6))) assert(rowsToSet(store.iterator()) === Set("a" -> 6, "aa" -> 2)) @@ -140,9 +141,9 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth // Keys added, updated and finally removed before commit should not appear in updates withStore { store => - update(store, "b", 4) // Added, finally removed - update(store, "bb", 5) // Added, updated, finally removed - update(store, "bb", 6) + put(store, "b", 4) // Added, finally removed + put(store, "bb", 5) // Added, updated, finally removed + put(store, "bb", 6) remove(store, _.startsWith("b")) store.commit() assert(updatesToSet(store.updates()) === Set.empty) @@ -153,7 +154,7 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth // Removed, but re-added data should be seen in updates as a value update withStore { store => remove(store, _.startsWith("a")) - update(store, "a", 10) + put(store, "a", 10) store.commit() assert(updatesToSet(store.updates()) === Set(Updated("a", 10), Removed("aa"))) assert(rowsToSet(store.iterator()) === Set("a" -> 10)) @@ -163,14 +164,14 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth test("cancel") { val provider = newStoreProvider() val store = provider.getStore(0) - update(store, "a", 1) + put(store, "a", 1) store.commit() assert(rowsToSet(store.iterator()) === Set("a" -> 1)) // cancelUpdates should not change the data in the files val store1 = provider.getStore(1) - update(store1, "b", 1) - store1.cancel() + put(store1, "b", 1) + store1.abort() assert(getDataFromFiles(provider) === Set("a" -> 1)) } @@ -183,7 +184,7 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth // Prepare some data in the stoer val store = provider.getStore(0) - update(store, "a", 1) + put(store, "a", 1) assert(store.commit() === 1) assert(rowsToSet(store.iterator()) === Set("a" -> 1)) @@ -193,14 +194,14 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth // Update store version with some data val store1 = provider.getStore(1) - update(store1, "b", 1) + put(store1, "b", 1) assert(store1.commit() === 2) assert(rowsToSet(store1.iterator()) === Set("a" -> 1, "b" -> 1)) assert(getDataFromFiles(provider) === Set("a" -> 1, "b" -> 1)) // Overwrite the version with other data val store2 = provider.getStore(1) - update(store2, "c", 1) + put(store2, "c", 1) assert(store2.commit() === 2) assert(rowsToSet(store2.iterator()) === Set("a" -> 1, "c" -> 1)) assert(getDataFromFiles(provider) === Set("a" -> 1, "c" -> 1)) @@ -213,7 +214,7 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth def updateVersionTo(targetVersion: Int): Unit = { for (i <- currentVersion + 1 to targetVersion) { val store = provider.getStore(currentVersion) - update(store, "a", i) + put(store, "a", i) store.commit() currentVersion += 1 } @@ -264,7 +265,7 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth for (i <- 1 to 20) { val store = provider.getStore(i - 1) - update(store, "a", i) + put(store, "a", i) store.commit() provider.doMaintenance() // do cleanup } @@ -284,7 +285,7 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth val provider = newStoreProvider(minDeltasForSnapshot = 5) for (i <- 1 to 6) { val store = provider.getStore(i - 1) - update(store, "a", i) + put(store, "a", i) store.commit() provider.doMaintenance() // do cleanup } @@ -333,7 +334,7 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth // Increase version of the store val store0 = StateStore.get(storeId, keySchema, valueSchema, 0, storeConf, hadoopConf) assert(store0.version === 0) - update(store0, "a", 1) + put(store0, "a", 1) store0.commit() assert(StateStore.get(storeId, keySchema, valueSchema, 1, storeConf, hadoopConf).version == 1) @@ -345,7 +346,7 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth val store1 = StateStore.get(storeId, keySchema, valueSchema, 1, storeConf, hadoopConf) assert(StateStore.isLoaded(storeId)) - update(store1, "a", 2) + put(store1, "a", 2) assert(store1.commit() === 2) assert(rowsToSet(store1.iterator()) === Set("a" -> 2)) } @@ -371,7 +372,7 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth for (i <- 1 to 20) { val store = StateStore.get( storeId, keySchema, valueSchema, i - 1, storeConf, hadoopConf) - update(store, "a", i) + put(store, "a", i) store.commit() } eventually(timeout(10 seconds)) { @@ -507,8 +508,12 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth store.remove(row => condition(rowToString(row))) } - private def update(store: StateStore, key: String, value: Int): Unit = { - store.update(stringToRow(key), _ => intToRow(value)) + private def put(store: StateStore, key: String, value: Int): Unit = { + store.put(stringToRow(key), intToRow(value)) + } + + private def get(store: StateStore, key: String): Option[Int] = { + store.get(stringToRow(key)).map(rowToInt) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala new file mode 100644 index 0000000000..b63ce89d18 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala @@ -0,0 +1,132 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.streaming + +import org.apache.spark.SparkException +import org.apache.spark.sql.{Encoder, StreamTest, SumOf, TypedColumn} +import org.apache.spark.sql.execution.streaming._ +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.test.SharedSQLContext + +object FailureSinglton { + var firstTime = true +} + +class StreamingAggregationSuite extends StreamTest with SharedSQLContext { + + import testImplicits._ + + test("simple count") { + val inputData = MemoryStream[Int] + + val aggregated = + inputData.toDF() + .groupBy($"value") + .agg(count("*")) + .as[(Int, Long)] + + testStream(aggregated)( + AddData(inputData, 3), + CheckLastBatch((3, 1)), + AddData(inputData, 3, 2), + CheckLastBatch((3, 2), (2, 1)), + StopStream, + StartStream, + AddData(inputData, 3, 2, 1), + CheckLastBatch((3, 3), (2, 2), (1, 1)), + // By default we run in new tuple mode. + AddData(inputData, 4, 4, 4, 4), + CheckLastBatch((4, 4)) + ) + } + + test("multiple keys") { + val inputData = MemoryStream[Int] + + val aggregated = + inputData.toDF() + .groupBy($"value", $"value" + 1) + .agg(count("*")) + .as[(Int, Int, Long)] + + testStream(aggregated)( + AddData(inputData, 1, 2), + CheckLastBatch((1, 2, 1), (2, 3, 1)), + AddData(inputData, 1, 2), + CheckLastBatch((1, 2, 2), (2, 3, 2)) + ) + } + + test("multiple aggregations") { + val inputData = MemoryStream[Int] + + val aggregated = + inputData.toDF() + .groupBy($"value") + .agg(count("*") as 'count) + .groupBy($"value" % 2) + .agg(sum($"count")) + .as[(Int, Long)] + + testStream(aggregated)( + AddData(inputData, 1, 2, 3, 4), + CheckLastBatch((0, 2), (1, 2)), + AddData(inputData, 1, 3, 5), + CheckLastBatch((1, 5)) + ) + } + + testQuietly("midbatch failure") { + val inputData = MemoryStream[Int] + FailureSinglton.firstTime = true + val aggregated = + inputData.toDS() + .map { i => + if (i == 4 && FailureSinglton.firstTime) { + FailureSinglton.firstTime = false + sys.error("injected failure") + } + + i + } + .groupBy($"value") + .agg(count("*")) + .as[(Int, Long)] + + testStream(aggregated)( + StartStream, + AddData(inputData, 1, 2, 3, 4), + ExpectFailure[SparkException](), + StartStream, + CheckLastBatch((1, 1), (2, 1), (3, 1), (4, 1)) + ) + } + + test("typed aggregators") { + def sum[I, N : Numeric : Encoder](f: I => N): TypedColumn[I, N] = + new SumOf(f).toColumn + + val inputData = MemoryStream[(String, Int)] + val aggregated = inputData.toDS().groupByKey(_._1).agg(sum(_._2)) + + testStream(aggregated)( + AddData(inputData, ("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)), + CheckLastBatch(("a", 30), ("b", 3), ("c", 1)) + ) + } +} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala index 2bdb428e9d..ff40c366c8 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala @@ -77,8 +77,9 @@ private[hive] class HiveSessionState(ctx: HiveContext) extends SessionState(ctx) /** * Planner that takes into account Hive-specific strategies. */ - override lazy val planner: SparkPlanner = { - new SparkPlanner(ctx.sparkContext, conf, experimentalMethods) with HiveStrategies { + override def planner: SparkPlanner = { + new SparkPlanner(ctx.sparkContext, conf, experimentalMethods.extraStrategies) + with HiveStrategies { override val hiveContext = ctx override def strategies: Seq[Strategy] = { -- cgit v1.2.3 From c16a396886672493df694f3ca30478c8edb771f0 Mon Sep 17 00:00:00 2001 From: Jacek Laskowski Date: Fri, 1 Apr 2016 15:21:29 -0700 Subject: [SPARK-13825][CORE] Upgrade to Scala 2.11.8 ## What changes were proposed in this pull request? Upgrade to 2.11.8 (from the current 2.11.7) ## How was this patch tested? A manual build Author: Jacek Laskowski Closes #11681 from jaceklaskowski/SPARK-13825-scala-2_11_8. --- dev/deps/spark-deps-hadoop-2.2 | 8 ++++---- dev/deps/spark-deps-hadoop-2.3 | 8 ++++---- dev/deps/spark-deps-hadoop-2.4 | 8 ++++---- dev/deps/spark-deps-hadoop-2.6 | 8 ++++---- dev/deps/spark-deps-hadoop-2.7 | 8 ++++---- pom.xml | 6 +++--- 6 files changed, 23 insertions(+), 23 deletions(-) diff --git a/dev/deps/spark-deps-hadoop-2.2 b/dev/deps/spark-deps-hadoop-2.2 index 115018e7c1..3865a9fb16 100644 --- a/dev/deps/spark-deps-hadoop-2.2 +++ b/dev/deps/spark-deps-hadoop-2.2 @@ -158,12 +158,12 @@ protobuf-java-2.5.0.jar py4j-0.9.2.jar pyrolite-4.9.jar reflectasm-1.07-shaded.jar -scala-compiler-2.11.7.jar -scala-library-2.11.7.jar +scala-compiler-2.11.8.jar +scala-library-2.11.8.jar scala-parser-combinators_2.11-1.0.4.jar -scala-reflect-2.11.7.jar +scala-reflect-2.11.8.jar scala-xml_2.11-1.0.2.jar -scalap-2.11.7.jar +scalap-2.11.8.jar servlet-api-2.5.jar slf4j-api-1.7.16.jar slf4j-log4j12-1.7.16.jar diff --git a/dev/deps/spark-deps-hadoop-2.3 b/dev/deps/spark-deps-hadoop-2.3 index 246d1147bf..4313799da7 100644 --- a/dev/deps/spark-deps-hadoop-2.3 +++ b/dev/deps/spark-deps-hadoop-2.3 @@ -149,12 +149,12 @@ protobuf-java-2.5.0.jar py4j-0.9.2.jar pyrolite-4.9.jar reflectasm-1.07-shaded.jar -scala-compiler-2.11.7.jar -scala-library-2.11.7.jar +scala-compiler-2.11.8.jar +scala-library-2.11.8.jar scala-parser-combinators_2.11-1.0.4.jar -scala-reflect-2.11.7.jar +scala-reflect-2.11.8.jar scala-xml_2.11-1.0.2.jar -scalap-2.11.7.jar +scalap-2.11.8.jar servlet-api-2.5.jar slf4j-api-1.7.16.jar slf4j-log4j12-1.7.16.jar diff --git a/dev/deps/spark-deps-hadoop-2.4 b/dev/deps/spark-deps-hadoop-2.4 index 0e2cdaf0d2..910ea685f2 100644 --- a/dev/deps/spark-deps-hadoop-2.4 +++ b/dev/deps/spark-deps-hadoop-2.4 @@ -150,12 +150,12 @@ protobuf-java-2.5.0.jar py4j-0.9.2.jar pyrolite-4.9.jar reflectasm-1.07-shaded.jar -scala-compiler-2.11.7.jar -scala-library-2.11.7.jar +scala-compiler-2.11.8.jar +scala-library-2.11.8.jar scala-parser-combinators_2.11-1.0.4.jar -scala-reflect-2.11.7.jar +scala-reflect-2.11.8.jar scala-xml_2.11-1.0.2.jar -scalap-2.11.7.jar +scalap-2.11.8.jar servlet-api-2.5.jar slf4j-api-1.7.16.jar slf4j-log4j12-1.7.16.jar diff --git a/dev/deps/spark-deps-hadoop-2.6 b/dev/deps/spark-deps-hadoop-2.6 index 1ed15595be..0692f24e47 100644 --- a/dev/deps/spark-deps-hadoop-2.6 +++ b/dev/deps/spark-deps-hadoop-2.6 @@ -156,12 +156,12 @@ protobuf-java-2.5.0.jar py4j-0.9.2.jar pyrolite-4.9.jar reflectasm-1.07-shaded.jar -scala-compiler-2.11.7.jar -scala-library-2.11.7.jar +scala-compiler-2.11.8.jar +scala-library-2.11.8.jar scala-parser-combinators_2.11-1.0.4.jar -scala-reflect-2.11.7.jar +scala-reflect-2.11.8.jar scala-xml_2.11-1.0.2.jar -scalap-2.11.7.jar +scalap-2.11.8.jar servlet-api-2.5.jar slf4j-api-1.7.16.jar slf4j-log4j12-1.7.16.jar diff --git a/dev/deps/spark-deps-hadoop-2.7 b/dev/deps/spark-deps-hadoop-2.7 index 218631ed6e..e397558e05 100644 --- a/dev/deps/spark-deps-hadoop-2.7 +++ b/dev/deps/spark-deps-hadoop-2.7 @@ -157,12 +157,12 @@ protobuf-java-2.5.0.jar py4j-0.9.2.jar pyrolite-4.9.jar reflectasm-1.07-shaded.jar -scala-compiler-2.11.7.jar -scala-library-2.11.7.jar +scala-compiler-2.11.8.jar +scala-library-2.11.8.jar scala-parser-combinators_2.11-1.0.4.jar -scala-reflect-2.11.7.jar +scala-reflect-2.11.8.jar scala-xml_2.11-1.0.2.jar -scalap-2.11.7.jar +scalap-2.11.8.jar servlet-api-2.5.jar slf4j-api-1.7.16.jar slf4j-log4j12-1.7.16.jar diff --git a/pom.xml b/pom.xml index be80e6b80c..e135c92c07 100644 --- a/pom.xml +++ b/pom.xml @@ -156,7 +156,7 @@ 3.4.1 3.2.2 - 2.11.7 + 2.11.8 2.11 ${scala.version} org.scala-lang @@ -2448,7 +2448,7 @@ scala-2.10 - 2.10.5 + 2.10.6 2.10 ${scala.version} org.scala-lang @@ -2480,7 +2480,7 @@ !scala-2.10 - 2.11.7 + 2.11.8 2.11 -- cgit v1.2.3 From 19f32f2d99c3620c0e562a98f7890316ddad1de9 Mon Sep 17 00:00:00 2001 From: Liwei Lin Date: Fri, 1 Apr 2016 15:26:22 -0700 Subject: [SPARK-12857][STREAMING] Standardize "records" and "events" on "records" ## What changes were proposed in this pull request? Currently the Streaming tab in web UI uses records and events interchangeably; this PR tries to standardize them on "records". "records" is chosen over "events" because: - "records" is used extensively throughout the streaming documents, codes, and comments - "events" is used only in Streaming UI related codes and comments ## How was this patch tested? - existing test suites - manually checking on the Streaming UI tab Author: Liwei Lin Closes #12032 from lw-lin/streaming-events-to-records. --- .../spark/streaming/receiver/RateLimiter.scala | 2 +- .../spark/streaming/ui/AllBatchesTable.scala | 4 +- .../ui/StreamingJobProgressListener.scala | 10 ++-- .../apache/spark/streaming/ui/StreamingPage.scala | 65 +++++++++++----------- 4 files changed, 41 insertions(+), 40 deletions(-) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/RateLimiter.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/RateLimiter.scala index b2189103a0..0a861f22b1 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/RateLimiter.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/RateLimiter.scala @@ -52,7 +52,7 @@ private[receiver] abstract class RateLimiter(conf: SparkConf) extends Logging { * Set the rate limit to `newRate`. The new rate will not exceed the maximum rate configured by * {{{spark.streaming.receiver.maxRate}}}, even if `newRate` is higher than that. * - * @param newRate A new rate in events per second. It has no effect if it's 0 or negative. + * @param newRate A new rate in records per second. It has no effect if it's 0 or negative. */ private[receiver] def updateRate(newRate: Long): Unit = if (newRate > 0) { diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/AllBatchesTable.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/AllBatchesTable.scala index d339723427..c024b4ef7e 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/ui/AllBatchesTable.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/AllBatchesTable.scala @@ -52,7 +52,7 @@ private[ui] abstract class BatchTableBase(tableId: String, batchInterval: Long) protected def baseRow(batch: BatchUIData): Seq[Node] = { val batchTime = batch.batchTime.milliseconds val formattedBatchTime = UIUtils.formatBatchTime(batchTime, batchInterval) - val eventCount = batch.numRecords + val numRecords = batch.numRecords val schedulingDelay = batch.schedulingDelay val formattedSchedulingDelay = schedulingDelay.map(SparkUIUtils.formatDuration).getOrElse("-") val processingTime = batch.processingDelay @@ -65,7 +65,7 @@ private[ui] abstract class BatchTableBase(tableId: String, batchInterval: Long) {formattedBatchTime} - + diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingJobProgressListener.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingJobProgressListener.scala index d6fcc582b9..6985c37f71 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingJobProgressListener.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingJobProgressListener.scala @@ -202,21 +202,21 @@ private[streaming] class StreamingJobProgressListener(ssc: StreamingContext) def streamIds: Seq[Int] = ssc.graph.getInputStreams().map(_.id) /** - * Return all of the event rates for each InputDStream in each batch. The key of the return value - * is the stream id, and the value is a sequence of batch time with its event rate. + * Return all of the record rates for each InputDStream in each batch. The key of the return value + * is the stream id, and the value is a sequence of batch time with its record rate. */ - def receivedEventRateWithBatchTime: Map[Int, Seq[(Long, Double)]] = synchronized { + def receivedRecordRateWithBatchTime: Map[Int, Seq[(Long, Double)]] = synchronized { val _retainedBatches = retainedBatches val latestBatches = _retainedBatches.map { batchUIData => (batchUIData.batchTime.milliseconds, batchUIData.streamIdToInputInfo.mapValues(_.numRecords)) } streamIds.map { streamId => - val eventRates = latestBatches.map { + val recordRates = latestBatches.map { case (batchTime, streamIdToNumRecords) => val numRecords = streamIdToNumRecords.getOrElse(streamId, 0L) (batchTime, numRecords * 1000.0 / batchDuration) } - (streamId, eventRates) + (streamId, recordRates) }.toMap } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala index fa40436221..b97e24f28b 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala @@ -125,9 +125,9 @@ private[ui] class MillisecondsStatUIData(data: Seq[(Long, Long)]) { * A helper class for "input rate" to generate data that will be used in the timeline and histogram * graphs. * - * @param data (batchTime, event-rate). + * @param data (batch time, record rate). */ -private[ui] class EventRateUIData(val data: Seq[(Long, Double)]) { +private[ui] class RecordRateUIData(val data: Seq[(Long, Double)]) { val avg: Option[Double] = if (data.isEmpty) None else Some(data.map(_._2).sum / data.size) @@ -215,7 +215,7 @@ private[ui] class StreamingPage(parent: StreamingTab) val minBatchTime = if (batchTimes.isEmpty) startTime else batchTimes.min val maxBatchTime = if (batchTimes.isEmpty) startTime else batchTimes.max - val eventRateForAllStreams = new EventRateUIData(batches.map { batchInfo => + val recordRateForAllStreams = new RecordRateUIData(batches.map { batchInfo => (batchInfo.batchTime.milliseconds, batchInfo.numRecords * 1000.0 / listener.batchDuration) }) @@ -241,24 +241,24 @@ private[ui] class StreamingPage(parent: StreamingTab) // Use the max input rate for all InputDStreams' graphs to make the Y axis ranges same. // If it's not an integral number, just use its ceil integral number. - val maxEventRate = eventRateForAllStreams.max.map(_.ceil.toLong).getOrElse(0L) - val minEventRate = 0L + val maxRecordRate = recordRateForAllStreams.max.map(_.ceil.toLong).getOrElse(0L) + val minRecordRate = 0L val batchInterval = UIUtils.convertToTimeUnit(listener.batchDuration, normalizedUnit) val jsCollector = new JsCollector - val graphUIDataForEventRateOfAllStreams = + val graphUIDataForRecordRateOfAllStreams = new GraphUIData( - "all-stream-events-timeline", - "all-stream-events-histogram", - eventRateForAllStreams.data, + "all-stream-records-timeline", + "all-stream-records-histogram", + recordRateForAllStreams.data, minBatchTime, maxBatchTime, - minEventRate, - maxEventRate, - "events/sec") - graphUIDataForEventRateOfAllStreams.generateDataJs(jsCollector) + minRecordRate, + maxRecordRate, + "records/sec") + graphUIDataForRecordRateOfAllStreams.generateDataJs(jsCollector) val graphUIDataForSchedulingDelay = new GraphUIData( @@ -334,16 +334,16 @@ private[ui] class StreamingPage(parent: StreamingTab)
Receivers: {listener.numActiveReceivers} / {numReceivers} active
} } -
Avg: {eventRateForAllStreams.formattedAvg} events/sec
+
Avg: {recordRateForAllStreams.formattedAvg} records/sec
- - + + {if (hasStream) { }} @@ -390,15 +390,16 @@ private[ui] class StreamingPage(parent: StreamingTab) maxX: Long, minY: Double, maxY: Double): Seq[Node] = { - val maxYCalculated = listener.receivedEventRateWithBatchTime.values - .flatMap { case streamAndRates => streamAndRates.map { case (_, eventRate) => eventRate } } + val maxYCalculated = listener.receivedRecordRateWithBatchTime.values + .flatMap { case streamAndRates => streamAndRates.map { case (_, recordRate) => recordRate } } .reduceOption[Double](math.max) .map(_.ceil.toLong) .getOrElse(0L) - val content = listener.receivedEventRateWithBatchTime.toList.sortBy(_._1).map { - case (streamId, eventRates) => - generateInputDStreamRow(jsCollector, streamId, eventRates, minX, maxX, minY, maxYCalculated) + val content = listener.receivedRecordRateWithBatchTime.toList.sortBy(_._1).map { + case (streamId, recordRates) => + generateInputDStreamRow( + jsCollector, streamId, recordRates, minX, maxX, minY, maxYCalculated) }.foldLeft[Seq[Node]](Nil)(_ ++ _) // scalastyle:off @@ -422,7 +423,7 @@ private[ui] class StreamingPage(parent: StreamingTab) private def generateInputDStreamRow( jsCollector: JsCollector, streamId: Int, - eventRates: Seq[(Long, Double)], + recordRates: Seq[(Long, Double)], minX: Long, maxX: Long, minY: Double, @@ -447,25 +448,25 @@ private[ui] class StreamingPage(parent: StreamingTab) val receiverLastErrorTime = receiverInfo.map { r => if (r.lastErrorTime < 0) "-" else SparkUIUtils.formatDate(r.lastErrorTime) }.getOrElse(emptyCell) - val receivedRecords = new EventRateUIData(eventRates) + val receivedRecords = new RecordRateUIData(recordRates) - val graphUIDataForEventRate = + val graphUIDataForRecordRate = new GraphUIData( - s"stream-$streamId-events-timeline", - s"stream-$streamId-events-histogram", + s"stream-$streamId-records-timeline", + s"stream-$streamId-records-histogram", receivedRecords.data, minX, maxX, minY, maxY, - "events/sec") - graphUIDataForEventRate.generateDataJs(jsCollector) + "records/sec") + graphUIDataForRecordRate.generateDataJs(jsCollector) @@ -475,9 +476,9 @@ private[ui] class StreamingPage(parent: StreamingTab) - + } -- cgit v1.2.3 From abc6c42c2d76d6aa4a8cac605cb4214e8f8f3328 Mon Sep 17 00:00:00 2001 From: Alex Bozarth Date: Fri, 1 Apr 2016 16:18:09 -0700 Subject: [SPARK-13241][WEB UI] Added long values for dates in ApplicationAttemptInfo API ## What changes were proposed in this pull request? Adding long values for each Date in the ApplicationAttemptInfo API for easier use in code ## How was the this patch tested? Tested with dev/run-tests Author: Alex Bozarth Closes #11326 from ajbozarth/spark13241. --- .../scala/org/apache/spark/status/api/v1/api.scala | 6 +++++- .../application_list_json_expectation.json | 24 ++++++++++++++++++++++ .../completed_app_list_json_expectation.json | 24 ++++++++++++++++++++++ .../maxDate2_app_list_json_expectation.json | 3 +++ .../maxDate_app_list_json_expectation.json | 6 ++++++ .../minDate_app_list_json_expectation.json | 18 ++++++++++++++++ .../one_app_json_expectation.json | 3 +++ .../one_app_multi_attempt_json_expectation.json | 6 ++++++ .../spark/deploy/history/HistoryServerSuite.scala | 4 +++- 9 files changed, 92 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/api.scala b/core/src/main/scala/org/apache/spark/status/api/v1/api.scala index d43868bbcb..ebbbf48148 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/api.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/api.scala @@ -38,7 +38,11 @@ class ApplicationAttemptInfo private[spark]( val lastUpdated: Date, val duration: Long, val sparkUser: String, - val completed: Boolean = false) + val completed: Boolean = false) { + def getStartTimeEpoch: Long = startTime.getTime + def getEndTimeEpoch: Long = endTime.getTime + def getLastUpdatedEpoch: Long = lastUpdated.getTime +} class ExecutorStageSummary private[spark]( val taskTime : Long, diff --git a/core/src/test/resources/HistoryServerExpectations/application_list_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/application_list_json_expectation.json index 5bbb4ceb97..1a13233133 100644 --- a/core/src/test/resources/HistoryServerExpectations/application_list_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/application_list_json_expectation.json @@ -2,6 +2,9 @@ "id" : "local-1430917381534", "name" : "Spark shell", "attempts" : [ { + "startTimeEpoch" : 1430917380893, + "endTimeEpoch" : 1430917391398, + "lastUpdatedEpoch" : 0, "startTime" : "2015-05-06T13:03:00.893GMT", "endTime" : "2015-05-06T13:03:11.398GMT", "lastUpdated" : "", @@ -14,6 +17,9 @@ "name" : "Spark shell", "attempts" : [ { "attemptId" : "2", + "startTimeEpoch" : 1430917380893, + "endTimeEpoch" : 1430917380950, + "lastUpdatedEpoch" : 0, "startTime" : "2015-05-06T13:03:00.893GMT", "endTime" : "2015-05-06T13:03:00.950GMT", "lastUpdated" : "", @@ -22,6 +28,9 @@ "completed" : true }, { "attemptId" : "1", + "startTimeEpoch" : 1430917380880, + "endTimeEpoch" : 1430917380890, + "lastUpdatedEpoch" : 0, "startTime" : "2015-05-06T13:03:00.880GMT", "endTime" : "2015-05-06T13:03:00.890GMT", "lastUpdated" : "", @@ -34,6 +43,9 @@ "name" : "Spark shell", "attempts" : [ { "attemptId" : "2", + "startTimeEpoch" : 1426633910242, + "endTimeEpoch" : 1426633945177, + "lastUpdatedEpoch" : 0, "startTime" : "2015-03-17T23:11:50.242GMT", "endTime" : "2015-03-17T23:12:25.177GMT", "lastUpdated" : "", @@ -42,6 +54,9 @@ "completed" : true }, { "attemptId" : "1", + "startTimeEpoch" : 1426533910242, + "endTimeEpoch" : 1426533945177, + "lastUpdatedEpoch" : 0, "startTime" : "2015-03-16T19:25:10.242GMT", "endTime" : "2015-03-16T19:25:45.177GMT", "lastUpdated" : "", @@ -53,6 +68,9 @@ "id" : "local-1425081759269", "name" : "Spark shell", "attempts" : [ { + "startTimeEpoch" : 1425081758277, + "endTimeEpoch" : 1425081766912, + "lastUpdatedEpoch" : 0, "startTime" : "2015-02-28T00:02:38.277GMT", "endTime" : "2015-02-28T00:02:46.912GMT", "lastUpdated" : "", @@ -64,6 +82,9 @@ "id" : "local-1422981780767", "name" : "Spark shell", "attempts" : [ { + "startTimeEpoch" : 1422981779720, + "endTimeEpoch" : 1422981788731, + "lastUpdatedEpoch" : 0, "startTime" : "2015-02-03T16:42:59.720GMT", "endTime" : "2015-02-03T16:43:08.731GMT", "lastUpdated" : "", @@ -75,6 +96,9 @@ "id" : "local-1422981759269", "name" : "Spark shell", "attempts" : [ { + "startTimeEpoch" : 1422981758277, + "endTimeEpoch" : 1422981766912, + "lastUpdatedEpoch" : 0, "startTime" : "2015-02-03T16:42:38.277GMT", "endTime" : "2015-02-03T16:42:46.912GMT", "lastUpdated" : "", diff --git a/core/src/test/resources/HistoryServerExpectations/completed_app_list_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/completed_app_list_json_expectation.json index 5bbb4ceb97..1a13233133 100644 --- a/core/src/test/resources/HistoryServerExpectations/completed_app_list_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/completed_app_list_json_expectation.json @@ -2,6 +2,9 @@ "id" : "local-1430917381534", "name" : "Spark shell", "attempts" : [ { + "startTimeEpoch" : 1430917380893, + "endTimeEpoch" : 1430917391398, + "lastUpdatedEpoch" : 0, "startTime" : "2015-05-06T13:03:00.893GMT", "endTime" : "2015-05-06T13:03:11.398GMT", "lastUpdated" : "", @@ -14,6 +17,9 @@ "name" : "Spark shell", "attempts" : [ { "attemptId" : "2", + "startTimeEpoch" : 1430917380893, + "endTimeEpoch" : 1430917380950, + "lastUpdatedEpoch" : 0, "startTime" : "2015-05-06T13:03:00.893GMT", "endTime" : "2015-05-06T13:03:00.950GMT", "lastUpdated" : "", @@ -22,6 +28,9 @@ "completed" : true }, { "attemptId" : "1", + "startTimeEpoch" : 1430917380880, + "endTimeEpoch" : 1430917380890, + "lastUpdatedEpoch" : 0, "startTime" : "2015-05-06T13:03:00.880GMT", "endTime" : "2015-05-06T13:03:00.890GMT", "lastUpdated" : "", @@ -34,6 +43,9 @@ "name" : "Spark shell", "attempts" : [ { "attemptId" : "2", + "startTimeEpoch" : 1426633910242, + "endTimeEpoch" : 1426633945177, + "lastUpdatedEpoch" : 0, "startTime" : "2015-03-17T23:11:50.242GMT", "endTime" : "2015-03-17T23:12:25.177GMT", "lastUpdated" : "", @@ -42,6 +54,9 @@ "completed" : true }, { "attemptId" : "1", + "startTimeEpoch" : 1426533910242, + "endTimeEpoch" : 1426533945177, + "lastUpdatedEpoch" : 0, "startTime" : "2015-03-16T19:25:10.242GMT", "endTime" : "2015-03-16T19:25:45.177GMT", "lastUpdated" : "", @@ -53,6 +68,9 @@ "id" : "local-1425081759269", "name" : "Spark shell", "attempts" : [ { + "startTimeEpoch" : 1425081758277, + "endTimeEpoch" : 1425081766912, + "lastUpdatedEpoch" : 0, "startTime" : "2015-02-28T00:02:38.277GMT", "endTime" : "2015-02-28T00:02:46.912GMT", "lastUpdated" : "", @@ -64,6 +82,9 @@ "id" : "local-1422981780767", "name" : "Spark shell", "attempts" : [ { + "startTimeEpoch" : 1422981779720, + "endTimeEpoch" : 1422981788731, + "lastUpdatedEpoch" : 0, "startTime" : "2015-02-03T16:42:59.720GMT", "endTime" : "2015-02-03T16:43:08.731GMT", "lastUpdated" : "", @@ -75,6 +96,9 @@ "id" : "local-1422981759269", "name" : "Spark shell", "attempts" : [ { + "startTimeEpoch" : 1422981758277, + "endTimeEpoch" : 1422981766912, + "lastUpdatedEpoch" : 0, "startTime" : "2015-02-03T16:42:38.277GMT", "endTime" : "2015-02-03T16:42:46.912GMT", "lastUpdated" : "", diff --git a/core/src/test/resources/HistoryServerExpectations/maxDate2_app_list_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/maxDate2_app_list_json_expectation.json index 3f80a529a0..eacf04b901 100644 --- a/core/src/test/resources/HistoryServerExpectations/maxDate2_app_list_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/maxDate2_app_list_json_expectation.json @@ -2,6 +2,9 @@ "id" : "local-1422981759269", "name" : "Spark shell", "attempts" : [ { + "startTimeEpoch" : 1422981758277, + "endTimeEpoch" : 1422981766912, + "lastUpdatedEpoch" : 0, "startTime" : "2015-02-03T16:42:38.277GMT", "endTime" : "2015-02-03T16:42:46.912GMT", "lastUpdated" : "", diff --git a/core/src/test/resources/HistoryServerExpectations/maxDate_app_list_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/maxDate_app_list_json_expectation.json index 508bdc17ef..adad25bf17 100644 --- a/core/src/test/resources/HistoryServerExpectations/maxDate_app_list_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/maxDate_app_list_json_expectation.json @@ -2,6 +2,9 @@ "id" : "local-1422981780767", "name" : "Spark shell", "attempts" : [ { + "startTimeEpoch" : 1422981779720, + "endTimeEpoch" : 1422981788731, + "lastUpdatedEpoch" : 0, "startTime" : "2015-02-03T16:42:59.720GMT", "endTime" : "2015-02-03T16:43:08.731GMT", "lastUpdated" : "", @@ -13,6 +16,9 @@ "id" : "local-1422981759269", "name" : "Spark shell", "attempts" : [ { + "startTimeEpoch" : 1422981758277, + "endTimeEpoch" : 1422981766912, + "lastUpdatedEpoch" : 0, "startTime" : "2015-02-03T16:42:38.277GMT", "endTime" : "2015-02-03T16:42:46.912GMT", "lastUpdated" : "", diff --git a/core/src/test/resources/HistoryServerExpectations/minDate_app_list_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/minDate_app_list_json_expectation.json index 5dca7d73de..a658909088 100644 --- a/core/src/test/resources/HistoryServerExpectations/minDate_app_list_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/minDate_app_list_json_expectation.json @@ -2,6 +2,9 @@ "id" : "local-1430917381534", "name" : "Spark shell", "attempts" : [ { + "startTimeEpoch" : 1430917380893, + "endTimeEpoch" : 1430917391398, + "lastUpdatedEpoch" : 0, "startTime" : "2015-05-06T13:03:00.893GMT", "endTime" : "2015-05-06T13:03:11.398GMT", "lastUpdated" : "", @@ -14,6 +17,9 @@ "name" : "Spark shell", "attempts" : [ { "attemptId" : "2", + "startTimeEpoch" : 1430917380893, + "endTimeEpoch" : 1430917380950, + "lastUpdatedEpoch" : 0, "startTime" : "2015-05-06T13:03:00.893GMT", "endTime" : "2015-05-06T13:03:00.950GMT", "lastUpdated" : "", @@ -22,6 +28,9 @@ "completed" : true }, { "attemptId" : "1", + "startTimeEpoch" : 1430917380880, + "endTimeEpoch" : 1430917380890, + "lastUpdatedEpoch" : 0, "startTime" : "2015-05-06T13:03:00.880GMT", "endTime" : "2015-05-06T13:03:00.890GMT", "lastUpdated" : "", @@ -34,6 +43,9 @@ "name" : "Spark shell", "attempts" : [ { "attemptId" : "2", + "startTimeEpoch" : 1426633910242, + "endTimeEpoch" : 1426633945177, + "lastUpdatedEpoch" : 0, "startTime" : "2015-03-17T23:11:50.242GMT", "endTime" : "2015-03-17T23:12:25.177GMT", "lastUpdated" : "", @@ -42,6 +54,9 @@ "completed" : true }, { "attemptId" : "1", + "startTimeEpoch" : 1426533910242, + "endTimeEpoch" : 1426533945177, + "lastUpdatedEpoch" : 0, "startTime" : "2015-03-16T19:25:10.242GMT", "endTime" : "2015-03-16T19:25:45.177GMT", "lastUpdated" : "", @@ -54,6 +69,9 @@ "name": "Spark shell", "attempts": [ { + "startTimeEpoch" : 1425081758277, + "endTimeEpoch" : 1425081766912, + "lastUpdatedEpoch" : 0, "startTime": "2015-02-28T00:02:38.277GMT", "endTime": "2015-02-28T00:02:46.912GMT", "lastUpdated" : "", diff --git a/core/src/test/resources/HistoryServerExpectations/one_app_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/one_app_json_expectation.json index cca32c7910..0217facad9 100644 --- a/core/src/test/resources/HistoryServerExpectations/one_app_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/one_app_json_expectation.json @@ -2,6 +2,9 @@ "id" : "local-1422981780767", "name" : "Spark shell", "attempts" : [ { + "startTimeEpoch" : 1422981779720, + "endTimeEpoch" : 1422981788731, + "lastUpdatedEpoch" : 0, "startTime" : "2015-02-03T16:42:59.720GMT", "endTime" : "2015-02-03T16:43:08.731GMT", "lastUpdated" : "", diff --git a/core/src/test/resources/HistoryServerExpectations/one_app_multi_attempt_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/one_app_multi_attempt_json_expectation.json index 1ea1779e83..b20a26648e 100644 --- a/core/src/test/resources/HistoryServerExpectations/one_app_multi_attempt_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/one_app_multi_attempt_json_expectation.json @@ -3,6 +3,9 @@ "name" : "Spark shell", "attempts" : [ { "attemptId" : "2", + "startTimeEpoch" : 1426633910242, + "endTimeEpoch" : 1426633945177, + "lastUpdatedEpoch" : 0, "startTime" : "2015-03-17T23:11:50.242GMT", "endTime" : "2015-03-17T23:12:25.177GMT", "lastUpdated" : "", @@ -11,6 +14,9 @@ "completed" : true }, { "attemptId" : "1", + "startTimeEpoch" : 1426533910242, + "endTimeEpoch" : 1426533945177, + "lastUpdatedEpoch" : 0, "startTime" : "2015-03-16T19:25:10.242GMT", "endTime" : "2015-03-16T19:25:45.177GMT", "lastUpdated" : "", diff --git a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala index 79e4efb1a8..2a013aca7b 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala @@ -162,7 +162,9 @@ class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with Matchers val json = if (jsonOrg.indexOf("lastUpdated") >= 0) { val subStrings = jsonOrg.split(",") for (i <- subStrings.indices) { - if (subStrings(i).indexOf("lastUpdated") >= 0) { + if (subStrings(i).indexOf("lastUpdatedEpoch") >= 0) { + subStrings(i) = subStrings(i).replaceAll("(\\d+)", "0") + } else if (subStrings(i).indexOf("lastUpdated") >= 0) { subStrings(i) = "\"lastUpdated\":\"\"" } } -- cgit v1.2.3 From 36e8fb8005eccea67a9dea8cf68ec3105aa43351 Mon Sep 17 00:00:00 2001 From: BenFradet Date: Fri, 1 Apr 2016 18:25:43 -0700 Subject: [SPARK-7425][ML] spark.ml Predictor should support other numeric types for label Currently, the Predictor abstraction expects the input labelCol type to be DoubleType, but we should support other numeric types. This will involve updating the PredictorParams.validateAndTransformSchema method. Author: BenFradet Closes #10355 from BenFradet/SPARK-7425. --- .../main/scala/org/apache/spark/ml/Predictor.scala | 9 +-- .../ml/classification/LogisticRegression.scala | 7 +- .../apache/spark/ml/classification/OneVsRest.scala | 4 +- .../ml/regression/AFTSurvivalRegression.scala | 11 +-- .../regression/GeneralizedLinearRegression.scala | 13 ++-- .../spark/ml/regression/IsotonicRegression.scala | 4 +- .../spark/ml/regression/LinearRegression.scala | 11 +-- .../org/apache/spark/ml/util/SchemaUtils.scala | 24 ++++-- .../DecisionTreeClassifierSuite.scala | 15 +++- .../ml/classification/GBTClassifierSuite.scala | 9 ++- .../classification/LogisticRegressionSuite.scala | 11 ++- .../MultilayerPerceptronClassifierSuite.scala | 12 +++ .../spark/ml/classification/NaiveBayesSuite.scala | 14 +++- .../spark/ml/classification/OneVsRestSuite.scala | 16 +++- .../RandomForestClassifierSuite.scala | 8 ++ .../ml/regression/AFTSurvivalRegressionSuite.scala | 9 +++ .../ml/regression/DecisionTreeRegressorSuite.scala | 8 ++ .../spark/ml/regression/GBTRegressorSuite.scala | 8 +- .../GeneralizedLinearRegressionSuite.scala | 12 ++- .../ml/regression/IsotonicRegressionSuite.scala | 9 +++ .../ml/regression/LinearRegressionSuite.scala | 17 ++++- .../ml/regression/RandomForestRegressorSuite.scala | 8 ++ .../org/apache/spark/ml/tree/impl/TreeTests.scala | 18 +++++ .../org/apache/spark/ml/util/MLTestingUtils.scala | 86 +++++++++++++++++++++- 24 files changed, 294 insertions(+), 49 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala index ebe48700f8..d23ae6f794 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala @@ -36,6 +36,7 @@ private[ml] trait PredictorParams extends Params /** * Validates and transforms the input schema with the provided param map. + * * @param schema input schema * @param fitting whether this is in fitting * @param featuresDataType SQL DataType for FeaturesType. @@ -49,8 +50,7 @@ private[ml] trait PredictorParams extends Params // TODO: Support casting Array[Double] and Array[Float] to Vector when FeaturesType = Vector SchemaUtils.checkColumnType(schema, $(featuresCol), featuresDataType) if (fitting) { - // TODO: Allow other numeric types - SchemaUtils.checkColumnType(schema, $(labelCol), DoubleType) + SchemaUtils.checkNumericType(schema, $(labelCol)) } SchemaUtils.appendColumn(schema, $(predictionCol), DoubleType) } @@ -121,9 +121,8 @@ abstract class Predictor[ * and put it in an RDD with strong types. */ protected def extractLabeledPoints(dataset: DataFrame): RDD[LabeledPoint] = { - dataset.select($(labelCol), $(featuresCol)).rdd.map { - case Row(label: Double, features: Vector) => - LabeledPoint(label, features) + dataset.select(col($(labelCol)).cast(DoubleType), col($(featuresCol))).rdd.map { + case Row(label: Double, features: Vector) => LabeledPoint(label, features) } } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index 3d1d5b6892..aeb94a6600 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -38,6 +38,7 @@ import org.apache.spark.mllib.util.MLUtils import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.functions.{col, lit} +import org.apache.spark.sql.types.DoubleType import org.apache.spark.storage.StorageLevel /** @@ -265,7 +266,7 @@ class LogisticRegression @Since("1.2.0") ( LogisticRegressionModel = { val w = if ($(weightCol).isEmpty) lit(1.0) else col($(weightCol)) val instances: RDD[Instance] = - dataset.select(col($(labelCol)), w, col($(featuresCol))).rdd.map { + dataset.select(col($(labelCol)).cast(DoubleType), w, col($(featuresCol))).rdd.map { case Row(label: Double, weight: Double, features: Vector) => Instance(label, weight, features) } @@ -361,7 +362,7 @@ class LogisticRegression @Since("1.2.0") ( if (optInitialModel.isDefined && optInitialModel.get.coefficients.size != numFeatures) { val vec = optInitialModel.get.coefficients logWarning( - s"Initial coefficients provided ${vec} did not match the expected size ${numFeatures}") + s"Initial coefficients provided $vec did not match the expected size $numFeatures") } if (optInitialModel.isDefined && optInitialModel.get.coefficients.size == numFeatures) { @@ -522,7 +523,7 @@ class LogisticRegressionModel private[spark] ( (LogisticRegressionModel, String) = { $(probabilityCol) match { case "" => - val probabilityColName = "probability_" + java.util.UUID.randomUUID.toString() + val probabilityColName = "probability_" + java.util.UUID.randomUUID.toString (copy(ParamMap.empty).setProbabilityCol(probabilityColName), probabilityColName) case p => (this, p) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala index 98b99a3485..263d54ce4d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala @@ -295,10 +295,12 @@ final class OneVsRest @Since("1.4.0") ( @Since("1.4.0") override def fit(dataset: DataFrame): OneVsRestModel = { + transformSchema(dataset.schema) + // determine number of classes either from metadata if provided, or via computation. val labelSchema = dataset.schema($(labelCol)) val computeNumClasses: () => Int = () => { - val Row(maxLabelIndex: Double) = dataset.agg(max($(labelCol))).head() + val Row(maxLabelIndex: Double) = dataset.agg(max(col($(labelCol)).cast(DoubleType))).head() // classes are assumed to be numbered from 0,...,maxLabelIndex maxLabelIndex.toInt + 1 } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala index ba5708ab8d..3278974954 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala @@ -103,7 +103,7 @@ private[regression] trait AFTSurvivalRegressionParams extends Params SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT) if (fitting) { SchemaUtils.checkColumnType(schema, $(censorCol), DoubleType) - SchemaUtils.checkColumnType(schema, $(labelCol), DoubleType) + SchemaUtils.checkNumericType(schema, $(labelCol)) } if (hasQuantilesCol) { SchemaUtils.appendColumn(schema, $(quantilesCol), new VectorUDT) @@ -184,10 +184,11 @@ class AFTSurvivalRegression @Since("1.6.0") (@Since("1.6.0") override val uid: S * and put it in an RDD with strong types. */ protected[ml] def extractAFTPoints(dataset: DataFrame): RDD[AFTPoint] = { - dataset.select($(featuresCol), $(labelCol), $(censorCol)).rdd.map { - case Row(features: Vector, label: Double, censor: Double) => - AFTPoint(features, label, censor) - } + dataset.select(col($(featuresCol)), col($(labelCol)).cast(DoubleType), col($(censorCol))) + .rdd.map { + case Row(features: Vector, label: Double, censor: Double) => + AFTPoint(features, label, censor) + } } @Since("1.6.0") diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala index 0e71e8d8e1..a40d3731cb 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala @@ -33,7 +33,7 @@ import org.apache.spark.mllib.linalg.{BLAS, Vector} import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.functions._ -import org.apache.spark.sql.types.{DataType, StructType} +import org.apache.spark.sql.types.{DataType, DoubleType, StructType} /** * Params for Generalized Linear Regression. @@ -47,6 +47,7 @@ private[regression] trait GeneralizedLinearRegressionBase extends PredictorParam * to be used in the model. * Supported options: "gaussian", "binomial", "poisson" and "gamma". * Default is "gaussian". + * * @group param */ @Since("2.0.0") @@ -63,6 +64,7 @@ private[regression] trait GeneralizedLinearRegressionBase extends PredictorParam * Param for the name of link function which provides the relationship * between the linear predictor and the mean of the distribution function. * Supported options: "identity", "log", "inverse", "logit", "probit", "cloglog" and "sqrt". + * * @group param */ @Since("2.0.0") @@ -210,9 +212,10 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val } val w = if ($(weightCol).isEmpty) lit(1.0) else col($(weightCol)) - val instances: RDD[Instance] = dataset.select(col($(labelCol)), w, col($(featuresCol))).rdd - .map { case Row(label: Double, weight: Double, features: Vector) => - Instance(label, weight, features) + val instances: RDD[Instance] = + dataset.select(col($(labelCol)).cast(DoubleType), w, col($(featuresCol))).rdd.map { + case Row(label: Double, weight: Double, features: Vector) => + Instance(label, weight, features) } if (familyObj == Gaussian && linkObj == Identity) { @@ -698,7 +701,7 @@ class GeneralizedLinearRegressionModel private[ml] ( : (GeneralizedLinearRegressionModel, String) = { $(predictionCol) match { case "" => - val predictionColName = "prediction_" + java.util.UUID.randomUUID.toString() + val predictionColName = "prediction_" + java.util.UUID.randomUUID.toString (copy(ParamMap.empty).setPredictionCol(predictionColName), predictionColName) case p => (this, p) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala index fb733f9a34..bd0b631d89 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala @@ -90,7 +90,7 @@ private[regression] trait IsotonicRegressionBase extends Params with HasFeatures } else { lit(1.0) } - dataset.select(col($(labelCol)), f, w).rdd.map { + dataset.select(col($(labelCol)).cast(DoubleType), f, w).rdd.map { case Row(label: Double, feature: Double, weight: Double) => (label, feature, weight) } @@ -106,7 +106,7 @@ private[regression] trait IsotonicRegressionBase extends Params with HasFeatures schema: StructType, fitting: Boolean): StructType = { if (fitting) { - SchemaUtils.checkColumnType(schema, $(labelCol), DoubleType) + SchemaUtils.checkNumericType(schema, $(labelCol)) if (hasWeightCol) { SchemaUtils.checkColumnType(schema, $(weightCol), DoubleType) } else { diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index 5ec02135cc..ba5ad4c072 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -40,6 +40,7 @@ import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types.DoubleType import org.apache.spark.storage.StorageLevel /** @@ -171,7 +172,7 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String // For low dimensional data, WeightedLeastSquares is more efficiently since the // training algorithm only requires one pass through the data. (SPARK-10668) val instances: RDD[Instance] = dataset.select( - col($(labelCol)), w, col($(featuresCol))).rdd.map { + col($(labelCol)).cast(DoubleType), w, col($(featuresCol))).rdd.map { case Row(label: Double, weight: Double, features: Vector) => Instance(label, weight, features) } @@ -431,7 +432,7 @@ class LinearRegressionModel private[ml] ( private[regression] def findSummaryModelAndPredictionCol(): (LinearRegressionModel, String) = { $(predictionCol) match { case "" => - val predictionColName = "prediction_" + java.util.UUID.randomUUID.toString() + val predictionColName = "prediction_" + java.util.UUID.randomUUID.toString (copy(ParamMap.empty).setPredictionCol(predictionColName), predictionColName) case p => (this, p) } @@ -550,7 +551,7 @@ class LinearRegressionSummary private[regression] ( @transient private val metrics = new RegressionMetrics( predictions - .select(predictionCol, labelCol) + .select(col(predictionCol), col(labelCol).cast(DoubleType)) .rdd .map { case Row(pred: Double, label: Double) => (pred, label) }, !model.getFitIntercept) @@ -653,7 +654,7 @@ class LinearRegressionSummary private[regression] ( col(model.getWeightCol)).as("wse")).agg(sum(col("wse"))).first().getDouble(0) } val sigma2 = rss / degreesOfFreedom - diagInvAtWA.map(_ * sigma2).map(math.sqrt(_)) + diagInvAtWA.map(_ * sigma2).map(math.sqrt) } } @@ -826,7 +827,7 @@ private class LeastSquaresAggregator( instance match { case Instance(label, weight, features) => require(dim == features.size, s"Dimensions mismatch when adding new sample." + s" Expecting $dim but got ${features.size}.") - require(weight >= 0.0, s"instance weight, ${weight} has to be >= 0.0") + require(weight >= 0.0, s"instance weight, $weight has to be >= 0.0") if (weight == 0.0) return this diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala b/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala index 76021ad8f4..334410c962 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala @@ -17,7 +17,7 @@ package org.apache.spark.ml.util -import org.apache.spark.sql.types.{DataType, StructField, StructType} +import org.apache.spark.sql.types.{DataType, NumericType, StructField, StructType} /** @@ -44,10 +44,10 @@ private[spark] object SchemaUtils { } /** - * Check whether the given schema contains a column of one of the require data types. - * @param colName column name - * @param dataTypes required column data types - */ + * Check whether the given schema contains a column of one of the require data types. + * @param colName column name + * @param dataTypes required column data types + */ def checkColumnTypes( schema: StructType, colName: String, @@ -60,6 +60,20 @@ private[spark] object SchemaUtils { s"${dataTypes.mkString("[", ", ", "]")} but was actually of type $actualDataType.$message") } + /** + * Check whether the given schema contains a column of the numeric data type. + * @param colName column name + */ + def checkNumericType( + schema: StructType, + colName: String, + msg: String = ""): Unit = { + val actualDataType = schema(colName).dataType + val message = if (msg != null && msg.trim.length > 0) " " + msg else "" + require(actualDataType.isInstanceOf[NumericType], s"Column $colName must be of type " + + s"NumericType but was actually of type $actualDataType.$message") + } + /** * Appends a new column to the input schema. This fails if the given output column already exists. * @param schema input schema diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala index 2b07524815..fe839e15e9 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala @@ -27,8 +27,7 @@ import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.{DecisionTree => OldDecisionTree, DecisionTreeSuite => OldDecisionTreeSuite} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.rdd.RDD -import org.apache.spark.sql.DataFrame -import org.apache.spark.sql.Row +import org.apache.spark.sql.{DataFrame, Row} class DecisionTreeClassifierSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { @@ -176,7 +175,7 @@ class DecisionTreeClassifierSuite } test("Multiclass classification tree with 10-ary (ordered) categorical features," + - " with just enough bins") { + " with just enough bins") { val rdd = categoricalDataPointsForMulticlassForOrderedFeaturesRDD val dt = new DecisionTreeClassifier() .setImpurity("Gini") @@ -273,7 +272,7 @@ class DecisionTreeClassifierSuite )) val df = TreeTests.setMetadata(data, Map(0 -> 1), 2) val dt = new DecisionTreeClassifier().setMaxDepth(3) - val model = dt.fit(df) + dt.fit(df) } test("Use soft prediction for binary classification with ordered categorical features") { @@ -335,6 +334,14 @@ class DecisionTreeClassifierSuite assert(importances.toArray.forall(_ >= 0.0)) } + test("should support all NumericType labels and not support other types") { + val dt = new DecisionTreeClassifier().setMaxDepth(1) + MLTestingUtils.checkNumericTypes[DecisionTreeClassificationModel, DecisionTreeClassifier]( + dt, isClassification = true, sqlContext) { (expected, actual) => + TreeTests.checkEqual(expected, actual) + } + } + ///////////////////////////////////////////////////////////////////////////// // Tests of model save/load ///////////////////////////////////////////////////////////////////////////// diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala index bf7481e8a3..76d8c9372e 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala @@ -31,7 +31,6 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.DataFrame import org.apache.spark.util.Utils - /** * Test suite for [[GBTClassifier]]. */ @@ -102,6 +101,14 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext { Utils.deleteRecursively(tempDir) } + test("should support all NumericType labels and not support other types") { + val gbt = new GBTClassifier().setMaxDepth(1) + MLTestingUtils.checkNumericTypes[GBTClassificationModel, GBTClassifier]( + gbt, isClassification = true, sqlContext) { (expected, actual) => + TreeTests.checkEqual(expected, actual) + } + } + // TODO: Reinstate test once runWithValidation is implemented SPARK-7132 /* test("runWithValidation stops early and performs better on a validation dataset") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala index afeeaf7fb5..7eefaf2346 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala @@ -103,7 +103,7 @@ class LogisticRegressionSuite assert(model.hasSummary) // Validate that we re-insert a probability column for evaluation val fieldNames = model.summary.predictions.schema.fieldNames - assert((dataset.schema.fieldNames.toSet).subsetOf( + assert(dataset.schema.fieldNames.toSet.subsetOf( fieldNames.toSet)) assert(fieldNames.exists(s => s.startsWith("probability_"))) } @@ -934,6 +934,15 @@ class LogisticRegressionSuite testEstimatorAndModelReadWrite(lr, dataset, LogisticRegressionSuite.allParamSettings, checkModelData) } + + test("should support all NumericType labels and not support other types") { + val lr = new LogisticRegression().setMaxIter(1) + MLTestingUtils.checkNumericTypes[LogisticRegressionModel, LogisticRegression]( + lr, isClassification = true, sqlContext) { (expected, actual) => + assert(expected.intercept === actual.intercept) + assert(expected.coefficients.toArray === actual.coefficients.toArray) + } + } } object LogisticRegressionSuite { diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala index 43781385db..06ff049b48 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.ml.classification import org.apache.spark.SparkFunSuite import org.apache.spark.ml.util.DefaultReadWriteTest +import org.apache.spark.ml.util.MLTestingUtils import org.apache.spark.mllib.classification.LogisticRegressionSuite._ import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS import org.apache.spark.mllib.evaluation.MulticlassMetrics @@ -162,4 +163,15 @@ class MultilayerPerceptronClassifierSuite assert(newMlpModel.layers === mlpModel.layers) assert(newMlpModel.weights === mlpModel.weights) } + + test("should support all NumericType labels and not support other types") { + val layers = Array(3, 2) + val mpc = new MultilayerPerceptronClassifier().setLayers(layers).setMaxIter(1) + MLTestingUtils.checkNumericTypes[ + MultilayerPerceptronClassificationModel, MultilayerPerceptronClassifier]( + mpc, isClassification = true, sqlContext) { (expected, actual) => + assert(expected.layers === actual.layers) + assert(expected.weights === actual.weights) + } + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala index 082a6bcd21..4727cd436f 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala @@ -21,7 +21,7 @@ import breeze.linalg.{Vector => BV} import org.apache.spark.SparkFunSuite import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.util.DefaultReadWriteTest +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.mllib.classification.NaiveBayes.{Bernoulli, Multinomial} import org.apache.spark.mllib.classification.NaiveBayesSuite._ import org.apache.spark.mllib.linalg._ @@ -86,7 +86,7 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext with Defa model: NaiveBayesModel, modelType: String): Unit = { featureAndProbabilities.collect().foreach { - case Row(features: Vector, probability: Vector) => { + case Row(features: Vector, probability: Vector) => assert(probability.toArray.sum ~== 1.0 relTol 1.0e-10) val expected = modelType match { case Multinomial => @@ -97,7 +97,6 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext with Defa throw new UnknownError(s"Invalid modelType: $modelType.") } assert(probability ~== expected relTol 1.0e-10) - } } } @@ -185,6 +184,15 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext with Defa val nb = new NaiveBayes() testEstimatorAndModelReadWrite(nb, dataset, NaiveBayesSuite.allParamSettings, checkModelData) } + + test("should support all NumericType labels and not support other types") { + val nb = new NaiveBayes() + MLTestingUtils.checkNumericTypes[NaiveBayesModel, NaiveBayes]( + nb, isClassification = true, sqlContext) { (expected, actual) => + assert(expected.pi === actual.pi) + assert(expected.theta === actual.theta) + } + } } object NaiveBayesSuite { diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala index 51c1baf682..4131396726 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala @@ -74,7 +74,7 @@ class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext with Defau // copied model must have the same parent. MLTestingUtils.checkCopy(ovaModel) - assert(ovaModel.models.size === numClasses) + assert(ovaModel.models.length === numClasses) val transformedDataset = ovaModel.transform(dataset) // check for label metadata in prediction col @@ -224,6 +224,20 @@ class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext with Defau val newOvaModel = testDefaultReadWrite(ovaModel, testParams = false) checkModelData(ovaModel, newOvaModel) } + + test("should support all NumericType labels and not support other types") { + val ovr = new OneVsRest().setClassifier(new LogisticRegression().setMaxIter(1)) + MLTestingUtils.checkNumericTypes[OneVsRestModel, OneVsRest]( + ovr, isClassification = true, sqlContext) { (expected, actual) => + val expectedModels = expected.models.map(m => m.asInstanceOf[LogisticRegressionModel]) + val actualModels = actual.models.map(m => m.asInstanceOf[LogisticRegressionModel]) + assert(expectedModels.length === actualModels.length) + expectedModels.zip(actualModels).foreach { case (e, a) => + assert(e.intercept === a.intercept) + assert(e.coefficients.toArray === a.coefficients.toArray) + } + } + } } private class MockLogisticRegression(uid: String) extends LogisticRegression(uid) { diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala index b896099e31..052bc83c38 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala @@ -178,6 +178,14 @@ class RandomForestClassifierSuite extends SparkFunSuite with MLlibTestSparkConte assert(importances.toArray.forall(_ >= 0.0)) } + test("should support all NumericType labels and not support other types") { + val rf = new RandomForestClassifier().setMaxDepth(1) + MLTestingUtils.checkNumericTypes[RandomForestClassificationModel, RandomForestClassifier]( + rf, isClassification = true, sqlContext) { (expected, actual) => + TreeTests.checkEqual(expected, actual) + } + } + ///////////////////////////////////////////////////////////////////////////// // Tests of model save/load ///////////////////////////////////////////////////////////////////////////// diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala index dbd752d2aa..f4844cc671 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala @@ -347,6 +347,15 @@ class AFTSurvivalRegressionSuite } } + test("should support all NumericType labels") { + val aft = new AFTSurvivalRegression().setMaxIter(1) + MLTestingUtils.checkNumericTypes[AFTSurvivalRegressionModel, AFTSurvivalRegression]( + aft, isClassification = false, sqlContext) { (expected, actual) => + assert(expected.intercept === actual.intercept) + assert(expected.coefficients === actual.coefficients) + } + } + test("read/write") { def checkModelData( model: AFTSurvivalRegressionModel, diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala index 662e3fc679..e9fb2677b2 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala @@ -117,6 +117,14 @@ class DecisionTreeRegressorSuite assert(importances.toArray.forall(_ >= 0.0)) } + test("should support all NumericType labels and not support other types") { + val dt = new DecisionTreeRegressor().setMaxDepth(1) + MLTestingUtils.checkNumericTypes[DecisionTreeRegressionModel, DecisionTreeRegressor]( + dt, isClassification = false, sqlContext) { (expected, actual) => + TreeTests.checkEqual(expected, actual) + } + } + ///////////////////////////////////////////////////////////////////////////// // Tests of model save/load ///////////////////////////////////////////////////////////////////////////// diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala index dfb8418086..914818f41f 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala @@ -29,7 +29,6 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.DataFrame import org.apache.spark.util.Utils - /** * Test suite for [[GBTRegressor]]. */ @@ -110,7 +109,14 @@ class GBTRegressorSuite extends SparkFunSuite with MLlibTestSparkContext { sc.checkpointDir = None Utils.deleteRecursively(tempDir) + } + test("should support all NumericType labels and not support other types") { + val gbt = new GBTRegressor().setMaxDepth(1) + MLTestingUtils.checkNumericTypes[GBTRegressionModel, GBTRegressor]( + gbt, isClassification = false, sqlContext) { (expected, actual) => + TreeTests.checkEqual(expected, actual) + } } // TODO: Reinstate test once runWithValidation is implemented SPARK-7132 diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala index 4ebdbf2213..2265464b51 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala @@ -982,6 +982,16 @@ class GeneralizedLinearRegressionSuite testEstimatorAndModelReadWrite(glr, datasetPoissonLog, GeneralizedLinearRegressionSuite.allParamSettings, checkModelData) } + + test("should support all NumericType labels and not support other types") { + val glr = new GeneralizedLinearRegression().setMaxIter(1) + MLTestingUtils.checkNumericTypes[ + GeneralizedLinearRegressionModel, GeneralizedLinearRegression]( + glr, isClassification = false, sqlContext) { (expected, actual) => + assert(expected.intercept === actual.intercept) + assert(expected.coefficients === actual.coefficients) + } + } } object GeneralizedLinearRegressionSuite { @@ -1023,7 +1033,7 @@ object GeneralizedLinearRegressionSuite { generator.setSeed(seed) (0 until nPoints).map { _ => - val features = Vectors.dense(coefficients.indices.map { rndElement(_) }.toArray) + val features = Vectors.dense(coefficients.indices.map(rndElement).toArray) val eta = BLAS.dot(Vectors.dense(coefficients), features) + intercept val mu = link match { case "identity" => eta diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala index b8874b4cd3..3a10ad7ed0 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala @@ -180,6 +180,15 @@ class IsotonicRegressionSuite testEstimatorAndModelReadWrite(ir, dataset, IsotonicRegressionSuite.allParamSettings, checkModelData) } + + test("should support all NumericType labels and not support other types") { + val ir = new IsotonicRegression() + MLTestingUtils.checkNumericTypes[IsotonicRegressionModel, IsotonicRegression]( + ir, isClassification = false, sqlContext) { (expected, actual) => + assert(expected.boundaries === actual.boundaries) + assert(expected.predictions === actual.predictions) + } + } } object IsotonicRegressionSuite { diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala index bd45d21e8d..cccb7f8d1b 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala @@ -61,9 +61,9 @@ class LinearRegressionSuite val featureSize = 4100 datasetWithSparseFeature = sqlContext.createDataFrame( sc.parallelize(LinearDataGenerator.generateLinearInput( - intercept = 0.0, weights = Seq.fill(featureSize)(r.nextDouble).toArray, - xMean = Seq.fill(featureSize)(r.nextDouble).toArray, - xVariance = Seq.fill(featureSize)(r.nextDouble).toArray, nPoints = 200, + intercept = 0.0, weights = Seq.fill(featureSize)(r.nextDouble()).toArray, + xMean = Seq.fill(featureSize)(r.nextDouble()).toArray, + xVariance = Seq.fill(featureSize)(r.nextDouble()).toArray, nPoints = 200, seed, eps = 0.1, sparsity = 0.7), 2)) /* @@ -687,7 +687,7 @@ class LinearRegressionSuite // Validate that we re-insert a prediction column for evaluation val modelNoPredictionColFieldNames = modelNoPredictionCol.summary.predictions.schema.fieldNames - assert((datasetWithDenseFeature.schema.fieldNames.toSet).subsetOf( + assert(datasetWithDenseFeature.schema.fieldNames.toSet.subsetOf( modelNoPredictionColFieldNames.toSet)) assert(modelNoPredictionColFieldNames.exists(s => s.startsWith("prediction_"))) @@ -1006,6 +1006,15 @@ class LinearRegressionSuite testEstimatorAndModelReadWrite(lr, datasetWithWeight, LinearRegressionSuite.allParamSettings, checkModelData) } + + test("should support all NumericType labels and not support other types") { + val lr = new LinearRegression().setMaxIter(1) + MLTestingUtils.checkNumericTypes[LinearRegressionModel, LinearRegression]( + lr, isClassification = false, sqlContext) { (expected, actual) => + assert(expected.intercept === actual.intercept) + assert(expected.coefficients === actual.coefficients) + } + } } object LinearRegressionSuite { diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala index 6be0c8bca0..2ab4f1b146 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala @@ -94,6 +94,14 @@ class RandomForestRegressorSuite extends SparkFunSuite with MLlibTestSparkContex assert(importances.toArray.forall(_ >= 0.0)) } + test("should support all NumericType labels and not support other types") { + val rf = new RandomForestRegressor().setMaxDepth(1) + MLTestingUtils.checkNumericTypes[RandomForestRegressionModel, RandomForestRegressor]( + rf, isClassification = false, sqlContext) { (expected, actual) => + TreeTests.checkEqual(expected, actual) + } + } + ///////////////////////////////////////////////////////////////////////////// // Tests of model save/load ///////////////////////////////////////////////////////////////////////////// diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala index 12808b0305..bd5bd17147 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala @@ -73,6 +73,24 @@ private[ml] object TreeTests extends SparkFunSuite { numClasses) } + /** + * Set label metadata (particularly the number of classes) on a DataFrame. + * @param data Dataset. Categorical features and labels must already have 0-based indices. + * This must be non-empty. + * @param numClasses Number of classes label can take. If 0, mark as continuous. + * @param labelColName Name of the label column on which to set the metadata. + * @return DataFrame with metadata + */ + def setMetadata(data: DataFrame, numClasses: Int, labelColName: String): DataFrame = { + val labelAttribute = if (numClasses == 0) { + NumericAttribute.defaultAttr.withName(labelColName) + } else { + NominalAttribute.defaultAttr.withName(labelColName).withNumValues(numClasses) + } + val labelMetadata = labelAttribute.toMetadata() + data.select(data("features"), data(labelColName).as(labelColName, labelMetadata)) + } + /** * Check if the two trees are exactly the same. * Note: I hesitate to override Node.equals since it could cause problems if users diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala b/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala index d290cc9b06..8108460518 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala @@ -17,14 +17,96 @@ package org.apache.spark.ml.util -import org.apache.spark.ml.Model +import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.param.ParamMap +import org.apache.spark.ml.tree.impl.TreeTests +import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.sql.{DataFrame, SQLContext} +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types._ -object MLTestingUtils { +object MLTestingUtils extends SparkFunSuite { def checkCopy(model: Model[_]): Unit = { val copied = model.copy(ParamMap.empty) .asInstanceOf[Model[_]] assert(copied.parent.uid == model.parent.uid) assert(copied.parent == model.parent) } + + def checkNumericTypes[M <: Model[M], T <: Estimator[M]]( + estimator: T, + isClassification: Boolean, + sqlContext: SQLContext)(check: (M, M) => Unit): Unit = { + val dfs = if (isClassification) { + genClassifDFWithNumericLabelCol(sqlContext) + } else { + genRegressionDFWithNumericLabelCol(sqlContext) + } + val expected = estimator.fit(dfs(DoubleType)) + val actuals = dfs.keys.filter(_ != DoubleType).map(t => estimator.fit(dfs(t))) + actuals.foreach(actual => check(expected, actual)) + + val dfWithStringLabels = generateDFWithStringLabelCol(sqlContext) + val thrown = intercept[IllegalArgumentException] { + estimator.fit(dfWithStringLabels) + } + assert(thrown.getMessage contains + "Column label must be of type NumericType but was actually of type StringType") + } + + def genClassifDFWithNumericLabelCol( + sqlContext: SQLContext, + labelColName: String = "label", + featuresColName: String = "features"): Map[NumericType, DataFrame] = { + val df = sqlContext.createDataFrame(Seq( + (0, Vectors.dense(0, 2, 3)), + (1, Vectors.dense(0, 3, 1)), + (0, Vectors.dense(0, 2, 2)), + (1, Vectors.dense(0, 3, 9)), + (0, Vectors.dense(0, 2, 6)) + )).toDF(labelColName, featuresColName) + + val types = + Seq(ShortType, LongType, IntegerType, FloatType, ByteType, DoubleType, DecimalType(10, 0)) + types.map(t => t -> df.select(col(labelColName).cast(t), col(featuresColName))) + .map { case (t, d) => t -> TreeTests.setMetadata(d, 2, labelColName) } + .toMap + } + + def genRegressionDFWithNumericLabelCol( + sqlContext: SQLContext, + labelColName: String = "label", + featuresColName: String = "features", + censorColName: String = "censor"): Map[NumericType, DataFrame] = { + val df = sqlContext.createDataFrame(Seq( + (0, Vectors.dense(0)), + (1, Vectors.dense(1)), + (2, Vectors.dense(2)), + (3, Vectors.dense(3)), + (4, Vectors.dense(4)) + )).toDF(labelColName, featuresColName) + + val types = + Seq(ShortType, LongType, IntegerType, FloatType, ByteType, DoubleType, DecimalType(10, 0)) + types + .map(t => t -> df.select(col(labelColName).cast(t), col(featuresColName))) + .map { case (t, d) => + t -> TreeTests.setMetadata(d, 0, labelColName).withColumn(censorColName, lit(0.0)) + } + .toMap + } + + def generateDFWithStringLabelCol( + sqlContext: SQLContext, + labelColName: String = "label", + featuresColName: String = "features", + censorColName: String = "censor"): DataFrame = + sqlContext.createDataFrame(Seq( + ("0", Vectors.dense(0, 2, 3), 0.0), + ("1", Vectors.dense(0, 3, 1), 1.0), + ("0", Vectors.dense(0, 2, 2), 0.0), + ("1", Vectors.dense(0, 3, 9), 1.0), + ("0", Vectors.dense(0, 2, 6), 0.0) + )).toDF(labelColName, featuresColName, censorColName) } -- cgit v1.2.3 From 4fc35e6f5c590feb47cbcb5b1136f2e985677b3f Mon Sep 17 00:00:00 2001 From: sethah Date: Fri, 1 Apr 2016 21:23:35 -0700 Subject: [SPARK-14308][ML][MLLIB] Remove unused mllib tree classes and move private classes to ML ## What changes were proposed in this pull request? Decision tree helper classes will be migrated to ML. This patch moves those internal classes that are not part of the public API and removes ones that are no longer used, after [SPARK-12183](https://github.com/apache/spark/pull/11855). No functional changes are made. Details: * Bin.scala is removed as the ML implementation does not require bins * mllib NodeIdCache is removed. It was only used by the mllib implementation previously, which no longer exists * mllib TreePoint is removed. It was only used by the mllib implementation previously, which no longer exists * BaggedPoint, DTStatsAggregator, DecisionTreeMetadata, BaggedPointSuite and TimeTracker are all moved to ML. ## How was this patch tested? No functional changes are made. Existing unit tests ensure behavior is unchanged. Author: sethah Closes #12097 from sethah/cleanup_mllib_tree. --- .../apache/spark/ml/tree/impl/BaggedPoint.scala | 125 ++++++++++++ .../spark/ml/tree/impl/DTStatsAggregator.scala | 181 +++++++++++++++++ .../spark/ml/tree/impl/DecisionTreeMetadata.scala | 217 +++++++++++++++++++++ .../spark/ml/tree/impl/GradientBoostedTrees.scala | 1 - .../apache/spark/ml/tree/impl/NodeIdCache.scala | 1 - .../apache/spark/ml/tree/impl/RandomForest.scala | 4 +- .../apache/spark/ml/tree/impl/TimeTracker.scala | 70 +++++++ .../org/apache/spark/ml/tree/impl/TreePoint.scala | 1 - .../spark/mllib/tree/GradientBoostedTrees.scala | 3 +- .../apache/spark/mllib/tree/impl/BaggedPoint.scala | 125 ------------ .../spark/mllib/tree/impl/DTStatsAggregator.scala | 178 ----------------- .../mllib/tree/impl/DecisionTreeMetadata.scala | 217 --------------------- .../apache/spark/mllib/tree/impl/NodeIdCache.scala | 195 ------------------ .../apache/spark/mllib/tree/impl/TimeTracker.scala | 70 ------- .../apache/spark/mllib/tree/impl/TreePoint.scala | 150 -------------- .../apache/spark/mllib/tree/impurity/Entropy.scala | 2 +- .../apache/spark/mllib/tree/impurity/Gini.scala | 2 +- .../spark/mllib/tree/impurity/Variance.scala | 2 +- .../org/apache/spark/mllib/tree/model/Bin.scala | 47 ----- .../spark/ml/tree/impl/BaggedPointSuite.scala | 99 ++++++++++ .../spark/ml/tree/impl/RandomForestSuite.scala | 1 - .../spark/mllib/tree/DecisionTreeSuite.scala | 2 +- .../spark/mllib/tree/impl/BaggedPointSuite.scala | 99 ---------- 23 files changed, 699 insertions(+), 1093 deletions(-) create mode 100644 mllib/src/main/scala/org/apache/spark/ml/tree/impl/BaggedPoint.scala create mode 100644 mllib/src/main/scala/org/apache/spark/ml/tree/impl/DTStatsAggregator.scala create mode 100644 mllib/src/main/scala/org/apache/spark/ml/tree/impl/DecisionTreeMetadata.scala create mode 100644 mllib/src/main/scala/org/apache/spark/ml/tree/impl/TimeTracker.scala delete mode 100644 mllib/src/main/scala/org/apache/spark/mllib/tree/impl/BaggedPoint.scala delete mode 100644 mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala delete mode 100644 mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala delete mode 100644 mllib/src/main/scala/org/apache/spark/mllib/tree/impl/NodeIdCache.scala delete mode 100644 mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TimeTracker.scala delete mode 100644 mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TreePoint.scala delete mode 100644 mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala create mode 100644 mllib/src/test/scala/org/apache/spark/ml/tree/impl/BaggedPointSuite.scala delete mode 100644 mllib/src/test/scala/org/apache/spark/mllib/tree/impl/BaggedPointSuite.scala diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/BaggedPoint.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/BaggedPoint.scala new file mode 100644 index 0000000000..4e372702f0 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/BaggedPoint.scala @@ -0,0 +1,125 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.tree.impl + +import org.apache.commons.math3.distribution.PoissonDistribution + +import org.apache.spark.rdd.RDD +import org.apache.spark.util.Utils +import org.apache.spark.util.random.XORShiftRandom + +/** + * Internal representation of a datapoint which belongs to several subsamples of the same dataset, + * particularly for bagging (e.g., for random forests). + * + * This holds one instance, as well as an array of weights which represent the (weighted) + * number of times which this instance appears in each subsamplingRate. + * E.g., (datum, [1, 0, 4]) indicates that there are 3 subsamples of the dataset and that + * this datum has 1 copy, 0 copies, and 4 copies in the 3 subsamples, respectively. + * + * @param datum Data instance + * @param subsampleWeights Weight of this instance in each subsampled dataset. + * + * TODO: This does not currently support (Double) weighted instances. Once MLlib has weighted + * dataset support, update. (We store subsampleWeights as Double for this future extension.) + */ +private[spark] class BaggedPoint[Datum](val datum: Datum, val subsampleWeights: Array[Double]) + extends Serializable + +private[spark] object BaggedPoint { + + /** + * Convert an input dataset into its BaggedPoint representation, + * choosing subsamplingRate counts for each instance. + * Each subsamplingRate has the same number of instances as the original dataset, + * and is created by subsampling without replacement. + * @param input Input dataset. + * @param subsamplingRate Fraction of the training data used for learning decision tree. + * @param numSubsamples Number of subsamples of this RDD to take. + * @param withReplacement Sampling with/without replacement. + * @param seed Random seed. + * @return BaggedPoint dataset representation. + */ + def convertToBaggedRDD[Datum] ( + input: RDD[Datum], + subsamplingRate: Double, + numSubsamples: Int, + withReplacement: Boolean, + seed: Long = Utils.random.nextLong()): RDD[BaggedPoint[Datum]] = { + if (withReplacement) { + convertToBaggedRDDSamplingWithReplacement(input, subsamplingRate, numSubsamples, seed) + } else { + if (numSubsamples == 1 && subsamplingRate == 1.0) { + convertToBaggedRDDWithoutSampling(input) + } else { + convertToBaggedRDDSamplingWithoutReplacement(input, subsamplingRate, numSubsamples, seed) + } + } + } + + private def convertToBaggedRDDSamplingWithoutReplacement[Datum] ( + input: RDD[Datum], + subsamplingRate: Double, + numSubsamples: Int, + seed: Long): RDD[BaggedPoint[Datum]] = { + input.mapPartitionsWithIndex { (partitionIndex, instances) => + // Use random seed = seed + partitionIndex + 1 to make generation reproducible. + val rng = new XORShiftRandom + rng.setSeed(seed + partitionIndex + 1) + instances.map { instance => + val subsampleWeights = new Array[Double](numSubsamples) + var subsampleIndex = 0 + while (subsampleIndex < numSubsamples) { + val x = rng.nextDouble() + subsampleWeights(subsampleIndex) = { + if (x < subsamplingRate) 1.0 else 0.0 + } + subsampleIndex += 1 + } + new BaggedPoint(instance, subsampleWeights) + } + } + } + + private def convertToBaggedRDDSamplingWithReplacement[Datum] ( + input: RDD[Datum], + subsample: Double, + numSubsamples: Int, + seed: Long): RDD[BaggedPoint[Datum]] = { + input.mapPartitionsWithIndex { (partitionIndex, instances) => + // Use random seed = seed + partitionIndex + 1 to make generation reproducible. + val poisson = new PoissonDistribution(subsample) + poisson.reseedRandomGenerator(seed + partitionIndex + 1) + instances.map { instance => + val subsampleWeights = new Array[Double](numSubsamples) + var subsampleIndex = 0 + while (subsampleIndex < numSubsamples) { + subsampleWeights(subsampleIndex) = poisson.sample() + subsampleIndex += 1 + } + new BaggedPoint(instance, subsampleWeights) + } + } + } + + private def convertToBaggedRDDWithoutSampling[Datum] ( + input: RDD[Datum]): RDD[BaggedPoint[Datum]] = { + input.map(datum => new BaggedPoint(datum, Array(1.0))) + } + +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DTStatsAggregator.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DTStatsAggregator.scala new file mode 100644 index 0000000000..61091bb803 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DTStatsAggregator.scala @@ -0,0 +1,181 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.tree.impl + +import org.apache.spark.mllib.tree.impurity._ + + + +/** + * DecisionTree statistics aggregator for a node. + * This holds a flat array of statistics for a set of (features, bins) + * and helps with indexing. + * This class is abstract to support learning with and without feature subsampling. + */ +private[spark] class DTStatsAggregator( + val metadata: DecisionTreeMetadata, + featureSubset: Option[Array[Int]]) extends Serializable { + + /** + * [[ImpurityAggregator]] instance specifying the impurity type. + */ + val impurityAggregator: ImpurityAggregator = metadata.impurity match { + case Gini => new GiniAggregator(metadata.numClasses) + case Entropy => new EntropyAggregator(metadata.numClasses) + case Variance => new VarianceAggregator() + case _ => throw new IllegalArgumentException(s"Bad impurity parameter: ${metadata.impurity}") + } + + /** + * Number of elements (Double values) used for the sufficient statistics of each bin. + */ + private val statsSize: Int = impurityAggregator.statsSize + + /** + * Number of bins for each feature. This is indexed by the feature index. + */ + private val numBins: Array[Int] = { + if (featureSubset.isDefined) { + featureSubset.get.map(metadata.numBins(_)) + } else { + metadata.numBins + } + } + + /** + * Offset for each feature for calculating indices into the [[allStats]] array. + */ + private val featureOffsets: Array[Int] = { + numBins.scanLeft(0)((total, nBins) => total + statsSize * nBins) + } + + /** + * Total number of elements stored in this aggregator + */ + private val allStatsSize: Int = featureOffsets.last + + /** + * Flat array of elements. + * Index for start of stats for a (feature, bin) is: + * index = featureOffsets(featureIndex) + binIndex * statsSize + */ + private val allStats: Array[Double] = new Array[Double](allStatsSize) + + /** + * Array of parent node sufficient stats. + * + * Note: this is necessary because stats for the parent node are not available + * on the first iteration of tree learning. + */ + private val parentStats: Array[Double] = new Array[Double](statsSize) + + /** + * Get an [[ImpurityCalculator]] for a given (node, feature, bin). + * + * @param featureOffset This is a pre-computed (node, feature) offset + * from [[getFeatureOffset]]. + */ + def getImpurityCalculator(featureOffset: Int, binIndex: Int): ImpurityCalculator = { + impurityAggregator.getCalculator(allStats, featureOffset + binIndex * statsSize) + } + + /** + * Get an [[ImpurityCalculator]] for the parent node. + */ + def getParentImpurityCalculator(): ImpurityCalculator = { + impurityAggregator.getCalculator(parentStats, 0) + } + + /** + * Update the stats for a given (feature, bin) for ordered features, using the given label. + */ + def update(featureIndex: Int, binIndex: Int, label: Double, instanceWeight: Double): Unit = { + val i = featureOffsets(featureIndex) + binIndex * statsSize + impurityAggregator.update(allStats, i, label, instanceWeight) + } + + /** + * Update the parent node stats using the given label. + */ + def updateParent(label: Double, instanceWeight: Double): Unit = { + impurityAggregator.update(parentStats, 0, label, instanceWeight) + } + + /** + * Faster version of [[update]]. + * Update the stats for a given (feature, bin), using the given label. + * + * @param featureOffset This is a pre-computed feature offset + * from [[getFeatureOffset]]. + */ + def featureUpdate( + featureOffset: Int, + binIndex: Int, + label: Double, + instanceWeight: Double): Unit = { + impurityAggregator.update(allStats, featureOffset + binIndex * statsSize, + label, instanceWeight) + } + + /** + * Pre-compute feature offset for use with [[featureUpdate]]. + * For ordered features only. + */ + def getFeatureOffset(featureIndex: Int): Int = featureOffsets(featureIndex) + + /** + * For a given feature, merge the stats for two bins. + * + * @param featureOffset This is a pre-computed feature offset + * from [[getFeatureOffset]]. + * @param binIndex The other bin is merged into this bin. + * @param otherBinIndex This bin is not modified. + */ + def mergeForFeature(featureOffset: Int, binIndex: Int, otherBinIndex: Int): Unit = { + impurityAggregator.merge(allStats, featureOffset + binIndex * statsSize, + featureOffset + otherBinIndex * statsSize) + } + + /** + * Merge this aggregator with another, and returns this aggregator. + * This method modifies this aggregator in-place. + */ + def merge(other: DTStatsAggregator): DTStatsAggregator = { + require(allStatsSize == other.allStatsSize, + s"DTStatsAggregator.merge requires that both aggregators have the same length stats vectors." + + s" This aggregator is of length $allStatsSize, but the other is ${other.allStatsSize}.") + var i = 0 + // TODO: Test BLAS.axpy + while (i < allStatsSize) { + allStats(i) += other.allStats(i) + i += 1 + } + + require(statsSize == other.statsSize, + s"DTStatsAggregator.merge requires that both aggregators have the same length parent " + + s"stats vectors. This aggregator's parent stats are length $statsSize, " + + s"but the other is ${other.statsSize}.") + var j = 0 + while (j < statsSize) { + parentStats(j) += other.parentStats(j) + j += 1 + } + + this + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DecisionTreeMetadata.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DecisionTreeMetadata.scala new file mode 100644 index 0000000000..df8eb5d1f9 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DecisionTreeMetadata.scala @@ -0,0 +1,217 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.tree.impl + +import scala.collection.mutable + +import org.apache.spark.internal.Logging +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.tree.configuration.Algo._ +import org.apache.spark.mllib.tree.configuration.QuantileStrategy._ +import org.apache.spark.mllib.tree.configuration.Strategy +import org.apache.spark.mllib.tree.impurity.Impurity +import org.apache.spark.rdd.RDD + +/** + * Learning and dataset metadata for DecisionTree. + * + * @param numClasses For classification: labels can take values {0, ..., numClasses - 1}. + * For regression: fixed at 0 (no meaning). + * @param maxBins Maximum number of bins, for all features. + * @param featureArity Map: categorical feature index --> arity. + * I.e., the feature takes values in {0, ..., arity - 1}. + * @param numBins Number of bins for each feature. + */ +private[spark] class DecisionTreeMetadata( + val numFeatures: Int, + val numExamples: Long, + val numClasses: Int, + val maxBins: Int, + val featureArity: Map[Int, Int], + val unorderedFeatures: Set[Int], + val numBins: Array[Int], + val impurity: Impurity, + val quantileStrategy: QuantileStrategy, + val maxDepth: Int, + val minInstancesPerNode: Int, + val minInfoGain: Double, + val numTrees: Int, + val numFeaturesPerNode: Int) extends Serializable { + + def isUnordered(featureIndex: Int): Boolean = unorderedFeatures.contains(featureIndex) + + def isClassification: Boolean = numClasses >= 2 + + def isMulticlass: Boolean = numClasses > 2 + + def isMulticlassWithCategoricalFeatures: Boolean = isMulticlass && (featureArity.size > 0) + + def isCategorical(featureIndex: Int): Boolean = featureArity.contains(featureIndex) + + def isContinuous(featureIndex: Int): Boolean = !featureArity.contains(featureIndex) + + /** + * Number of splits for the given feature. + * For unordered features, there is 1 bin per split. + * For ordered features, there is 1 more bin than split. + */ + def numSplits(featureIndex: Int): Int = if (isUnordered(featureIndex)) { + numBins(featureIndex) + } else { + numBins(featureIndex) - 1 + } + + + /** + * Set number of splits for a continuous feature. + * For a continuous feature, number of bins is number of splits plus 1. + */ + def setNumSplits(featureIndex: Int, numSplits: Int) { + require(isContinuous(featureIndex), + s"Only number of bin for a continuous feature can be set.") + numBins(featureIndex) = numSplits + 1 + } + + /** + * Indicates if feature subsampling is being used. + */ + def subsamplingFeatures: Boolean = numFeatures != numFeaturesPerNode + +} + +private[spark] object DecisionTreeMetadata extends Logging { + + /** + * Construct a [[DecisionTreeMetadata]] instance for this dataset and parameters. + * This computes which categorical features will be ordered vs. unordered, + * as well as the number of splits and bins for each feature. + */ + def buildMetadata( + input: RDD[LabeledPoint], + strategy: Strategy, + numTrees: Int, + featureSubsetStrategy: String): DecisionTreeMetadata = { + + val numFeatures = input.map(_.features.size).take(1).headOption.getOrElse { + throw new IllegalArgumentException(s"DecisionTree requires size of input RDD > 0, " + + s"but was given by empty one.") + } + val numExamples = input.count() + val numClasses = strategy.algo match { + case Classification => strategy.numClasses + case Regression => 0 + } + + val maxPossibleBins = math.min(strategy.maxBins, numExamples).toInt + if (maxPossibleBins < strategy.maxBins) { + logWarning(s"DecisionTree reducing maxBins from ${strategy.maxBins} to $maxPossibleBins" + + s" (= number of training instances)") + } + + // We check the number of bins here against maxPossibleBins. + // This needs to be checked here instead of in Strategy since maxPossibleBins can be modified + // based on the number of training examples. + if (strategy.categoricalFeaturesInfo.nonEmpty) { + val maxCategoriesPerFeature = strategy.categoricalFeaturesInfo.values.max + val maxCategory = + strategy.categoricalFeaturesInfo.find(_._2 == maxCategoriesPerFeature).get._1 + require(maxCategoriesPerFeature <= maxPossibleBins, + s"DecisionTree requires maxBins (= $maxPossibleBins) to be at least as large as the " + + s"number of values in each categorical feature, but categorical feature $maxCategory " + + s"has $maxCategoriesPerFeature values. Considering remove this and other categorical " + + "features with a large number of values, or add more training examples.") + } + + val unorderedFeatures = new mutable.HashSet[Int]() + val numBins = Array.fill[Int](numFeatures)(maxPossibleBins) + if (numClasses > 2) { + // Multiclass classification + val maxCategoriesForUnorderedFeature = + ((math.log(maxPossibleBins / 2 + 1) / math.log(2.0)) + 1).floor.toInt + strategy.categoricalFeaturesInfo.foreach { case (featureIndex, numCategories) => + // Hack: If a categorical feature has only 1 category, we treat it as continuous. + // TODO(SPARK-9957): Handle this properly by filtering out those features. + if (numCategories > 1) { + // Decide if some categorical features should be treated as unordered features, + // which require 2 * ((1 << numCategories - 1) - 1) bins. + // We do this check with log values to prevent overflows in case numCategories is large. + // The next check is equivalent to: 2 * ((1 << numCategories - 1) - 1) <= maxBins + if (numCategories <= maxCategoriesForUnorderedFeature) { + unorderedFeatures.add(featureIndex) + numBins(featureIndex) = numUnorderedBins(numCategories) + } else { + numBins(featureIndex) = numCategories + } + } + } + } else { + // Binary classification or regression + strategy.categoricalFeaturesInfo.foreach { case (featureIndex, numCategories) => + // If a categorical feature has only 1 category, we treat it as continuous: SPARK-9957 + if (numCategories > 1) { + numBins(featureIndex) = numCategories + } + } + } + + // Set number of features to use per node (for random forests). + val _featureSubsetStrategy = featureSubsetStrategy match { + case "auto" => + if (numTrees == 1) { + "all" + } else { + if (strategy.algo == Classification) { + "sqrt" + } else { + "onethird" + } + } + case _ => featureSubsetStrategy + } + val numFeaturesPerNode: Int = _featureSubsetStrategy match { + case "all" => numFeatures + case "sqrt" => math.sqrt(numFeatures).ceil.toInt + case "log2" => math.max(1, (math.log(numFeatures) / math.log(2)).ceil.toInt) + case "onethird" => (numFeatures / 3.0).ceil.toInt + } + + new DecisionTreeMetadata(numFeatures, numExamples, numClasses, numBins.max, + strategy.categoricalFeaturesInfo, unorderedFeatures.toSet, numBins, + strategy.impurity, strategy.quantileCalculationStrategy, strategy.maxDepth, + strategy.minInstancesPerNode, strategy.minInfoGain, numTrees, numFeaturesPerNode) + } + + /** + * Version of [[DecisionTreeMetadata#buildMetadata]] for DecisionTree. + */ + def buildMetadata( + input: RDD[LabeledPoint], + strategy: Strategy): DecisionTreeMetadata = { + buildMetadata(input, strategy, numTrees = 1, featureSubsetStrategy = "all") + } + + /** + * Given the arity of a categorical feature (arity = number of categories), + * return the number of bins for the feature if it is to be treated as an unordered feature. + * There is 1 split for every partitioning of categories into 2 disjoint, non-empty sets; + * there are math.pow(2, arity - 1) - 1 such splits. + * Each split has 2 corresponding bins. + */ + def numUnorderedBins(arity: Int): Int = (1 << arity - 1) - 1 + +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala index b37f4e891e..0749d93b7d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala @@ -25,7 +25,6 @@ import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} import org.apache.spark.mllib.tree.configuration.{BoostingStrategy => OldBoostingStrategy} -import org.apache.spark.mllib.tree.impl.TimeTracker import org.apache.spark.mllib.tree.impurity.{Variance => OldVariance} import org.apache.spark.mllib.tree.loss.{Loss => OldLoss} import org.apache.spark.rdd.RDD diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/NodeIdCache.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/NodeIdCache.scala index 2c8286766f..9d697a36b6 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/NodeIdCache.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/NodeIdCache.scala @@ -26,7 +26,6 @@ import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.spark.annotation.DeveloperApi import org.apache.spark.internal.Logging import org.apache.spark.ml.tree.{LearningNode, Split} -import org.apache.spark.mllib.tree.impl.BaggedPoint import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala index cccf052b3e..7b1fd089f2 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala @@ -28,8 +28,6 @@ import org.apache.spark.ml.regression.DecisionTreeRegressionModel import org.apache.spark.ml.tree._ import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy} -import org.apache.spark.mllib.tree.impl.{BaggedPoint, DecisionTreeMetadata, DTStatsAggregator, - TimeTracker} import org.apache.spark.mllib.tree.impurity.ImpurityCalculator import org.apache.spark.mllib.tree.model.ImpurityStats import org.apache.spark.rdd.RDD @@ -330,7 +328,7 @@ private[spark] object RandomForest extends Logging { /** * Given a group of nodes, this finds the best split for each node. * - * @param input Training data: RDD of [[org.apache.spark.mllib.tree.impl.TreePoint]] + * @param input Training data: RDD of [[org.apache.spark.ml.tree.impl.TreePoint]] * @param metadata Learning and dataset metadata * @param topNodes Root node for each tree. Used for matching instances with nodes. * @param nodesForGroup Mapping: treeIndex --> nodes to be split in tree diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/TimeTracker.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/TimeTracker.scala new file mode 100644 index 0000000000..4cc250aa46 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/TimeTracker.scala @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.tree.impl + +import scala.collection.mutable.{HashMap => MutableHashMap} + +/** + * Time tracker implementation which holds labeled timers. + */ +private[spark] class TimeTracker extends Serializable { + + private val starts: MutableHashMap[String, Long] = new MutableHashMap[String, Long]() + + private val totals: MutableHashMap[String, Long] = new MutableHashMap[String, Long]() + + /** + * Starts a new timer, or re-starts a stopped timer. + */ + def start(timerLabel: String): Unit = { + val currentTime = System.nanoTime() + if (starts.contains(timerLabel)) { + throw new RuntimeException(s"TimeTracker.start(timerLabel) called again on" + + s" timerLabel = $timerLabel before that timer was stopped.") + } + starts(timerLabel) = currentTime + } + + /** + * Stops a timer and returns the elapsed time in seconds. + */ + def stop(timerLabel: String): Double = { + val currentTime = System.nanoTime() + if (!starts.contains(timerLabel)) { + throw new RuntimeException(s"TimeTracker.stop(timerLabel) called on" + + s" timerLabel = $timerLabel, but that timer was not started.") + } + val elapsed = currentTime - starts(timerLabel) + starts.remove(timerLabel) + if (totals.contains(timerLabel)) { + totals(timerLabel) += elapsed + } else { + totals(timerLabel) = elapsed + } + elapsed / 1e9 + } + + /** + * Print all timing results in seconds. + */ + override def toString: String = { + totals.map { case (label, elapsed) => + s" $label: ${elapsed / 1e9}" + }.mkString("\n") + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/TreePoint.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/TreePoint.scala index 9fa27e5e1f..3a2bf3c725 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/TreePoint.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/TreePoint.scala @@ -19,7 +19,6 @@ package org.apache.spark.ml.tree.impl import org.apache.spark.ml.tree.{ContinuousSplit, Split} import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.mllib.tree.impl.DecisionTreeMetadata import org.apache.spark.rdd.RDD diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala index d166dc7905..0f0c6b466d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala @@ -20,11 +20,11 @@ package org.apache.spark.mllib.tree import org.apache.spark.annotation.Since import org.apache.spark.api.java.JavaRDD import org.apache.spark.internal.Logging +import org.apache.spark.ml.tree.impl.TimeTracker import org.apache.spark.mllib.impl.PeriodicRDDCheckpointer import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.configuration.Algo._ import org.apache.spark.mllib.tree.configuration.BoostingStrategy -import org.apache.spark.mllib.tree.impl.TimeTracker import org.apache.spark.mllib.tree.impurity.Variance import org.apache.spark.mllib.tree.model.{DecisionTreeModel, GradientBoostedTreesModel} import org.apache.spark.rdd.RDD @@ -165,6 +165,7 @@ object GradientBoostedTrees extends Logging { /** * Internal method for performing regression using trees as base learners. + * * @param input Training dataset. * @param validationInput Validation dataset, ignored if validate is set to false. * @param boostingStrategy Boosting parameters. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/BaggedPoint.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/BaggedPoint.scala deleted file mode 100644 index 572815df0b..0000000000 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/BaggedPoint.scala +++ /dev/null @@ -1,125 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.mllib.tree.impl - -import org.apache.commons.math3.distribution.PoissonDistribution - -import org.apache.spark.rdd.RDD -import org.apache.spark.util.Utils -import org.apache.spark.util.random.XORShiftRandom - -/** - * Internal representation of a datapoint which belongs to several subsamples of the same dataset, - * particularly for bagging (e.g., for random forests). - * - * This holds one instance, as well as an array of weights which represent the (weighted) - * number of times which this instance appears in each subsamplingRate. - * E.g., (datum, [1, 0, 4]) indicates that there are 3 subsamples of the dataset and that - * this datum has 1 copy, 0 copies, and 4 copies in the 3 subsamples, respectively. - * - * @param datum Data instance - * @param subsampleWeights Weight of this instance in each subsampled dataset. - * - * TODO: This does not currently support (Double) weighted instances. Once MLlib has weighted - * dataset support, update. (We store subsampleWeights as Double for this future extension.) - */ -private[spark] class BaggedPoint[Datum](val datum: Datum, val subsampleWeights: Array[Double]) - extends Serializable - -private[spark] object BaggedPoint { - - /** - * Convert an input dataset into its BaggedPoint representation, - * choosing subsamplingRate counts for each instance. - * Each subsamplingRate has the same number of instances as the original dataset, - * and is created by subsampling without replacement. - * @param input Input dataset. - * @param subsamplingRate Fraction of the training data used for learning decision tree. - * @param numSubsamples Number of subsamples of this RDD to take. - * @param withReplacement Sampling with/without replacement. - * @param seed Random seed. - * @return BaggedPoint dataset representation. - */ - def convertToBaggedRDD[Datum] ( - input: RDD[Datum], - subsamplingRate: Double, - numSubsamples: Int, - withReplacement: Boolean, - seed: Long = Utils.random.nextLong()): RDD[BaggedPoint[Datum]] = { - if (withReplacement) { - convertToBaggedRDDSamplingWithReplacement(input, subsamplingRate, numSubsamples, seed) - } else { - if (numSubsamples == 1 && subsamplingRate == 1.0) { - convertToBaggedRDDWithoutSampling(input) - } else { - convertToBaggedRDDSamplingWithoutReplacement(input, subsamplingRate, numSubsamples, seed) - } - } - } - - private def convertToBaggedRDDSamplingWithoutReplacement[Datum] ( - input: RDD[Datum], - subsamplingRate: Double, - numSubsamples: Int, - seed: Long): RDD[BaggedPoint[Datum]] = { - input.mapPartitionsWithIndex { (partitionIndex, instances) => - // Use random seed = seed + partitionIndex + 1 to make generation reproducible. - val rng = new XORShiftRandom - rng.setSeed(seed + partitionIndex + 1) - instances.map { instance => - val subsampleWeights = new Array[Double](numSubsamples) - var subsampleIndex = 0 - while (subsampleIndex < numSubsamples) { - val x = rng.nextDouble() - subsampleWeights(subsampleIndex) = { - if (x < subsamplingRate) 1.0 else 0.0 - } - subsampleIndex += 1 - } - new BaggedPoint(instance, subsampleWeights) - } - } - } - - private def convertToBaggedRDDSamplingWithReplacement[Datum] ( - input: RDD[Datum], - subsample: Double, - numSubsamples: Int, - seed: Long): RDD[BaggedPoint[Datum]] = { - input.mapPartitionsWithIndex { (partitionIndex, instances) => - // Use random seed = seed + partitionIndex + 1 to make generation reproducible. - val poisson = new PoissonDistribution(subsample) - poisson.reseedRandomGenerator(seed + partitionIndex + 1) - instances.map { instance => - val subsampleWeights = new Array[Double](numSubsamples) - var subsampleIndex = 0 - while (subsampleIndex < numSubsamples) { - subsampleWeights(subsampleIndex) = poisson.sample() - subsampleIndex += 1 - } - new BaggedPoint(instance, subsampleWeights) - } - } - } - - private def convertToBaggedRDDWithoutSampling[Datum] ( - input: RDD[Datum]): RDD[BaggedPoint[Datum]] = { - input.map(datum => new BaggedPoint(datum, Array(1.0))) - } - -} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala deleted file mode 100644 index c745e9f8db..0000000000 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala +++ /dev/null @@ -1,178 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.mllib.tree.impl - -import org.apache.spark.mllib.tree.impurity._ - - - -/** - * DecisionTree statistics aggregator for a node. - * This holds a flat array of statistics for a set of (features, bins) - * and helps with indexing. - * This class is abstract to support learning with and without feature subsampling. - */ -private[spark] class DTStatsAggregator( - val metadata: DecisionTreeMetadata, - featureSubset: Option[Array[Int]]) extends Serializable { - - /** - * [[ImpurityAggregator]] instance specifying the impurity type. - */ - val impurityAggregator: ImpurityAggregator = metadata.impurity match { - case Gini => new GiniAggregator(metadata.numClasses) - case Entropy => new EntropyAggregator(metadata.numClasses) - case Variance => new VarianceAggregator() - case _ => throw new IllegalArgumentException(s"Bad impurity parameter: ${metadata.impurity}") - } - - /** - * Number of elements (Double values) used for the sufficient statistics of each bin. - */ - private val statsSize: Int = impurityAggregator.statsSize - - /** - * Number of bins for each feature. This is indexed by the feature index. - */ - private val numBins: Array[Int] = { - if (featureSubset.isDefined) { - featureSubset.get.map(metadata.numBins(_)) - } else { - metadata.numBins - } - } - - /** - * Offset for each feature for calculating indices into the [[allStats]] array. - */ - private val featureOffsets: Array[Int] = { - numBins.scanLeft(0)((total, nBins) => total + statsSize * nBins) - } - - /** - * Total number of elements stored in this aggregator - */ - private val allStatsSize: Int = featureOffsets.last - - /** - * Flat array of elements. - * Index for start of stats for a (feature, bin) is: - * index = featureOffsets(featureIndex) + binIndex * statsSize - */ - private val allStats: Array[Double] = new Array[Double](allStatsSize) - - /** - * Array of parent node sufficient stats. - * - * Note: this is necessary because stats for the parent node are not available - * on the first iteration of tree learning. - */ - private val parentStats: Array[Double] = new Array[Double](statsSize) - - /** - * Get an [[ImpurityCalculator]] for a given (node, feature, bin). - * @param featureOffset This is a pre-computed (node, feature) offset - * from [[getFeatureOffset]]. - */ - def getImpurityCalculator(featureOffset: Int, binIndex: Int): ImpurityCalculator = { - impurityAggregator.getCalculator(allStats, featureOffset + binIndex * statsSize) - } - - /** - * Get an [[ImpurityCalculator]] for the parent node. - */ - def getParentImpurityCalculator(): ImpurityCalculator = { - impurityAggregator.getCalculator(parentStats, 0) - } - - /** - * Update the stats for a given (feature, bin) for ordered features, using the given label. - */ - def update(featureIndex: Int, binIndex: Int, label: Double, instanceWeight: Double): Unit = { - val i = featureOffsets(featureIndex) + binIndex * statsSize - impurityAggregator.update(allStats, i, label, instanceWeight) - } - - /** - * Update the parent node stats using the given label. - */ - def updateParent(label: Double, instanceWeight: Double): Unit = { - impurityAggregator.update(parentStats, 0, label, instanceWeight) - } - - /** - * Faster version of [[update]]. - * Update the stats for a given (feature, bin), using the given label. - * @param featureOffset This is a pre-computed feature offset - * from [[getFeatureOffset]]. - */ - def featureUpdate( - featureOffset: Int, - binIndex: Int, - label: Double, - instanceWeight: Double): Unit = { - impurityAggregator.update(allStats, featureOffset + binIndex * statsSize, - label, instanceWeight) - } - - /** - * Pre-compute feature offset for use with [[featureUpdate]]. - * For ordered features only. - */ - def getFeatureOffset(featureIndex: Int): Int = featureOffsets(featureIndex) - - /** - * For a given feature, merge the stats for two bins. - * @param featureOffset This is a pre-computed feature offset - * from [[getFeatureOffset]]. - * @param binIndex The other bin is merged into this bin. - * @param otherBinIndex This bin is not modified. - */ - def mergeForFeature(featureOffset: Int, binIndex: Int, otherBinIndex: Int): Unit = { - impurityAggregator.merge(allStats, featureOffset + binIndex * statsSize, - featureOffset + otherBinIndex * statsSize) - } - - /** - * Merge this aggregator with another, and returns this aggregator. - * This method modifies this aggregator in-place. - */ - def merge(other: DTStatsAggregator): DTStatsAggregator = { - require(allStatsSize == other.allStatsSize, - s"DTStatsAggregator.merge requires that both aggregators have the same length stats vectors." - + s" This aggregator is of length $allStatsSize, but the other is ${other.allStatsSize}.") - var i = 0 - // TODO: Test BLAS.axpy - while (i < allStatsSize) { - allStats(i) += other.allStats(i) - i += 1 - } - - require(statsSize == other.statsSize, - s"DTStatsAggregator.merge requires that both aggregators have the same length parent " + - s"stats vectors. This aggregator's parent stats are length $statsSize, " + - s"but the other is ${other.statsSize}.") - var j = 0 - while (j < statsSize) { - parentStats(j) += other.parentStats(j) - j += 1 - } - - this - } -} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala deleted file mode 100644 index 4f27dc44ef..0000000000 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala +++ /dev/null @@ -1,217 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.mllib.tree.impl - -import scala.collection.mutable - -import org.apache.spark.internal.Logging -import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.mllib.tree.configuration.Algo._ -import org.apache.spark.mllib.tree.configuration.QuantileStrategy._ -import org.apache.spark.mllib.tree.configuration.Strategy -import org.apache.spark.mllib.tree.impurity.Impurity -import org.apache.spark.rdd.RDD - -/** - * Learning and dataset metadata for DecisionTree. - * - * @param numClasses For classification: labels can take values {0, ..., numClasses - 1}. - * For regression: fixed at 0 (no meaning). - * @param maxBins Maximum number of bins, for all features. - * @param featureArity Map: categorical feature index --> arity. - * I.e., the feature takes values in {0, ..., arity - 1}. - * @param numBins Number of bins for each feature. - */ -private[spark] class DecisionTreeMetadata( - val numFeatures: Int, - val numExamples: Long, - val numClasses: Int, - val maxBins: Int, - val featureArity: Map[Int, Int], - val unorderedFeatures: Set[Int], - val numBins: Array[Int], - val impurity: Impurity, - val quantileStrategy: QuantileStrategy, - val maxDepth: Int, - val minInstancesPerNode: Int, - val minInfoGain: Double, - val numTrees: Int, - val numFeaturesPerNode: Int) extends Serializable { - - def isUnordered(featureIndex: Int): Boolean = unorderedFeatures.contains(featureIndex) - - def isClassification: Boolean = numClasses >= 2 - - def isMulticlass: Boolean = numClasses > 2 - - def isMulticlassWithCategoricalFeatures: Boolean = isMulticlass && (featureArity.size > 0) - - def isCategorical(featureIndex: Int): Boolean = featureArity.contains(featureIndex) - - def isContinuous(featureIndex: Int): Boolean = !featureArity.contains(featureIndex) - - /** - * Number of splits for the given feature. - * For unordered features, there is 1 bin per split. - * For ordered features, there is 1 more bin than split. - */ - def numSplits(featureIndex: Int): Int = if (isUnordered(featureIndex)) { - numBins(featureIndex) - } else { - numBins(featureIndex) - 1 - } - - - /** - * Set number of splits for a continuous feature. - * For a continuous feature, number of bins is number of splits plus 1. - */ - def setNumSplits(featureIndex: Int, numSplits: Int) { - require(isContinuous(featureIndex), - s"Only number of bin for a continuous feature can be set.") - numBins(featureIndex) = numSplits + 1 - } - - /** - * Indicates if feature subsampling is being used. - */ - def subsamplingFeatures: Boolean = numFeatures != numFeaturesPerNode - -} - -private[spark] object DecisionTreeMetadata extends Logging { - - /** - * Construct a [[DecisionTreeMetadata]] instance for this dataset and parameters. - * This computes which categorical features will be ordered vs. unordered, - * as well as the number of splits and bins for each feature. - */ - def buildMetadata( - input: RDD[LabeledPoint], - strategy: Strategy, - numTrees: Int, - featureSubsetStrategy: String): DecisionTreeMetadata = { - - val numFeatures = input.map(_.features.size).take(1).headOption.getOrElse { - throw new IllegalArgumentException(s"DecisionTree requires size of input RDD > 0, " + - s"but was given by empty one.") - } - val numExamples = input.count() - val numClasses = strategy.algo match { - case Classification => strategy.numClasses - case Regression => 0 - } - - val maxPossibleBins = math.min(strategy.maxBins, numExamples).toInt - if (maxPossibleBins < strategy.maxBins) { - logWarning(s"DecisionTree reducing maxBins from ${strategy.maxBins} to $maxPossibleBins" + - s" (= number of training instances)") - } - - // We check the number of bins here against maxPossibleBins. - // This needs to be checked here instead of in Strategy since maxPossibleBins can be modified - // based on the number of training examples. - if (strategy.categoricalFeaturesInfo.nonEmpty) { - val maxCategoriesPerFeature = strategy.categoricalFeaturesInfo.values.max - val maxCategory = - strategy.categoricalFeaturesInfo.find(_._2 == maxCategoriesPerFeature).get._1 - require(maxCategoriesPerFeature <= maxPossibleBins, - s"DecisionTree requires maxBins (= $maxPossibleBins) to be at least as large as the " + - s"number of values in each categorical feature, but categorical feature $maxCategory " + - s"has $maxCategoriesPerFeature values. Considering remove this and other categorical " + - "features with a large number of values, or add more training examples.") - } - - val unorderedFeatures = new mutable.HashSet[Int]() - val numBins = Array.fill[Int](numFeatures)(maxPossibleBins) - if (numClasses > 2) { - // Multiclass classification - val maxCategoriesForUnorderedFeature = - ((math.log(maxPossibleBins / 2 + 1) / math.log(2.0)) + 1).floor.toInt - strategy.categoricalFeaturesInfo.foreach { case (featureIndex, numCategories) => - // Hack: If a categorical feature has only 1 category, we treat it as continuous. - // TODO(SPARK-9957): Handle this properly by filtering out those features. - if (numCategories > 1) { - // Decide if some categorical features should be treated as unordered features, - // which require 2 * ((1 << numCategories - 1) - 1) bins. - // We do this check with log values to prevent overflows in case numCategories is large. - // The next check is equivalent to: 2 * ((1 << numCategories - 1) - 1) <= maxBins - if (numCategories <= maxCategoriesForUnorderedFeature) { - unorderedFeatures.add(featureIndex) - numBins(featureIndex) = numUnorderedBins(numCategories) - } else { - numBins(featureIndex) = numCategories - } - } - } - } else { - // Binary classification or regression - strategy.categoricalFeaturesInfo.foreach { case (featureIndex, numCategories) => - // If a categorical feature has only 1 category, we treat it as continuous: SPARK-9957 - if (numCategories > 1) { - numBins(featureIndex) = numCategories - } - } - } - - // Set number of features to use per node (for random forests). - val _featureSubsetStrategy = featureSubsetStrategy match { - case "auto" => - if (numTrees == 1) { - "all" - } else { - if (strategy.algo == Classification) { - "sqrt" - } else { - "onethird" - } - } - case _ => featureSubsetStrategy - } - val numFeaturesPerNode: Int = _featureSubsetStrategy match { - case "all" => numFeatures - case "sqrt" => math.sqrt(numFeatures).ceil.toInt - case "log2" => math.max(1, (math.log(numFeatures) / math.log(2)).ceil.toInt) - case "onethird" => (numFeatures / 3.0).ceil.toInt - } - - new DecisionTreeMetadata(numFeatures, numExamples, numClasses, numBins.max, - strategy.categoricalFeaturesInfo, unorderedFeatures.toSet, numBins, - strategy.impurity, strategy.quantileCalculationStrategy, strategy.maxDepth, - strategy.minInstancesPerNode, strategy.minInfoGain, numTrees, numFeaturesPerNode) - } - - /** - * Version of [[DecisionTreeMetadata#buildMetadata]] for DecisionTree. - */ - def buildMetadata( - input: RDD[LabeledPoint], - strategy: Strategy): DecisionTreeMetadata = { - buildMetadata(input, strategy, numTrees = 1, featureSubsetStrategy = "all") - } - - /** - * Given the arity of a categorical feature (arity = number of categories), - * return the number of bins for the feature if it is to be treated as an unordered feature. - * There is 1 split for every partitioning of categories into 2 disjoint, non-empty sets; - * there are math.pow(2, arity - 1) - 1 such splits. - * Each split has 2 corresponding bins. - */ - def numUnorderedBins(arity: Int): Int = (1 << arity - 1) - 1 - -} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/NodeIdCache.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/NodeIdCache.scala deleted file mode 100644 index dc7e969f7b..0000000000 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/NodeIdCache.scala +++ /dev/null @@ -1,195 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.mllib.tree.impl - -import scala.collection.mutable - -import org.apache.hadoop.fs.{FileSystem, Path} - -import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.mllib.tree.configuration.FeatureType._ -import org.apache.spark.mllib.tree.model.{Bin, Node, Split} -import org.apache.spark.rdd.RDD -import org.apache.spark.storage.StorageLevel - -/** - * :: DeveloperApi :: - * This is used by the node id cache to find the child id that a data point would belong to. - * @param split Split information. - * @param nodeIndex The current node index of a data point that this will update. - */ -@DeveloperApi -private[tree] case class NodeIndexUpdater( - split: Split, - nodeIndex: Int) { - /** - * Determine a child node index based on the feature value and the split. - * @param binnedFeatures Binned feature values. - * @param bins Bin information to convert the bin indices to approximate feature values. - * @return Child node index to update to. - */ - def updateNodeIndex(binnedFeatures: Array[Int], bins: Array[Array[Bin]]): Int = { - if (split.featureType == Continuous) { - val featureIndex = split.feature - val binIndex = binnedFeatures(featureIndex) - val featureValueUpperBound = bins(featureIndex)(binIndex).highSplit.threshold - if (featureValueUpperBound <= split.threshold) { - Node.leftChildIndex(nodeIndex) - } else { - Node.rightChildIndex(nodeIndex) - } - } else { - if (split.categories.contains(binnedFeatures(split.feature).toDouble)) { - Node.leftChildIndex(nodeIndex) - } else { - Node.rightChildIndex(nodeIndex) - } - } - } -} - -/** - * :: DeveloperApi :: - * A given TreePoint would belong to a particular node per tree. - * Each row in the nodeIdsForInstances RDD is an array over trees of the node index - * in each tree. Initially, values should all be 1 for root node. - * The nodeIdsForInstances RDD needs to be updated at each iteration. - * @param nodeIdsForInstances The initial values in the cache - * (should be an Array of all 1's (meaning the root nodes)). - * @param checkpointInterval The checkpointing interval - * (how often should the cache be checkpointed.). - */ -@DeveloperApi -private[spark] class NodeIdCache( - var nodeIdsForInstances: RDD[Array[Int]], - val checkpointInterval: Int) { - - // Keep a reference to a previous node Ids for instances. - // Because we will keep on re-persisting updated node Ids, - // we want to unpersist the previous RDD. - private var prevNodeIdsForInstances: RDD[Array[Int]] = null - - // To keep track of the past checkpointed RDDs. - private val checkpointQueue = mutable.Queue[RDD[Array[Int]]]() - private var rddUpdateCount = 0 - - /** - * Update the node index values in the cache. - * This updates the RDD and its lineage. - * TODO: Passing bin information to executors seems unnecessary and costly. - * @param data The RDD of training rows. - * @param nodeIdUpdaters A map of node index updaters. - * The key is the indices of nodes that we want to update. - * @param bins Bin information needed to find child node indices. - */ - def updateNodeIndices( - data: RDD[BaggedPoint[TreePoint]], - nodeIdUpdaters: Array[mutable.Map[Int, NodeIndexUpdater]], - bins: Array[Array[Bin]]): Unit = { - if (prevNodeIdsForInstances != null) { - // Unpersist the previous one if one exists. - prevNodeIdsForInstances.unpersist() - } - - prevNodeIdsForInstances = nodeIdsForInstances - nodeIdsForInstances = data.zip(nodeIdsForInstances).map { - case (point, node) => { - var treeId = 0 - while (treeId < nodeIdUpdaters.length) { - val nodeIdUpdater = nodeIdUpdaters(treeId).getOrElse(node(treeId), null) - if (nodeIdUpdater != null) { - val newNodeIndex = nodeIdUpdater.updateNodeIndex( - binnedFeatures = point.datum.binnedFeatures, - bins = bins) - node(treeId) = newNodeIndex - } - - treeId += 1 - } - - node - } - } - - // Keep on persisting new ones. - nodeIdsForInstances.persist(StorageLevel.MEMORY_AND_DISK) - rddUpdateCount += 1 - - // Handle checkpointing if the directory is not None. - if (nodeIdsForInstances.sparkContext.getCheckpointDir.nonEmpty && - (rddUpdateCount % checkpointInterval) == 0) { - // Let's see if we can delete previous checkpoints. - var canDelete = true - while (checkpointQueue.size > 1 && canDelete) { - // We can delete the oldest checkpoint iff - // the next checkpoint actually exists in the file system. - if (checkpointQueue.get(1).get.getCheckpointFile.isDefined) { - val old = checkpointQueue.dequeue() - - // Since the old checkpoint is not deleted by Spark, - // we'll manually delete it here. - val fs = FileSystem.get(old.sparkContext.hadoopConfiguration) - fs.delete(new Path(old.getCheckpointFile.get), true) - } else { - canDelete = false - } - } - - nodeIdsForInstances.checkpoint() - checkpointQueue.enqueue(nodeIdsForInstances) - } - } - - /** - * Call this after training is finished to delete any remaining checkpoints. - */ - def deleteAllCheckpoints(): Unit = { - while (checkpointQueue.nonEmpty) { - val old = checkpointQueue.dequeue() - for (checkpointFile <- old.getCheckpointFile) { - val fs = FileSystem.get(old.sparkContext.hadoopConfiguration) - fs.delete(new Path(checkpointFile), true) - } - } - if (prevNodeIdsForInstances != null) { - // Unpersist the previous one if one exists. - prevNodeIdsForInstances.unpersist() - } - } -} - -private[spark] object NodeIdCache { - /** - * Initialize the node Id cache with initial node Id values. - * @param data The RDD of training rows. - * @param numTrees The number of trees that we want to create cache for. - * @param checkpointInterval The checkpointing interval - * (how often should the cache be checkpointed.). - * @param initVal The initial values in the cache. - * @return A node Id cache containing an RDD of initial root node Indices. - */ - def init( - data: RDD[BaggedPoint[TreePoint]], - numTrees: Int, - checkpointInterval: Int, - initVal: Int = 1): NodeIdCache = { - new NodeIdCache( - data.map(_ => Array.fill[Int](numTrees)(initVal)), - checkpointInterval) - } -} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TimeTracker.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TimeTracker.scala deleted file mode 100644 index 70afaa162b..0000000000 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TimeTracker.scala +++ /dev/null @@ -1,70 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.mllib.tree.impl - -import scala.collection.mutable.{HashMap => MutableHashMap} - -/** - * Time tracker implementation which holds labeled timers. - */ -private[spark] class TimeTracker extends Serializable { - - private val starts: MutableHashMap[String, Long] = new MutableHashMap[String, Long]() - - private val totals: MutableHashMap[String, Long] = new MutableHashMap[String, Long]() - - /** - * Starts a new timer, or re-starts a stopped timer. - */ - def start(timerLabel: String): Unit = { - val currentTime = System.nanoTime() - if (starts.contains(timerLabel)) { - throw new RuntimeException(s"TimeTracker.start(timerLabel) called again on" + - s" timerLabel = $timerLabel before that timer was stopped.") - } - starts(timerLabel) = currentTime - } - - /** - * Stops a timer and returns the elapsed time in seconds. - */ - def stop(timerLabel: String): Double = { - val currentTime = System.nanoTime() - if (!starts.contains(timerLabel)) { - throw new RuntimeException(s"TimeTracker.stop(timerLabel) called on" + - s" timerLabel = $timerLabel, but that timer was not started.") - } - val elapsed = currentTime - starts(timerLabel) - starts.remove(timerLabel) - if (totals.contains(timerLabel)) { - totals(timerLabel) += elapsed - } else { - totals(timerLabel) = elapsed - } - elapsed / 1e9 - } - - /** - * Print all timing results in seconds. - */ - override def toString: String = { - totals.map { case (label, elapsed) => - s" $label: ${elapsed / 1e9}" - }.mkString("\n") - } -} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TreePoint.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TreePoint.scala deleted file mode 100644 index 21919d69a3..0000000000 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TreePoint.scala +++ /dev/null @@ -1,150 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.mllib.tree.impl - -import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.mllib.tree.model.Bin -import org.apache.spark.rdd.RDD - - -/** - * Internal representation of LabeledPoint for DecisionTree. - * This bins feature values based on a subsampled of data as follows: - * (a) Continuous features are binned into ranges. - * (b) Unordered categorical features are binned based on subsets of feature values. - * "Unordered categorical features" are categorical features with low arity used in - * multiclass classification. - * (c) Ordered categorical features are binned based on feature values. - * "Ordered categorical features" are categorical features with high arity, - * or any categorical feature used in regression or binary classification. - * - * @param label Label from LabeledPoint - * @param binnedFeatures Binned feature values. - * Same length as LabeledPoint.features, but values are bin indices. - */ -private[spark] class TreePoint(val label: Double, val binnedFeatures: Array[Int]) - extends Serializable { -} - -private[spark] object TreePoint { - - /** - * Convert an input dataset into its TreePoint representation, - * binning feature values in preparation for DecisionTree training. - * @param input Input dataset. - * @param bins Bins for features, of size (numFeatures, numBins). - * @param metadata Learning and dataset metadata - * @return TreePoint dataset representation - */ - def convertToTreeRDD( - input: RDD[LabeledPoint], - bins: Array[Array[Bin]], - metadata: DecisionTreeMetadata): RDD[TreePoint] = { - // Construct arrays for featureArity for efficiency in the inner loop. - val featureArity: Array[Int] = new Array[Int](metadata.numFeatures) - var featureIndex = 0 - while (featureIndex < metadata.numFeatures) { - featureArity(featureIndex) = metadata.featureArity.getOrElse(featureIndex, 0) - featureIndex += 1 - } - input.map { x => - TreePoint.labeledPointToTreePoint(x, bins, featureArity) - } - } - - /** - * Convert one LabeledPoint into its TreePoint representation. - * @param bins Bins for features, of size (numFeatures, numBins). - * @param featureArity Array indexed by feature, with value 0 for continuous and numCategories - * for categorical features. - */ - private def labeledPointToTreePoint( - labeledPoint: LabeledPoint, - bins: Array[Array[Bin]], - featureArity: Array[Int]): TreePoint = { - val numFeatures = labeledPoint.features.size - val arr = new Array[Int](numFeatures) - var featureIndex = 0 - while (featureIndex < numFeatures) { - arr(featureIndex) = findBin(featureIndex, labeledPoint, featureArity(featureIndex), - bins) - featureIndex += 1 - } - new TreePoint(labeledPoint.label, arr) - } - - /** - * Find bin for one (labeledPoint, feature). - * - * @param featureArity 0 for continuous features; number of categories for categorical features. - * @param bins Bins for features, of size (numFeatures, numBins). - */ - private def findBin( - featureIndex: Int, - labeledPoint: LabeledPoint, - featureArity: Int, - bins: Array[Array[Bin]]): Int = { - - /** - * Binary search helper method for continuous feature. - */ - def binarySearchForBins(): Int = { - val binForFeatures = bins(featureIndex) - val feature = labeledPoint.features(featureIndex) - var left = 0 - var right = binForFeatures.length - 1 - while (left <= right) { - val mid = left + (right - left) / 2 - val bin = binForFeatures(mid) - val lowThreshold = bin.lowSplit.threshold - val highThreshold = bin.highSplit.threshold - if ((lowThreshold < feature) && (highThreshold >= feature)) { - return mid - } else if (lowThreshold >= feature) { - right = mid - 1 - } else { - left = mid + 1 - } - } - -1 - } - - if (featureArity == 0) { - // Perform binary search for finding bin for continuous features. - val binIndex = binarySearchForBins() - if (binIndex == -1) { - throw new RuntimeException("No bin was found for continuous feature." + - " This error can occur when given invalid data values (such as NaN)." + - s" Feature index: $featureIndex. Feature value: ${labeledPoint.features(featureIndex)}") - } - binIndex - } else { - // Categorical feature bins are indexed by feature values. - val featureValue = labeledPoint.features(featureIndex) - if (featureValue < 0 || featureValue >= featureArity) { - throw new IllegalArgumentException( - s"DecisionTree given invalid data:" + - s" Feature $featureIndex is categorical with values in" + - s" {0,...,${featureArity - 1}," + - s" but a data point gives it value $featureValue.\n" + - " Bad data point: " + labeledPoint.toString) - } - featureValue.toInt - } - } -} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala index 13aff11007..ff7700d2d1 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala @@ -85,7 +85,7 @@ object Entropy extends Impurity { * Note: Instances of this class do not hold the data; they operate on views of the data. * @param numClasses Number of classes for label. */ -private[tree] class EntropyAggregator(numClasses: Int) +private[spark] class EntropyAggregator(numClasses: Int) extends ImpurityAggregator(numClasses) with Serializable { /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala index 39c7f9c3be..58dc79b739 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala @@ -81,7 +81,7 @@ object Gini extends Impurity { * Note: Instances of this class do not hold the data; they operate on views of the data. * @param numClasses Number of classes for label. */ -private[tree] class GiniAggregator(numClasses: Int) +private[spark] class GiniAggregator(numClasses: Int) extends ImpurityAggregator(numClasses) with Serializable { /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala index 92d74a1b83..2423516123 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala @@ -71,7 +71,7 @@ object Variance extends Impurity { * in order to compute impurity from a sample. * Note: Instances of this class do not hold the data; they operate on views of the data. */ -private[tree] class VarianceAggregator() +private[spark] class VarianceAggregator() extends ImpurityAggregator(statsSize = 3) with Serializable { /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala deleted file mode 100644 index 0cad473782..0000000000 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala +++ /dev/null @@ -1,47 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.mllib.tree.model - -import org.apache.spark.mllib.tree.configuration.FeatureType._ - -/** - * Used for "binning" the feature values for faster best split calculation. - * - * For a continuous feature, the bin is determined by a low and a high split, - * where an example with featureValue falls into the bin s.t. - * lowSplit.threshold < featureValue <= highSplit.threshold. - * - * For ordered categorical features, there is a 1-1-1 correspondence between - * bins, splits, and feature values. The bin is determined by category/feature value. - * However, the bins are not necessarily ordered by feature value; - * they are ordered using impurity. - * - * For unordered categorical features, there is a 1-1 correspondence between bins, splits, - * where bins and splits correspond to subsets of feature values (in highSplit.categories). - * An unordered feature with k categories uses (1 << k - 1) - 1 bins, corresponding to all - * partitionings of categories into 2 disjoint, non-empty sets. - * - * @param lowSplit signifying the lower threshold for the continuous feature to be - * accepted in the bin - * @param highSplit signifying the upper threshold for the continuous feature to be - * accepted in the bin - * @param featureType type of feature -- categorical or continuous - * @param category categorical label value accepted in the bin for ordered features - */ -private[tree] -case class Bin(lowSplit: Split, highSplit: Split, featureType: FeatureType, category: Double) diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/BaggedPointSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/BaggedPointSuite.scala new file mode 100644 index 0000000000..77ab3d8bb7 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/BaggedPointSuite.scala @@ -0,0 +1,99 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.tree.impl + +import org.apache.spark.SparkFunSuite +import org.apache.spark.mllib.tree.EnsembleTestHelper +import org.apache.spark.mllib.util.MLlibTestSparkContext + +/** + * Test suite for [[BaggedPoint]]. + */ +class BaggedPointSuite extends SparkFunSuite with MLlibTestSparkContext { + + test("BaggedPoint RDD: without subsampling") { + val arr = EnsembleTestHelper.generateOrderedLabeledPoints(1, 1000) + val rdd = sc.parallelize(arr) + val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, 1.0, 1, false, 42) + baggedRDD.collect().foreach { baggedPoint => + assert(baggedPoint.subsampleWeights.size == 1 && baggedPoint.subsampleWeights(0) == 1) + } + } + + test("BaggedPoint RDD: with subsampling with replacement (fraction = 1.0)") { + val numSubsamples = 100 + val (expectedMean, expectedStddev) = (1.0, 1.0) + + val seeds = Array(123, 5354, 230, 349867, 23987) + val arr = EnsembleTestHelper.generateOrderedLabeledPoints(1, 1000) + val rdd = sc.parallelize(arr) + seeds.foreach { seed => + val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, 1.0, numSubsamples, true, seed) + val subsampleCounts: Array[Array[Double]] = baggedRDD.map(_.subsampleWeights).collect() + EnsembleTestHelper.testRandomArrays(subsampleCounts, numSubsamples, expectedMean, + expectedStddev, epsilon = 0.01) + } + } + + test("BaggedPoint RDD: with subsampling with replacement (fraction = 0.5)") { + val numSubsamples = 100 + val subsample = 0.5 + val (expectedMean, expectedStddev) = (subsample, math.sqrt(subsample)) + + val seeds = Array(123, 5354, 230, 349867, 23987) + val arr = EnsembleTestHelper.generateOrderedLabeledPoints(1, 1000) + val rdd = sc.parallelize(arr) + seeds.foreach { seed => + val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, subsample, numSubsamples, true, seed) + val subsampleCounts: Array[Array[Double]] = baggedRDD.map(_.subsampleWeights).collect() + EnsembleTestHelper.testRandomArrays(subsampleCounts, numSubsamples, expectedMean, + expectedStddev, epsilon = 0.01) + } + } + + test("BaggedPoint RDD: with subsampling without replacement (fraction = 1.0)") { + val numSubsamples = 100 + val (expectedMean, expectedStddev) = (1.0, 0) + + val seeds = Array(123, 5354, 230, 349867, 23987) + val arr = EnsembleTestHelper.generateOrderedLabeledPoints(1, 1000) + val rdd = sc.parallelize(arr) + seeds.foreach { seed => + val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, 1.0, numSubsamples, false, seed) + val subsampleCounts: Array[Array[Double]] = baggedRDD.map(_.subsampleWeights).collect() + EnsembleTestHelper.testRandomArrays(subsampleCounts, numSubsamples, expectedMean, + expectedStddev, epsilon = 0.01) + } + } + + test("BaggedPoint RDD: with subsampling without replacement (fraction = 0.5)") { + val numSubsamples = 100 + val subsample = 0.5 + val (expectedMean, expectedStddev) = (subsample, math.sqrt(subsample * (1 - subsample))) + + val seeds = Array(123, 5354, 230, 349867, 23987) + val arr = EnsembleTestHelper.generateOrderedLabeledPoints(1, 1000) + val rdd = sc.parallelize(arr) + seeds.foreach { seed => + val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, subsample, numSubsamples, false, seed) + val subsampleCounts: Array[Array[Double]] = baggedRDD.map(_.subsampleWeights).collect() + EnsembleTestHelper.testRandomArrays(subsampleCounts, numSubsamples, expectedMean, + expectedStddev, epsilon = 0.01) + } + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala index 441338e74e..e64551f03c 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala @@ -26,7 +26,6 @@ import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.{DecisionTreeSuite => OldDTSuite, EnsembleTestHelper} import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, QuantileStrategy, Strategy => OldStrategy} -import org.apache.spark.mllib.tree.impl.{BaggedPoint, DecisionTreeMetadata} import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, GiniCalculator} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala index bb1041b109..49cb7e1f24 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala @@ -20,12 +20,12 @@ package org.apache.spark.mllib.tree import scala.collection.JavaConverters._ import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.tree.impl.DecisionTreeMetadata import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.configuration.Algo._ import org.apache.spark.mllib.tree.configuration.FeatureType._ import org.apache.spark.mllib.tree.configuration.Strategy -import org.apache.spark.mllib.tree.impl.DecisionTreeMetadata import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Variance} import org.apache.spark.mllib.tree.model._ import org.apache.spark.mllib.util.MLlibTestSparkContext diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/impl/BaggedPointSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/impl/BaggedPointSuite.scala deleted file mode 100644 index 9d756da410..0000000000 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/impl/BaggedPointSuite.scala +++ /dev/null @@ -1,99 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.mllib.tree.impl - -import org.apache.spark.SparkFunSuite -import org.apache.spark.mllib.tree.EnsembleTestHelper -import org.apache.spark.mllib.util.MLlibTestSparkContext - -/** - * Test suite for [[BaggedPoint]]. - */ -class BaggedPointSuite extends SparkFunSuite with MLlibTestSparkContext { - - test("BaggedPoint RDD: without subsampling") { - val arr = EnsembleTestHelper.generateOrderedLabeledPoints(1, 1000) - val rdd = sc.parallelize(arr) - val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, 1.0, 1, false, 42) - baggedRDD.collect().foreach { baggedPoint => - assert(baggedPoint.subsampleWeights.size == 1 && baggedPoint.subsampleWeights(0) == 1) - } - } - - test("BaggedPoint RDD: with subsampling with replacement (fraction = 1.0)") { - val numSubsamples = 100 - val (expectedMean, expectedStddev) = (1.0, 1.0) - - val seeds = Array(123, 5354, 230, 349867, 23987) - val arr = EnsembleTestHelper.generateOrderedLabeledPoints(1, 1000) - val rdd = sc.parallelize(arr) - seeds.foreach { seed => - val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, 1.0, numSubsamples, true, seed) - val subsampleCounts: Array[Array[Double]] = baggedRDD.map(_.subsampleWeights).collect() - EnsembleTestHelper.testRandomArrays(subsampleCounts, numSubsamples, expectedMean, - expectedStddev, epsilon = 0.01) - } - } - - test("BaggedPoint RDD: with subsampling with replacement (fraction = 0.5)") { - val numSubsamples = 100 - val subsample = 0.5 - val (expectedMean, expectedStddev) = (subsample, math.sqrt(subsample)) - - val seeds = Array(123, 5354, 230, 349867, 23987) - val arr = EnsembleTestHelper.generateOrderedLabeledPoints(1, 1000) - val rdd = sc.parallelize(arr) - seeds.foreach { seed => - val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, subsample, numSubsamples, true, seed) - val subsampleCounts: Array[Array[Double]] = baggedRDD.map(_.subsampleWeights).collect() - EnsembleTestHelper.testRandomArrays(subsampleCounts, numSubsamples, expectedMean, - expectedStddev, epsilon = 0.01) - } - } - - test("BaggedPoint RDD: with subsampling without replacement (fraction = 1.0)") { - val numSubsamples = 100 - val (expectedMean, expectedStddev) = (1.0, 0) - - val seeds = Array(123, 5354, 230, 349867, 23987) - val arr = EnsembleTestHelper.generateOrderedLabeledPoints(1, 1000) - val rdd = sc.parallelize(arr) - seeds.foreach { seed => - val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, 1.0, numSubsamples, false, seed) - val subsampleCounts: Array[Array[Double]] = baggedRDD.map(_.subsampleWeights).collect() - EnsembleTestHelper.testRandomArrays(subsampleCounts, numSubsamples, expectedMean, - expectedStddev, epsilon = 0.01) - } - } - - test("BaggedPoint RDD: with subsampling without replacement (fraction = 0.5)") { - val numSubsamples = 100 - val subsample = 0.5 - val (expectedMean, expectedStddev) = (subsample, math.sqrt(subsample * (1 - subsample))) - - val seeds = Array(123, 5354, 230, 349867, 23987) - val arr = EnsembleTestHelper.generateOrderedLabeledPoints(1, 1000) - val rdd = sc.parallelize(arr) - seeds.foreach { seed => - val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, subsample, numSubsamples, false, seed) - val subsampleCounts: Array[Array[Double]] = baggedRDD.map(_.subsampleWeights).collect() - EnsembleTestHelper.testRandomArrays(subsampleCounts, numSubsamples, expectedMean, - expectedStddev, epsilon = 0.01) - } - } -} -- cgit v1.2.3 From 27e71a2cd930ae28c82c9c3ee6476a12ea165fdf Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Fri, 1 Apr 2016 22:00:24 -0700 Subject: [SPARK-14244][SQL] Don't use SizeBasedWindowFunction.n created on executor side when evaluating window functions ## What changes were proposed in this pull request? `SizeBasedWindowFunction.n` is a global singleton attribute created for evaluating size based aggregate window functions like `CUME_DIST`. However, this attribute gets different expression IDs when created on both driver side and executor side. This PR adds `withPartitionSize` method to `SizeBasedWindowFunction` so that we can easily rewrite `SizeBasedWindowFunction.n` on executor side. ## How was this patch tested? A test case is added in `HiveSparkSubmitSuite`, which supports launching multi-process clusters. Author: Cheng Lian Closes #12040 from liancheng/spark-14244-fix-sized-window-function. --- .../catalyst/expressions/windowExpressions.scala | 6 ++- .../org/apache/spark/sql/execution/Window.scala | 22 ++++++++--- .../scala/org/apache/spark/sql/hive/hiveUDFs.scala | 4 +- .../spark/sql/hive/HiveSparkSubmitSuite.scala | 44 +++++++++++++++++++++- 4 files changed, 67 insertions(+), 9 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala index b8679474cf..c0b453dccf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala @@ -451,7 +451,11 @@ abstract class RowNumberLike extends AggregateWindowFunction { * A [[SizeBasedWindowFunction]] needs the size of the current window for its calculation. */ trait SizeBasedWindowFunction extends AggregateWindowFunction { - protected def n: AttributeReference = SizeBasedWindowFunction.n + // It's made a val so that the attribute created on driver side is serialized to executor side. + // Otherwise, if it's defined as a function, when it's called on executor side, it actually + // returns the singleton value instantiated on executor side, which has different expression ID + // from the one created on driver side. + val n: AttributeReference = SizeBasedWindowFunction.n } object SizeBasedWindowFunction { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala index 7acf020b28..7d0567842c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala @@ -874,7 +874,8 @@ private[execution] final class UnboundedFollowingWindowFunctionFrame( * processor class. */ private[execution] object AggregateProcessor { - def apply(functions: Array[Expression], + def apply( + functions: Array[Expression], ordinal: Int, inputAttributes: Seq[Attribute], newMutableProjection: (Seq[Expression], Seq[Attribute]) => () => MutableProjection): @@ -885,11 +886,20 @@ private[execution] object AggregateProcessor { val evaluateExpressions = mutable.Buffer.fill[Expression](ordinal)(NoOp) val imperatives = mutable.Buffer.empty[ImperativeAggregate] + // SPARK-14244: `SizeBasedWindowFunction`s are firstly created on driver side and then + // serialized to executor side. These functions all reference a global singleton window + // partition size attribute reference, i.e., `SizeBasedWindowFunction.n`. Here we must collect + // the singleton instance created on driver side instead of using executor side + // `SizeBasedWindowFunction.n` to avoid binding failure caused by mismatching expression ID. + val partitionSize: Option[AttributeReference] = { + val aggs = functions.flatMap(_.collectFirst { case f: SizeBasedWindowFunction => f }) + aggs.headOption.map(_.n) + } + // Check if there are any SizeBasedWindowFunctions. If there are, we add the partition size to // the aggregation buffer. Note that the ordinal of the partition size value will always be 0. - val trackPartitionSize = functions.exists(_.isInstanceOf[SizeBasedWindowFunction]) - if (trackPartitionSize) { - aggBufferAttributes += SizeBasedWindowFunction.n + partitionSize.foreach { n => + aggBufferAttributes += n initialValues += NoOp updateExpressions += NoOp } @@ -920,7 +930,7 @@ private[execution] object AggregateProcessor { // Create the projections. val initialProjection = newMutableProjection( initialValues, - Seq(SizeBasedWindowFunction.n))() + partitionSize.toSeq)() val updateProjection = newMutableProjection( updateExpressions, aggBufferAttributes ++ inputAttributes)() @@ -935,7 +945,7 @@ private[execution] object AggregateProcessor { updateProjection, evaluateProjection, imperatives.toArray, - trackPartitionSize) + partitionSize.isDefined) } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala index c07c428895..5ada3d5598 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala @@ -107,7 +107,9 @@ private[hive] class HiveFunctionRegistry( // If there is any other error, we throw an AnalysisException. val errorMessage = s"No handler for Hive udf ${functionInfo.getFunctionClass} " + s"because: ${throwable.getMessage}." - throw new AnalysisException(errorMessage) + val analysisException = new AnalysisException(errorMessage) + analysisException.setStackTrace(throwable.getStackTrace) + throw analysisException } } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala index 16747cab37..53dec6348f 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala @@ -31,7 +31,7 @@ import org.scalatest.time.SpanSugar._ import org.apache.spark._ import org.apache.spark.internal.Logging -import org.apache.spark.sql.{QueryTest, SQLContext} +import org.apache.spark.sql.{QueryTest, Row, SQLContext} import org.apache.spark.sql.expressions.Window import org.apache.spark.sql.hive.test.{TestHive, TestHiveContext} import org.apache.spark.sql.test.ProcessTestUtils.ProcessOutputCapturer @@ -135,6 +135,19 @@ class HiveSparkSubmitSuite runSparkSubmit(args) } + test("SPARK-14244 fix window partition size attribute binding failure") { + val unusedJar = TestUtils.createJarWithClasses(Seq.empty) + val args = Seq( + "--class", SPARK_14244.getClass.getName.stripSuffix("$"), + "--name", "SparkSQLConfTest", + "--master", "local-cluster[2,1,1024]", + "--conf", "spark.ui.enabled=false", + "--conf", "spark.master.rest.enabled=false", + "--driver-java-options", "-Dderby.system.durability=test", + unusedJar.toString) + runSparkSubmit(args) + } + // NOTE: This is an expensive operation in terms of time (10 seconds+). Use sparingly. // This is copied from org.apache.spark.deploy.SparkSubmitSuite private def runSparkSubmit(args: Seq[String]): Unit = { @@ -378,3 +391,32 @@ object SPARK_11009 extends QueryTest { } } } + +object SPARK_14244 extends QueryTest { + import org.apache.spark.sql.expressions.Window + import org.apache.spark.sql.functions._ + + protected var sqlContext: SQLContext = _ + + def main(args: Array[String]): Unit = { + Utils.configTestLog4j("INFO") + + val sparkContext = new SparkContext( + new SparkConf() + .set("spark.ui.enabled", "false") + .set("spark.sql.shuffle.partitions", "100")) + + val hiveContext = new TestHiveContext(sparkContext) + sqlContext = hiveContext + + import hiveContext.implicits._ + + try { + val window = Window.orderBy('id) + val df = sqlContext.range(2).select(cume_dist().over(window).as('cdist)).orderBy('cdist) + checkAnswer(df, Seq(Row(0.5D), Row(1.0D))) + } finally { + sparkContext.stop() + } + } +} -- cgit v1.2.3 From 877dc712e66db69cb320e10ba5edebca401591e3 Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Fri, 1 Apr 2016 22:38:07 -0700 Subject: [SPARK-14138] [SQL] [MASTER] Fix generated SpecificColumnarIterator code can exceed JVM size limit for cached DataFrames ## What changes were proposed in this pull request? This PR reduces Java byte code size of method in ```SpecificColumnarIterator``` by using a approach to make a group for lot of ```ColumnAccessor``` instantiations or method calls (more than 200) into a method ## How was this patch tested? Added a new unit test, which includes large instantiations and method calls, to ```InMemoryColumnarQuerySuite``` Author: Kazuaki Ishizaki Closes #12108 from kiszk/SPARK-14138-master. --- .../columnar/GenerateColumnAccessor.scala | 46 +++++++++++++++++++--- .../columnar/InMemoryColumnarQuerySuite.scala | 10 +++++ 2 files changed, 51 insertions(+), 5 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala index d4e5db459f..e2e33e3246 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala @@ -88,7 +88,7 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera case array: ArrayType => classOf[ArrayColumnAccessor].getName case t: MapType => classOf[MapColumnAccessor].getName } - ctx.addMutableState(accessorCls, accessorName, s"$accessorName = null;") + ctx.addMutableState(accessorCls, accessorName, "") val createCode = dt match { case t if ctx.isPrimitiveType(dt) => @@ -114,6 +114,42 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera (createCode, extract + patch) }.unzip + /* + * 200 = 6000 bytes / 30 (up to 30 bytes per one call)) + * the maximum byte code size to be compiled for HotSpot is 8000. + * We should keep less than 8000 + */ + val numberOfStatementsThreshold = 200 + val (initializerAccessorCalls, extractorCalls) = + if (initializeAccessors.length <= numberOfStatementsThreshold) { + (initializeAccessors.mkString("\n"), extractors.mkString("\n")) + } else { + val groupedAccessorsItr = initializeAccessors.grouped(numberOfStatementsThreshold) + val groupedExtractorsItr = extractors.grouped(numberOfStatementsThreshold) + var groupedAccessorsLength = 0 + groupedAccessorsItr.zipWithIndex.map { case (body, i) => + groupedAccessorsLength += 1 + val funcName = s"accessors$i" + val funcCode = s""" + |private void $funcName() { + | ${body.mkString("\n")} + |} + """.stripMargin + ctx.addNewFunction(funcName, funcCode) + } + groupedExtractorsItr.zipWithIndex.map { case (body, i) => + val funcName = s"extractors$i" + val funcCode = s""" + |private void $funcName() { + | ${body.mkString("\n")} + |} + """.stripMargin + ctx.addNewFunction(funcName, funcCode) + } + ((0 to groupedAccessorsLength - 1).map { i => s"accessors$i();" }.mkString("\n"), + (0 to groupedAccessorsLength - 1).map { i => s"extractors$i();" }.mkString("\n")) + } + val code = s""" import java.nio.ByteBuffer; import java.nio.ByteOrder; @@ -149,8 +185,6 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera this.nativeOrder = ByteOrder.nativeOrder(); this.buffers = new byte[${columnTypes.length}][]; this.mutableRow = new MutableUnsafeRow(rowWriter); - - ${ctx.initMutableStates()} } public void initialize(Iterator input, DataType[] columnTypes, int[] columnIndexes) { @@ -159,6 +193,8 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera this.columnIndexes = columnIndexes; } + ${ctx.declareAddedFunctions()} + public boolean hasNext() { if (currentRow < numRowsInBatch) { return true; @@ -173,7 +209,7 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera for (int i = 0; i < columnIndexes.length; i ++) { buffers[i] = batch.buffers()[columnIndexes[i]]; } - ${initializeAccessors.mkString("\n")} + ${initializerAccessorCalls} return hasNext(); } @@ -182,7 +218,7 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera currentRow += 1; bufferHolder.reset(); rowWriter.zeroOutNullBytes(); - ${extractors.mkString("\n")} + ${extractorCalls} unsafeRow.setTotalSize(bufferHolder.totalSize()); return unsafeRow; } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala index 9e04caf8ba..50c8745a28 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala @@ -220,4 +220,14 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { assert(data.count() === 10) assert(data.filter($"s" === "3").count() === 1) } + + test("SPARK-14138: Generated SpecificColumnarIterator can exceed JVM size limit for cached DF") { + val length1 = 3999 + val columnTypes1 = List.fill(length1)(IntegerType) + val columnarIterator1 = GenerateColumnAccessor.generate(columnTypes1) + + val length2 = 10000 + val columnTypes2 = List.fill(length2)(IntegerType) + val columnarIterator2 = GenerateColumnAccessor.generate(columnTypes2) + } } -- cgit v1.2.3 From fa1af0aff7bde9bbf7bfa6a3ac74699734c2fd8a Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Fri, 1 Apr 2016 22:45:52 -0700 Subject: [SPARK-14251][SQL] Add SQL command for printing out generated code for debugging ## What changes were proposed in this pull request? This PR implements `EXPLAIN CODEGEN` SQL command which returns generated codes like `debugCodegen`. In `spark-shell`, we don't need to `import debug` module. In `spark-sql`, we can use this SQL command now. **Before** ``` scala> import org.apache.spark.sql.execution.debug._ scala> sql("select 'a' as a group by 1").debugCodegen() Found 2 WholeStageCodegen subtrees. == Subtree 1 / 2 == ... Generated code: ... == Subtree 2 / 2 == ... Generated code: ... ``` **After** ``` scala> sql("explain extended codegen select 'a' as a group by 1").collect().foreach(println) [Found 2 WholeStageCodegen subtrees.] [== Subtree 1 / 2 ==] ... [] [Generated code:] ... [] [== Subtree 2 / 2 ==] ... [] [Generated code:] ... ``` ## How was this patch tested? Pass the Jenkins tests (including new testcases) Author: Dongjoon Hyun Closes #12099 from dongjoon-hyun/SPARK-14251. --- .../apache/spark/sql/catalyst/parser/SqlBase.g4 | 5 ++- .../spark/sql/execution/SparkSqlParser.scala | 3 +- .../spark/sql/execution/command/commands.scala | 15 ++++++-- .../apache/spark/sql/execution/debug/package.scala | 43 +++++++++++----------- .../spark/sql/execution/debug/DebuggingSuite.scala | 2 +- .../apache/spark/sql/hive/execution/commands.scala | 1 - .../sql/hive/execution/HiveExplainSuite.scala | 29 +++++++++++++++ 7 files changed, 67 insertions(+), 31 deletions(-) diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index d1747b9915..f34bb061e4 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -584,7 +584,7 @@ frameBound explainOption - : LOGICAL | FORMATTED | EXTENDED + : LOGICAL | FORMATTED | EXTENDED | CODEGEN ; transactionMode @@ -633,7 +633,7 @@ nonReserved | DELIMITED | FIELDS | TERMINATED | COLLECTION | ITEMS | KEYS | ESCAPED | LINES | SEPARATED | EXTENDED | REFRESH | CLEAR | CACHE | UNCACHE | LAZY | TEMPORARY | OPTIONS | GROUPING | CUBE | ROLLUP - | EXPLAIN | FORMAT | LOGICAL | FORMATTED + | EXPLAIN | FORMAT | LOGICAL | FORMATTED | CODEGEN | TABLESAMPLE | USE | TO | BUCKET | PERCENTLIT | OUT | OF | SET | VIEW | REPLACE @@ -724,6 +724,7 @@ DESCRIBE: 'DESCRIBE'; EXPLAIN: 'EXPLAIN'; FORMAT: 'FORMAT'; LOGICAL: 'LOGICAL'; +CODEGEN: 'CODEGEN'; CAST: 'CAST'; SHOW: 'SHOW'; TABLES: 'TABLES'; diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala index 7efe98dd18..ff3ab7746c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala @@ -136,7 +136,8 @@ class SparkSqlAstBuilder extends AstBuilder { // Create the explain comment. val statement = plan(ctx.statement) if (isExplainableStatement(statement)) { - ExplainCommand(statement, extended = options.exists(_.EXTENDED != null)) + ExplainCommand(statement, extended = options.exists(_.EXTENDED != null), + codegen = options.exists(_.CODEGEN != null)) } else { ExplainCommand(OneRowRelation) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala index f90d8717ca..4bc62cdc4a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala @@ -28,10 +28,10 @@ import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.debug._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ - /** * A logical command that is executed for its side-effects. `RunnableCommand`s are * wrapped in `ExecutedCommand` during execution. @@ -237,15 +237,22 @@ case class ExplainCommand( logicalPlan: LogicalPlan, override val output: Seq[Attribute] = Seq(AttributeReference("plan", StringType, nullable = true)()), - extended: Boolean = false) + extended: Boolean = false, + codegen: Boolean = false) extends RunnableCommand { // Run through the optimizer to generate the physical plan. override def run(sqlContext: SQLContext): Seq[Row] = try { // TODO in Hive, the "extended" ExplainCommand prints the AST as well, and detailed properties. val queryExecution = sqlContext.executePlan(logicalPlan) - val outputString = if (extended) queryExecution.toString else queryExecution.simpleString - + val outputString = + if (codegen) { + codegenString(queryExecution.executedPlan) + } else if (extended) { + queryExecution.toString + } else { + queryExecution.simpleString + } outputString.split("\n").map(Row(_)) } catch { case cause: TreeNodeException[_] => ("Error occurred during query planning: \n" + cause.getMessage).split("\n").map(Row(_)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala index 3a174ed94c..7b0c8ebdfa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala @@ -48,6 +48,25 @@ package object debug { // scalastyle:on println } + def codegenString(plan: SparkPlan): String = { + val codegenSubtrees = new collection.mutable.HashSet[WholeStageCodegen]() + plan transform { + case s: WholeStageCodegen => + codegenSubtrees += s + s + case s => s + } + var output = s"Found ${codegenSubtrees.size} WholeStageCodegen subtrees.\n" + for ((s, i) <- codegenSubtrees.toSeq.zipWithIndex) { + output += s"== Subtree ${i + 1} / ${codegenSubtrees.size} ==\n" + output += s + output += "\nGenerated code:\n" + val (_, source) = s.doCodeGen() + output += s"${CodeFormatter.format(source)}\n" + } + output + } + /** * Augments [[SQLContext]] with debug methods. */ @@ -81,28 +100,7 @@ package object debug { * WholeStageCodegen subtree). */ def debugCodegen(): Unit = { - debugPrint(debugCodegenString()) - } - - /** Visible for testing. */ - def debugCodegenString(): String = { - val plan = query.queryExecution.executedPlan - val codegenSubtrees = new collection.mutable.HashSet[WholeStageCodegen]() - plan transform { - case s: WholeStageCodegen => - codegenSubtrees += s - s - case s => s - } - var output = s"Found ${codegenSubtrees.size} WholeStageCodegen subtrees.\n" - for ((s, i) <- codegenSubtrees.toSeq.zipWithIndex) { - output += s"== Subtree ${i + 1} / ${codegenSubtrees.size} ==\n" - output += s - output += "\nGenerated code:\n" - val (_, source) = s.doCodeGen() - output += s"${CodeFormatter.format(source)}\n" - } - output + debugPrint(codegenString(query.queryExecution.executedPlan)) } } @@ -123,6 +121,7 @@ package object debug { /** * A collection of metrics for each column of output. + * * @param elementTypes the actual runtime types for the output. Useful when there are bugs * causing the wrong data to be projected. */ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala index 979265e274..c0fce4b96a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala @@ -27,7 +27,7 @@ class DebuggingSuite extends SparkFunSuite with SharedSQLContext { } test("debugCodegen") { - val res = sqlContext.range(10).groupBy("id").count().debugCodegenString() + val res = codegenString(sqlContext.range(10).groupBy("id").count().queryExecution.executedPlan) assert(res.contains("Subtree 1 / 2")) assert(res.contains("Subtree 2 / 2")) assert(res.contains("Object[]")) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala index cd26a68f35..64d1341a47 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala @@ -24,7 +24,6 @@ import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.execution.command.RunnableCommand import org.apache.spark.sql.execution.datasources.{BucketSpec, DataSource, LogicalRelation} import org.apache.spark.sql.hive.HiveContext diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala index b7ef5d1db7..c45d49d6c0 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala @@ -101,4 +101,33 @@ class HiveExplainSuite extends QueryTest with SQLTestUtils with TestHiveSingleto "Physical Plan should not contain Subquery since it's eliminated by optimizer") } } + + test("EXPLAIN CODEGEN command") { + checkExistence(sql("EXPLAIN CODEGEN SELECT 1"), true, + "WholeStageCodegen", + "Generated code:", + "/* 001 */ public Object generate(Object[] references) {", + "/* 002 */ return new GeneratedIterator(references);", + "/* 003 */ }" + ) + + checkExistence(sql("EXPLAIN CODEGEN SELECT 1"), false, + "== Physical Plan ==" + ) + + checkExistence(sql("EXPLAIN EXTENDED CODEGEN SELECT 1"), true, + "WholeStageCodegen", + "Generated code:", + "/* 001 */ public Object generate(Object[] references) {", + "/* 002 */ return new GeneratedIterator(references);", + "/* 003 */ }" + ) + + checkExistence(sql("EXPLAIN EXTENDED CODEGEN SELECT 1"), false, + "== Parsed Logical Plan ==", + "== Analyzed Logical Plan ==", + "== Optimized Logical Plan ==", + "== Physical Plan ==" + ) + } } -- cgit v1.2.3 From f414154418c2291448954b9f0890d592b2d823ae Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Fri, 1 Apr 2016 22:46:56 -0700 Subject: [SPARK-14285][SQL] Implement common type-safe aggregate functions ## What changes were proposed in this pull request? In the Dataset API, it is fairly difficult for users to perform simple aggregations in a type-safe way at the moment because there are no aggregators that have been implemented. This pull request adds a few common aggregate functions in expressions.scala.typed package, and also creates the expressions.java.typed package without implementation. The java implementation should probably come as a separate pull request. One challenge there is to resolve the type difference between Scala primitive types and Java boxed types. ## How was this patch tested? Added unit tests for them. Author: Reynold Xin Closes #12077 from rxin/SPARK-14285. --- .../sql/execution/aggregate/typedaggregators.scala | 69 ++++++++++++ .../org/apache/spark/sql/expressions/Window.scala | 8 +- .../apache/spark/sql/expressions/WindowSpec.scala | 8 +- .../apache/spark/sql/expressions/java/typed.java | 34 ++++++ .../apache/spark/sql/expressions/scala/typed.scala | 89 +++++++++++++++ .../org/apache/spark/sql/expressions/udaf.scala | 4 +- .../org/apache/spark/sql/JavaDatasetSuite.java | 54 --------- .../sql/sources/JavaDatasetAggregatorSuite.java | 123 +++++++++++++++++++++ .../apache/spark/sql/DatasetAggregatorSuite.scala | 64 +++-------- 9 files changed, 342 insertions(+), 111 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/typedaggregators.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/expressions/java/typed.java create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/expressions/scala/typed.scala create mode 100644 sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaDatasetAggregatorSuite.java diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/typedaggregators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/typedaggregators.scala new file mode 100644 index 0000000000..9afc29038b --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/typedaggregators.scala @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.aggregate + +import org.apache.spark.sql.expressions.Aggregator + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// This file defines internal implementations for aggregators. +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +class TypedSum[IN, OUT : Numeric](f: IN => OUT) extends Aggregator[IN, OUT, OUT] { + val numeric = implicitly[Numeric[OUT]] + override def zero: OUT = numeric.zero + override def reduce(b: OUT, a: IN): OUT = numeric.plus(b, f(a)) + override def merge(b1: OUT, b2: OUT): OUT = numeric.plus(b1, b2) + override def finish(reduction: OUT): OUT = reduction +} + + +class TypedSumDouble[IN](f: IN => Double) extends Aggregator[IN, Double, Double] { + override def zero: Double = 0.0 + override def reduce(b: Double, a: IN): Double = b + f(a) + override def merge(b1: Double, b2: Double): Double = b1 + b2 + override def finish(reduction: Double): Double = reduction +} + + +class TypedSumLong[IN](f: IN => Long) extends Aggregator[IN, Long, Long] { + override def zero: Long = 0L + override def reduce(b: Long, a: IN): Long = b + f(a) + override def merge(b1: Long, b2: Long): Long = b1 + b2 + override def finish(reduction: Long): Long = reduction +} + + +class TypedCount[IN](f: IN => Any) extends Aggregator[IN, Long, Long] { + override def zero: Long = 0 + override def reduce(b: Long, a: IN): Long = { + if (f(a) == null) b else b + 1 + } + override def merge(b1: Long, b2: Long): Long = b1 + b2 + override def finish(reduction: Long): Long = reduction +} + + +class TypedAverage[IN](f: IN => Double) extends Aggregator[IN, (Double, Long), Double] { + override def zero: (Double, Long) = (0.0, 0L) + override def reduce(b: (Double, Long), a: IN): (Double, Long) = (f(a) + b._1, 1 + b._2) + override def finish(reduction: (Double, Long)): Double = reduction._1 / reduction._2 + override def merge(b1: (Double, Long), b2: (Double, Long)): (Double, Long) = { + (b1._1 + b2._1, b1._2 + b2._2) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala index e9b60841fc..350c283646 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Window.scala @@ -42,7 +42,7 @@ object Window { * Creates a [[WindowSpec]] with the partitioning defined. * @since 1.4.0 */ - @scala.annotation.varargs + @_root_.scala.annotation.varargs def partitionBy(colName: String, colNames: String*): WindowSpec = { spec.partitionBy(colName, colNames : _*) } @@ -51,7 +51,7 @@ object Window { * Creates a [[WindowSpec]] with the partitioning defined. * @since 1.4.0 */ - @scala.annotation.varargs + @_root_.scala.annotation.varargs def partitionBy(cols: Column*): WindowSpec = { spec.partitionBy(cols : _*) } @@ -60,7 +60,7 @@ object Window { * Creates a [[WindowSpec]] with the ordering defined. * @since 1.4.0 */ - @scala.annotation.varargs + @_root_.scala.annotation.varargs def orderBy(colName: String, colNames: String*): WindowSpec = { spec.orderBy(colName, colNames : _*) } @@ -69,7 +69,7 @@ object Window { * Creates a [[WindowSpec]] with the ordering defined. * @since 1.4.0 */ - @scala.annotation.varargs + @_root_.scala.annotation.varargs def orderBy(cols: Column*): WindowSpec = { spec.orderBy(cols : _*) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala index 9e9c58cb66..d716da2668 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala @@ -39,7 +39,7 @@ class WindowSpec private[sql]( * Defines the partitioning columns in a [[WindowSpec]]. * @since 1.4.0 */ - @scala.annotation.varargs + @_root_.scala.annotation.varargs def partitionBy(colName: String, colNames: String*): WindowSpec = { partitionBy((colName +: colNames).map(Column(_)): _*) } @@ -48,7 +48,7 @@ class WindowSpec private[sql]( * Defines the partitioning columns in a [[WindowSpec]]. * @since 1.4.0 */ - @scala.annotation.varargs + @_root_.scala.annotation.varargs def partitionBy(cols: Column*): WindowSpec = { new WindowSpec(cols.map(_.expr), orderSpec, frame) } @@ -57,7 +57,7 @@ class WindowSpec private[sql]( * Defines the ordering columns in a [[WindowSpec]]. * @since 1.4.0 */ - @scala.annotation.varargs + @_root_.scala.annotation.varargs def orderBy(colName: String, colNames: String*): WindowSpec = { orderBy((colName +: colNames).map(Column(_)): _*) } @@ -66,7 +66,7 @@ class WindowSpec private[sql]( * Defines the ordering columns in a [[WindowSpec]]. * @since 1.4.0 */ - @scala.annotation.varargs + @_root_.scala.annotation.varargs def orderBy(cols: Column*): WindowSpec = { val sortOrder: Seq[SortOrder] = cols.map { col => col.expr match { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/java/typed.java b/sql/core/src/main/scala/org/apache/spark/sql/expressions/java/typed.java new file mode 100644 index 0000000000..cdba970d8f --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/java/typed.java @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.expressions.java; + +import org.apache.spark.annotation.Experimental; +import org.apache.spark.sql.Dataset; + +/** + * :: Experimental :: + * Type-safe functions available for {@link Dataset} operations in Java. + * + * Scala users should use {@link org.apache.spark.sql.expressions.scala.typed}. + * + * @since 2.0.0 + */ +@Experimental +public class typed { + +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/scala/typed.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/scala/typed.scala new file mode 100644 index 0000000000..d0eb190afd --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/scala/typed.scala @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.expressions.scala + +import org.apache.spark.annotation.Experimental +import org.apache.spark.sql._ +import org.apache.spark.sql.execution.aggregate._ + +/** + * :: Experimental :: + * Type-safe functions available for [[Dataset]] operations in Scala. + * + * Java users should use [[org.apache.spark.sql.expressions.java.typed]]. + * + * @since 2.0.0 + */ +@Experimental +// scalastyle:off +object typed { + // scalastyle:on + + // Note: whenever we update this file, we should update the corresponding Java version too. + // The reason we have separate files for Java and Scala is because in the Scala version, we can + // use tighter types (primitive types) for return types, whereas in the Java version we can only + // use boxed primitive types. + // For example, avg in the Scala veresion returns Scala primitive Double, whose bytecode + // signature is just a java.lang.Object; avg in the Java version returns java.lang.Double. + + // TODO: This is pretty hacky. Maybe we should have an object for implicit encoders. + private val implicits = new SQLImplicits { + override protected def _sqlContext: SQLContext = null + } + + import implicits._ + + /** + * Average aggregate function. + * + * @since 2.0.0 + */ + def avg[IN](f: IN => Double): TypedColumn[IN, Double] = new TypedAverage(f).toColumn + + /** + * Count aggregate function. + * + * @since 2.0.0 + */ + def count[IN](f: IN => Any): TypedColumn[IN, Long] = new TypedCount(f).toColumn + + /** + * Sum aggregate function for floating point (double) type. + * + * @since 2.0.0 + */ + def sum[IN](f: IN => Double): TypedColumn[IN, Double] = new TypedSumDouble[IN](f).toColumn + + /** + * Sum aggregate function for integral (long, i.e. 64 bit integer) type. + * + * @since 2.0.0 + */ + def sumLong[IN](f: IN => Long): TypedColumn[IN, Long] = new TypedSumLong[IN](f).toColumn + + // TODO: + // stddevOf: Double + // varianceOf: Double + // approxCountDistinct: Long + + // minOf: T + // maxOf: T + + // firstOf: T + // lastOf: T +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/udaf.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/udaf.scala index 8b355befc3..48925910ac 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/udaf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/udaf.scala @@ -106,7 +106,7 @@ abstract class UserDefinedAggregateFunction extends Serializable { /** * Creates a [[Column]] for this UDAF using given [[Column]]s as input arguments. */ - @scala.annotation.varargs + @_root_.scala.annotation.varargs def apply(exprs: Column*): Column = { val aggregateExpression = AggregateExpression( @@ -120,7 +120,7 @@ abstract class UserDefinedAggregateFunction extends Serializable { * Creates a [[Column]] for this UDAF using the distinct values of the given * [[Column]]s as input arguments. */ - @scala.annotation.varargs + @_root_.scala.annotation.varargs def distinct(exprs: Column*): Column = { val aggregateExpression = AggregateExpression( diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java index a6c819373b..a5ab446e08 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java @@ -37,7 +37,6 @@ import org.apache.spark.SparkContext; import org.apache.spark.api.java.function.*; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.sql.*; -import org.apache.spark.sql.expressions.Aggregator; import org.apache.spark.sql.test.TestSQLContext; import org.apache.spark.sql.catalyst.encoders.OuterScopes; import org.apache.spark.sql.catalyst.expressions.GenericRow; @@ -385,59 +384,6 @@ public class JavaDatasetSuite implements Serializable { Assert.assertEquals(data, ds.collectAsList()); } - @Test - public void testTypedAggregation() { - Encoder> encoder = Encoders.tuple(Encoders.STRING(), Encoders.INT()); - List> data = - Arrays.asList(tuple2("a", 1), tuple2("a", 2), tuple2("b", 3)); - Dataset> ds = context.createDataset(data, encoder); - - KeyValueGroupedDataset> grouped = ds.groupByKey( - new MapFunction, String>() { - @Override - public String call(Tuple2 value) throws Exception { - return value._1(); - } - }, - Encoders.STRING()); - - Dataset> agged = - grouped.agg(new IntSumOf().toColumn(Encoders.INT(), Encoders.INT())); - Assert.assertEquals(Arrays.asList(tuple2("a", 3), tuple2("b", 3)), agged.collectAsList()); - - Dataset> agged2 = grouped.agg( - new IntSumOf().toColumn(Encoders.INT(), Encoders.INT())) - .as(Encoders.tuple(Encoders.STRING(), Encoders.INT())); - Assert.assertEquals( - Arrays.asList( - new Tuple2<>("a", 3), - new Tuple2<>("b", 3)), - agged2.collectAsList()); - } - - static class IntSumOf extends Aggregator, Integer, Integer> { - - @Override - public Integer zero() { - return 0; - } - - @Override - public Integer reduce(Integer l, Tuple2 t) { - return l + t._2(); - } - - @Override - public Integer merge(Integer b1, Integer b2) { - return b1 + b2; - } - - @Override - public Integer finish(Integer reduction) { - return reduction; - } - } - public static class KryoSerializable { String value; diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaDatasetAggregatorSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaDatasetAggregatorSuite.java new file mode 100644 index 0000000000..c4c455b6e6 --- /dev/null +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaDatasetAggregatorSuite.java @@ -0,0 +1,123 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package test.org.apache.spark.sql.sources; + +import java.io.Serializable; +import java.util.Arrays; +import java.util.List; + +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; +import scala.Tuple2; + +import org.apache.spark.SparkContext; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.api.java.function.MapFunction; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Encoder; +import org.apache.spark.sql.Encoders; +import org.apache.spark.sql.KeyValueGroupedDataset; +import org.apache.spark.sql.expressions.Aggregator; +import org.apache.spark.sql.test.TestSQLContext; + +/** + * Suite for testing the aggregate functionality of Datasets in Java. + */ +public class JavaDatasetAggregatorSuite implements Serializable { + private transient JavaSparkContext jsc; + private transient TestSQLContext context; + + @Before + public void setUp() { + // Trigger static initializer of TestData + SparkContext sc = new SparkContext("local[*]", "testing"); + jsc = new JavaSparkContext(sc); + context = new TestSQLContext(sc); + context.loadTestData(); + } + + @After + public void tearDown() { + context.sparkContext().stop(); + context = null; + jsc = null; + } + + private Tuple2 tuple2(T1 t1, T2 t2) { + return new Tuple2<>(t1, t2); + } + + private KeyValueGroupedDataset> generateGroupedDataset() { + Encoder> encoder = Encoders.tuple(Encoders.STRING(), Encoders.INT()); + List> data = + Arrays.asList(tuple2("a", 1), tuple2("a", 2), tuple2("b", 3)); + Dataset> ds = context.createDataset(data, encoder); + + return ds.groupByKey( + new MapFunction, String>() { + @Override + public String call(Tuple2 value) throws Exception { + return value._1(); + } + }, + Encoders.STRING()); + } + + @Test + public void testTypedAggregationAnonClass() { + KeyValueGroupedDataset> grouped = generateGroupedDataset(); + + Dataset> agged = + grouped.agg(new IntSumOf().toColumn(Encoders.INT(), Encoders.INT())); + Assert.assertEquals(Arrays.asList(tuple2("a", 3), tuple2("b", 3)), agged.collectAsList()); + + Dataset> agged2 = grouped.agg( + new IntSumOf().toColumn(Encoders.INT(), Encoders.INT())) + .as(Encoders.tuple(Encoders.STRING(), Encoders.INT())); + Assert.assertEquals( + Arrays.asList( + new Tuple2<>("a", 3), + new Tuple2<>("b", 3)), + agged2.collectAsList()); + } + + static class IntSumOf extends Aggregator, Integer, Integer> { + + @Override + public Integer zero() { + return 0; + } + + @Override + public Integer reduce(Integer l, Tuple2 t) { + return l + t._2(); + } + + @Override + public Integer merge(Integer b1, Integer b2) { + return b1 + b2; + } + + @Override + public Integer finish(Integer reduction) { + return reduction; + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala index 84770169f0..5430aff6ce 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala @@ -20,35 +20,10 @@ package org.apache.spark.sql import scala.language.postfixOps import org.apache.spark.sql.expressions.Aggregator +import org.apache.spark.sql.expressions.scala.typed import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SharedSQLContext -/** An `Aggregator` that adds up any numeric type returned by the given function. */ -class SumOf[I, N : Numeric](f: I => N) extends Aggregator[I, N, N] { - val numeric = implicitly[Numeric[N]] - - override def zero: N = numeric.zero - - override def reduce(b: N, a: I): N = numeric.plus(b, f(a)) - - override def merge(b1: N, b2: N): N = numeric.plus(b1, b2) - - override def finish(reduction: N): N = reduction -} - -object TypedAverage extends Aggregator[(String, Int), (Long, Long), Double] { - override def zero: (Long, Long) = (0, 0) - - override def reduce(countAndSum: (Long, Long), input: (String, Int)): (Long, Long) = { - (countAndSum._1 + 1, countAndSum._2 + input._2) - } - - override def merge(b1: (Long, Long), b2: (Long, Long)): (Long, Long) = { - (b1._1 + b2._1, b1._2 + b2._2) - } - - override def finish(countAndSum: (Long, Long)): Double = countAndSum._2 / countAndSum._1 -} object ComplexResultAgg extends Aggregator[(String, Int), (Long, Long), (Long, Long)] { @@ -113,15 +88,12 @@ class DatasetAggregatorSuite extends QueryTest with SharedSQLContext { import testImplicits._ - def sum[I, N : Numeric : Encoder](f: I => N): TypedColumn[I, N] = - new SumOf(f).toColumn - test("typed aggregation: TypedAggregator") { val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() checkDataset( - ds.groupByKey(_._1).agg(sum(_._2)), - ("a", 30), ("b", 3), ("c", 1)) + ds.groupByKey(_._1).agg(typed.sum(_._2)), + ("a", 30.0), ("b", 3.0), ("c", 1.0)) } test("typed aggregation: TypedAggregator, expr, expr") { @@ -129,20 +101,10 @@ class DatasetAggregatorSuite extends QueryTest with SharedSQLContext { checkDataset( ds.groupByKey(_._1).agg( - sum(_._2), + typed.sum(_._2), expr("sum(_2)").as[Long], count("*")), - ("a", 30, 30L, 2L), ("b", 3, 3L, 2L), ("c", 1, 1L, 1L)) - } - - test("typed aggregation: complex case") { - val ds = Seq("a" -> 1, "a" -> 3, "b" -> 3).toDS() - - checkDataset( - ds.groupByKey(_._1).agg( - expr("avg(_2)").as[Double], - TypedAverage.toColumn), - ("a", 2.0, 2.0), ("b", 3.0, 3.0)) + ("a", 30.0, 30L, 2L), ("b", 3.0, 3L, 2L), ("c", 1.0, 1L, 1L)) } test("typed aggregation: complex result type") { @@ -159,11 +121,11 @@ class DatasetAggregatorSuite extends QueryTest with SharedSQLContext { val ds = Seq(1, 3, 2, 5).toDS() checkDataset( - ds.select(sum((i: Int) => i)), - 11) + ds.select(typed.sum((i: Int) => i)), + 11.0) checkDataset( - ds.select(sum((i: Int) => i), sum((i: Int) => i * 2)), - 11 -> 22) + ds.select(typed.sum((i: Int) => i), typed.sum((i: Int) => i * 2)), + 11.0 -> 22.0) } test("typed aggregation: class input") { @@ -206,4 +168,12 @@ class DatasetAggregatorSuite extends QueryTest with SharedSQLContext { ds.groupByKey(_.b).agg(ComplexBufferAgg.toColumn), ("one", 1), ("two", 1)) } + + test("typed aggregate: avg, count, sum") { + val ds = Seq("a" -> 1, "a" -> 3, "b" -> 3).toDS() + checkDataset( + ds.groupByKey(_._1).agg( + typed.avg(_._2), typed.count(_._2), typed.sum(_._2), typed.sumLong(_._2)), + ("a", 2.0, 2L, 4.0, 4L), ("b", 3.0, 1L, 3.0, 3L)) + } } -- cgit v1.2.3 From d7982a3a9aa804e7e3a2004335e7f314867a5f8a Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Fri, 1 Apr 2016 22:51:47 -0700 Subject: [MINOR][SQL] Fix comments styl and correct several styles and nits in CSV data source ## What changes were proposed in this pull request? While trying to create a PR (which was not an issue at the end), I just corrected some style nits. So, I removed the changes except for some coding style corrections. - According to the [scala-style-guide#documentation-style](https://github.com/databricks/scala-style-guide#documentation-style), Scala style comments are discouraged. >```scala >/** This is a correct one-liner, short description. */ > >/** > * This is correct multi-line JavaDoc comment. And > * this is my second line, and if I keep typing, this would be > * my third line. > */ > >/** In Spark, we don't use the ScalaDoc style so this > * is not correct. > */ >``` - Double newlines between consecutive methods was removed. According to [scala-style-guide#blank-lines-vertical-whitespace](https://github.com/databricks/scala-style-guide#blank-lines-vertical-whitespace), single newline appears when >Between consecutive members (or initializers) of a class: fields, constructors, methods, nested classes, static initializers, instance initializers. - Remove uesless parentheses in tests - Use `mapPartitions` instead of `mapPartitionsWithIndex()`. ## How was this patch tested? Unit tests were used and `dev/run_tests` for style tests. Author: hyukjinkwon Closes #12109 from HyukjinKwon/SPARK-14271. --- .../sql/execution/datasources/csv/CSVParser.scala | 80 +++++++++++----------- .../execution/datasources/csv/CSVRelation.scala | 6 +- .../execution/datasources/csv/DefaultSource.scala | 1 - .../execution/datasources/csv/CSVParserSuite.scala | 10 +-- 4 files changed, 48 insertions(+), 49 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVParser.scala index 7cf1b4c662..5570b2c173 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVParser.scala @@ -25,11 +25,11 @@ import com.univocity.parsers.csv.{CsvParser, CsvParserSettings, CsvWriter, CsvWr import org.apache.spark.internal.Logging /** - * Read and parse CSV-like input - * - * @param params Parameters object - * @param headers headers for the columns - */ + * Read and parse CSV-like input + * + * @param params Parameters object + * @param headers headers for the columns + */ private[sql] abstract class CsvReader(params: CSVOptions, headers: Seq[String]) { protected lazy val parser: CsvParser = { @@ -54,11 +54,11 @@ private[sql] abstract class CsvReader(params: CSVOptions, headers: Seq[String]) } /** - * Converts a sequence of string to CSV string - * - * @param params Parameters object for configuration - * @param headers headers for columns - */ + * Converts a sequence of string to CSV string + * + * @param params Parameters object for configuration + * @param headers headers for columns + */ private[sql] class LineCsvWriter(params: CSVOptions, headers: Seq[String]) extends Logging { private val writerSettings = new CsvWriterSettings private val format = writerSettings.getFormat @@ -90,18 +90,18 @@ private[sql] class LineCsvWriter(params: CSVOptions, headers: Seq[String]) exten } /** - * Parser for parsing a line at a time. Not efficient for bulk data. - * - * @param params Parameters object - */ + * Parser for parsing a line at a time. Not efficient for bulk data. + * + * @param params Parameters object + */ private[sql] class LineCsvReader(params: CSVOptions) extends CsvReader(params, null) { /** - * parse a line - * - * @param line a String with no newline at the end - * @return array of strings where each string is a field in the CSV record - */ + * parse a line + * + * @param line a String with no newline at the end + * @return array of strings where each string is a field in the CSV record + */ def parseLine(line: String): Array[String] = { parser.beginParsing(new StringReader(line)) val parsed = parser.parseNext() @@ -111,12 +111,12 @@ private[sql] class LineCsvReader(params: CSVOptions) } /** - * Parser for parsing lines in bulk. Use this when efficiency is desired. - * - * @param iter iterator over lines in the file - * @param params Parameters object - * @param headers headers for the columns - */ + * Parser for parsing lines in bulk. Use this when efficiency is desired. + * + * @param iter iterator over lines in the file + * @param params Parameters object + * @param headers headers for the columns + */ private[sql] class BulkCsvReader( iter: Iterator[String], params: CSVOptions, @@ -128,9 +128,9 @@ private[sql] class BulkCsvReader( private var nextRecord = parser.parseNext() /** - * get the next parsed line. - * @return array of strings where each string is a field in the CSV record - */ + * get the next parsed line. + * @return array of strings where each string is a field in the CSV record + */ override def next(): Array[String] = { val curRecord = nextRecord if(curRecord != null) { @@ -146,11 +146,11 @@ private[sql] class BulkCsvReader( } /** - * A Reader that "reads" from a sequence of lines. Spark's textFile method removes newlines at - * end of each line Univocity parser requires a Reader that provides access to the data to be - * parsed and needs the newlines to be present - * @param iter iterator over RDD[String] - */ + * A Reader that "reads" from a sequence of lines. Spark's textFile method removes newlines at + * end of each line Univocity parser requires a Reader that provides access to the data to be + * parsed and needs the newlines to be present + * @param iter iterator over RDD[String] + */ private class StringIteratorReader(val iter: Iterator[String]) extends java.io.Reader { private var next: Long = 0 @@ -159,9 +159,9 @@ private class StringIteratorReader(val iter: Iterator[String]) extends java.io.R private var str: String = null // current string from iter /** - * fetch next string from iter, if done with current one - * pretend there is a new line at the end of every string we get from from iter - */ + * fetch next string from iter, if done with current one + * pretend there is a new line at the end of every string we get from from iter + */ private def refill(): Unit = { if (length == next) { if (iter.hasNext) { @@ -175,8 +175,8 @@ private class StringIteratorReader(val iter: Iterator[String]) extends java.io.R } /** - * read the next character, if at end of string pretend there is a new line - */ + * read the next character, if at end of string pretend there is a new line + */ override def read(): Int = { refill() if (next >= length) { @@ -189,8 +189,8 @@ private class StringIteratorReader(val iter: Iterator[String]) extends java.io.R } /** - * read from str into cbuf - */ + * read from str into cbuf + */ override def read(cbuf: Array[Char], off: Int, len: Int): Int = { refill() var n = 0 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala index b47328a3dd..54fb03b6d3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala @@ -42,12 +42,12 @@ object CSVRelation extends Logging { firstLine: String, params: CSVOptions): RDD[Array[String]] = { // If header is set, make sure firstLine is materialized before sending to executors. - file.mapPartitionsWithIndex({ - case (split, iter) => new BulkCsvReader( + file.mapPartitions { iter => + new BulkCsvReader( if (params.headerFlag) iter.filterNot(_ == firstLine) else iter, params, headers = header) - }, true) + } } def csvParser( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/DefaultSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/DefaultSource.scala index 6b6add48cd..c0d6f6fbf7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/DefaultSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/DefaultSource.scala @@ -164,7 +164,6 @@ class DefaultSource extends FileFormat with DataSourceRegister { } } - private def baseRdd( sqlContext: SQLContext, options: CSVOptions, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVParserSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVParserSuite.scala index c0c38c6787..dc54883277 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVParserSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVParserSuite.scala @@ -46,7 +46,7 @@ class CSVParserSuite extends SparkFunSuite { var numRead = 0 var n = 0 do { // try to fill cbuf - var off = 0 + var off = 0 var len = cbuf.length n = reader.read(cbuf, off, len) @@ -81,7 +81,7 @@ class CSVParserSuite extends SparkFunSuite { test("Regular case") { val input = List("This is a string", "This is another string", "Small", "", "\"quoted\"") val read = readAll(input.toIterator) - assert(read === input.mkString("\n") ++ ("\n")) + assert(read === input.mkString("\n") ++ "\n") } test("Empty iter") { @@ -93,12 +93,12 @@ class CSVParserSuite extends SparkFunSuite { test("Embedded new line") { val input = List("This is a string", "This is another string", "Small\n", "", "\"quoted\"") val read = readAll(input.toIterator) - assert(read === input.mkString("\n") ++ ("\n")) + assert(read === input.mkString("\n") ++ "\n") } test("Buffer Regular case") { val input = List("This is a string", "This is another string", "Small", "", "\"quoted\"") - val output = input.mkString("\n") ++ ("\n") + val output = input.mkString("\n") ++ "\n" for(i <- 1 to output.length + 5) { val read = readBufAll(input.toIterator, i) assert(read === output) @@ -116,7 +116,7 @@ class CSVParserSuite extends SparkFunSuite { test("Buffer Embedded new line") { val input = List("This is a string", "This is another string", "Small\n", "", "\"quoted\"") - val output = input.mkString("\n") ++ ("\n") + val output = input.mkString("\n") ++ "\n" for(i <- 1 to output.length + 5) { val read = readBufAll(input.toIterator, 1) assert(read === output) -- cgit v1.2.3 From 67d753516da9b6318cd4001bb7ae91703aaf098d Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Sat, 2 Apr 2016 00:00:19 -0700 Subject: [HOTFIX] Fix compilation break. --- .../spark/sql/execution/streaming/HDFSMetadataLogSuite.scala | 1 + .../apache/spark/sql/streaming/StreamingAggregationSuite.scala | 8 +++----- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLogSuite.scala index d5db9db36b..1328142704 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLogSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLogSuite.scala @@ -21,6 +21,7 @@ import java.io.{File, FileNotFoundException, IOException} import java.net.URI import java.util.ConcurrentModificationException +import scala.language.implicitConversions import scala.util.Random import org.apache.hadoop.conf.Configuration diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala index b63ce89d18..3af7c01e52 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala @@ -18,8 +18,9 @@ package org.apache.spark.sql.streaming import org.apache.spark.SparkException -import org.apache.spark.sql.{Encoder, StreamTest, SumOf, TypedColumn} +import org.apache.spark.sql.StreamTest import org.apache.spark.sql.execution.streaming._ +import org.apache.spark.sql.expressions.scala.typed import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SharedSQLContext @@ -118,11 +119,8 @@ class StreamingAggregationSuite extends StreamTest with SharedSQLContext { } test("typed aggregators") { - def sum[I, N : Numeric : Encoder](f: I => N): TypedColumn[I, N] = - new SumOf(f).toColumn - val inputData = MemoryStream[(String, Int)] - val aggregated = inputData.toDS().groupByKey(_._1).agg(sum(_._2)) + val aggregated = inputData.toDS().groupByKey(_._1).agg(typed.sumLong(_._2)) testStream(aggregated)( AddData(inputData, ("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)), -- cgit v1.2.3 From 06694f1c68cb752ea311144f0dbe50e92e1393cf Mon Sep 17 00:00:00 2001 From: Jacek Laskowski Date: Sat, 2 Apr 2016 08:12:04 -0700 Subject: [MINOR] Typo fixes ## What changes were proposed in this pull request? Typo fixes. No functional changes. ## How was this patch tested? Built the sources and ran with samples. Author: Jacek Laskowski Closes #11802 from jaceklaskowski/typo-fixes. --- .../streaming/RecoverableNetworkWordCount.scala | 2 +- .../main/scala/org/apache/spark/ml/Pipeline.scala | 2 +- .../spark/ml/regression/LinearRegression.scala | 2 +- .../sql/catalyst/plans/logical/LogicalPlan.scala | 4 ++-- .../org/apache/spark/sql/ExperimentalMethods.scala | 2 +- .../sql/execution/joins/BroadcastHashJoin.scala | 2 +- .../scala/org/apache/spark/sql/functions.scala | 12 ++++++------ .../apache/spark/streaming/StreamingContext.scala | 13 +++++++------ .../streaming/dstream/ConstantInputDStream.scala | 2 +- .../apache/spark/streaming/dstream/DStream.scala | 8 ++++---- .../streaming/dstream/DStreamCheckpointData.scala | 6 +++--- .../spark/streaming/dstream/InputDStream.scala | 6 +++--- .../streaming/dstream/ReducedWindowedDStream.scala | 2 +- .../spark/streaming/dstream/StateDStream.scala | 12 ++++++------ .../streaming/scheduler/ReceivedBlockTracker.scala | 4 ++-- .../streaming/scheduler/rate/RateEstimator.scala | 22 ++++++++++++---------- 16 files changed, 52 insertions(+), 49 deletions(-) diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/RecoverableNetworkWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/RecoverableNetworkWordCount.scala index 05f8e65d65..b6b8bc33f7 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/RecoverableNetworkWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/RecoverableNetworkWordCount.scala @@ -141,7 +141,7 @@ object RecoverableNetworkWordCount { def main(args: Array[String]) { if (args.length != 4) { - System.err.println("You arguments were " + args.mkString("[", ", ", "]")) + System.err.println("Your arguments were " + args.mkString("[", ", ", "]")) System.err.println( """ |Usage: RecoverableNetworkWordCount diff --git a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala index 3a99979a88..afefaaa883 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala @@ -147,7 +147,7 @@ class Pipeline @Since("1.4.0") ( t case _ => throw new IllegalArgumentException( - s"Do not support stage $stage of type ${stage.getClass}") + s"Does not support stage $stage of type ${stage.getClass}") } if (index < indexOfLastEstimator) { curDataset = transformer.transform(curDataset) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index ba5ad4c072..2633c06f40 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -58,7 +58,7 @@ private[regression] trait LinearRegressionParams extends PredictorParams * The specific squared error loss function used is: * L = 1/2n ||A coefficients - y||^2^ * - * This support multiple types of regularization: + * This supports multiple types of regularization: * - none (a.k.a. ordinary least squares) * - L2 (ridge regression) * - L1 (Lasso) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index ecf4285c46..aceeb8aadc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -79,13 +79,13 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { /** * Computes [[Statistics]] for this plan. The default implementation assumes the output - * cardinality is the product of of all child plan's cardinality, i.e. applies in the case + * cardinality is the product of all child plan's cardinality, i.e. applies in the case * of cartesian joins. * * [[LeafNode]]s must override this. */ def statistics: Statistics = { - if (children.size == 0) { + if (children.isEmpty) { throw new UnsupportedOperationException(s"LeafNode $nodeName must implement statistics.") } Statistics(sizeInBytes = children.map(_.statistics.sizeInBytes).product) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/ExperimentalMethods.scala b/sql/core/src/main/scala/org/apache/spark/sql/ExperimentalMethods.scala index d7cd84fd24..c5df028485 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/ExperimentalMethods.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/ExperimentalMethods.scala @@ -37,7 +37,7 @@ class ExperimentalMethods private[sql]() { /** * Allows extra strategies to be injected into the query planner at runtime. Note this API - * should be consider experimental and is not intended to be stable across releases. + * should be considered experimental and is not intended to be stable across releases. * * @since 1.3.0 */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala index f5b083c216..0ed1ed41b0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala @@ -32,7 +32,7 @@ import org.apache.spark.util.collection.CompactBuffer /** * Performs an inner hash join of two child relations. When the output RDD of this operator is * being constructed, a Spark job is asynchronously started to calculate the values for the - * broadcasted relation. This data is then placed in a Spark broadcast variable. The streamed + * broadcast relation. This data is then placed in a Spark broadcast variable. The streamed * relation is not shuffled. */ case class BroadcastHashJoin( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 74906050ac..baf947d037 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -2232,7 +2232,7 @@ object functions { /** * Splits str around pattern (pattern is a regular expression). - * NOTE: pattern is a string represent the regular expression. + * NOTE: pattern is a string representation of the regular expression. * * @group string_funcs * @since 1.5.0 @@ -2267,9 +2267,9 @@ object functions { /** * Translate any character in the src by a character in replaceString. - * The characters in replaceString is corresponding to the characters in matchingString. - * The translate will happen when any character in the string matching with the character - * in the matchingString. + * The characters in replaceString correspond to the characters in matchingString. + * The translate will happen when any character in the string matches the character + * in the `matchingString`. * * @group string_funcs * @since 1.5.0 @@ -2692,7 +2692,7 @@ object functions { ////////////////////////////////////////////////////////////////////////////////////////////// /** - * Returns true if the array contain the value + * Returns true if the array contains `value` * @group collection_funcs * @since 1.5.0 */ @@ -2920,7 +2920,7 @@ object functions { /** * Defines a user-defined function (UDF) using a Scala closure. For this variant, the caller must - * specifcy the output data type, and there is no automatic input type coercion. + * specify the output data type, and there is no automatic input type coercion. * * @param f A closure in Scala * @param dataType The output data type of the UDF diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala index 3a664c4f5c..c1e151d08b 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala @@ -132,7 +132,7 @@ class StreamingContext private[streaming] ( "both SparkContext and checkpoint as null") } - private[streaming] val isCheckpointPresent = (_cp != null) + private[streaming] val isCheckpointPresent: Boolean = _cp != null private[streaming] val sc: SparkContext = { if (_sc != null) { @@ -213,8 +213,8 @@ class StreamingContext private[streaming] ( def sparkContext: SparkContext = sc /** - * Set each DStreams in this context to remember RDDs it generated in the last given duration. - * DStreams remember RDDs only for a limited duration of time and releases them for garbage + * Set each DStream in this context to remember RDDs it generated in the last given duration. + * DStreams remember RDDs only for a limited duration of time and release them for garbage * collection. This method allows the developer to specify how long to remember the RDDs ( * if the developer wishes to query old data outside the DStream computation). * @param duration Minimum duration that each DStream should remember its RDDs @@ -282,13 +282,14 @@ class StreamingContext private[streaming] ( } /** - * Create a input stream from TCP source hostname:port. Data is received using + * Creates an input stream from TCP source hostname:port. Data is received using * a TCP socket and the receive bytes is interpreted as UTF8 encoded `\n` delimited * lines. * @param hostname Hostname to connect to for receiving data * @param port Port to connect to for receiving data * @param storageLevel Storage level to use for storing the received objects * (default: StorageLevel.MEMORY_AND_DISK_SER_2) + * @see [[socketStream]] */ def socketTextStream( hostname: String, @@ -299,7 +300,7 @@ class StreamingContext private[streaming] ( } /** - * Create a input stream from TCP source hostname:port. Data is received using + * Creates an input stream from TCP source hostname:port. Data is received using * a TCP socket and the receive bytes it interpreted as object using the given * converter. * @param hostname Hostname to connect to for receiving data @@ -860,7 +861,7 @@ private class StreamingContextPythonHelper { */ def tryRecoverFromCheckpoint(checkpointPath: String): Option[StreamingContext] = { val checkpointOption = CheckpointReader.read( - checkpointPath, new SparkConf(), SparkHadoopUtil.get.conf, false) + checkpointPath, new SparkConf(), SparkHadoopUtil.get.conf, ignoreReadError = false) checkpointOption.map(new StreamingContext(null, _, null)) } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ConstantInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ConstantInputDStream.scala index b5f86fe779..995470ec8d 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ConstantInputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ConstantInputDStream.scala @@ -23,7 +23,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.streaming.{StreamingContext, Time} /** - * An input stream that always returns the same RDD on each timestep. Useful for testing. + * An input stream that always returns the same RDD on each time step. Useful for testing. */ class ConstantInputDStream[T: ClassTag](_ssc: StreamingContext, rdd: RDD[T]) extends InputDStream[T](_ssc) { diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala index eb7b64eaf4..c40beeff97 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala @@ -83,7 +83,7 @@ abstract class DStream[T: ClassTag] ( // RDDs generated, marked as private[streaming] so that testsuites can access it @transient - private[streaming] var generatedRDDs = new HashMap[Time, RDD[T]] () + private[streaming] var generatedRDDs = new HashMap[Time, RDD[T]]() // Time zero for the DStream private[streaming] var zeroTime: Time = null @@ -269,7 +269,7 @@ abstract class DStream[T: ClassTag] ( checkpointDuration == null || rememberDuration > checkpointDuration, s"The remember duration for ${this.getClass.getSimpleName} has been set to " + s" $rememberDuration which is not more than the checkpoint interval" + - s" ($checkpointDuration). Please set it to higher than $checkpointDuration." + s" ($checkpointDuration). Please set it to a value higher than $checkpointDuration." ) dependencies.foreach(_.validateAtStart()) @@ -277,7 +277,7 @@ abstract class DStream[T: ClassTag] ( logInfo(s"Slide time = $slideDuration") logInfo(s"Storage level = ${storageLevel.description}") logInfo(s"Checkpoint interval = $checkpointDuration") - logInfo(s"Remember duration = $rememberDuration") + logInfo(s"Remember interval = $rememberDuration") logInfo(s"Initialized and validated $this") } @@ -535,7 +535,7 @@ abstract class DStream[T: ClassTag] ( private def readObject(ois: ObjectInputStream): Unit = Utils.tryOrIOException { logDebug(s"${this.getClass().getSimpleName}.readObject used") ois.defaultReadObject() - generatedRDDs = new HashMap[Time, RDD[T]] () + generatedRDDs = new HashMap[Time, RDD[T]]() } // ======================================================================= diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStreamCheckpointData.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStreamCheckpointData.scala index 365a6bc417..431c9dbe2c 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStreamCheckpointData.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStreamCheckpointData.scala @@ -29,7 +29,7 @@ import org.apache.spark.streaming.Time import org.apache.spark.util.Utils private[streaming] -class DStreamCheckpointData[T: ClassTag] (dstream: DStream[T]) +class DStreamCheckpointData[T: ClassTag](dstream: DStream[T]) extends Serializable with Logging { protected val data = new HashMap[Time, AnyRef]() @@ -45,7 +45,7 @@ class DStreamCheckpointData[T: ClassTag] (dstream: DStream[T]) /** * Updates the checkpoint data of the DStream. This gets called every time * the graph checkpoint is initiated. Default implementation records the - * checkpoint files to which the generate RDDs of the DStream has been saved. + * checkpoint files at which the generated RDDs of the DStream have been saved. */ def update(time: Time) { @@ -103,7 +103,7 @@ class DStreamCheckpointData[T: ClassTag] (dstream: DStream[T]) /** * Restore the checkpoint data. This gets called once when the DStream graph - * (along with its DStreams) are being restored from a graph checkpoint file. + * (along with its output DStreams) is being restored from a graph checkpoint file. * Default implementation restores the RDDs from their checkpoint files. */ def restore() { diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/InputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/InputDStream.scala index 0b6b191dbe..dc88349db5 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/InputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/InputDStream.scala @@ -39,7 +39,7 @@ import org.apache.spark.util.Utils * * @param _ssc Streaming context that will execute this input stream */ -abstract class InputDStream[T: ClassTag] (_ssc: StreamingContext) +abstract class InputDStream[T: ClassTag](_ssc: StreamingContext) extends DStream[T](_ssc) { private[streaming] var lastValidTime: Time = null @@ -90,8 +90,8 @@ abstract class InputDStream[T: ClassTag] (_ssc: StreamingContext) } else { // Time is valid, but check it it is more than lastValidTime if (lastValidTime != null && time < lastValidTime) { - logWarning("isTimeValid called with " + time + " where as last valid time is " + - lastValidTime) + logWarning(s"isTimeValid called with $time whereas the last valid time " + + s"is $lastValidTime") } lastValidTime = time true diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReducedWindowedDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReducedWindowedDStream.scala index a9be2f213f..a9e93838b8 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReducedWindowedDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReducedWindowedDStream.scala @@ -87,7 +87,7 @@ class ReducedWindowedDStream[K: ClassTag, V: ClassTag]( logDebug("Window time = " + windowDuration) logDebug("Slide time = " + slideDuration) - logDebug("ZeroTime = " + zeroTime) + logDebug("Zero time = " + zeroTime) logDebug("Current window = " + currentWindow) logDebug("Previous window = " + previousWindow) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/StateDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/StateDStream.scala index 68eff89030..0379957e58 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/StateDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/StateDStream.scala @@ -70,7 +70,7 @@ class StateDStream[K: ClassTag, V: ClassTag, S: ClassTag]( // Try to get the parent RDD parent.getOrCompute(validTime) match { case Some(parentRDD) => { // If parent RDD exists, then compute as usual - computeUsingPreviousRDD (parentRDD, prevStateRDD) + computeUsingPreviousRDD(parentRDD, prevStateRDD) } case None => { // If parent RDD does not exist @@ -98,15 +98,15 @@ class StateDStream[K: ClassTag, V: ClassTag, S: ClassTag]( // and then apply the update function val updateFuncLocal = updateFunc val finalFunc = (iterator: Iterator[(K, Iterable[V])]) => { - updateFuncLocal (iterator.map (tuple => (tuple._1, tuple._2.toSeq, None))) + updateFuncLocal(iterator.map(tuple => (tuple._1, tuple._2.toSeq, None))) } - val groupedRDD = parentRDD.groupByKey (partitioner) - val sessionRDD = groupedRDD.mapPartitions (finalFunc, preservePartitioning) + val groupedRDD = parentRDD.groupByKey(partitioner) + val sessionRDD = groupedRDD.mapPartitions(finalFunc, preservePartitioning) // logDebug("Generating state RDD for time " + validTime + " (first)") - Some (sessionRDD) + Some(sessionRDD) } - case Some (initialStateRDD) => { + case Some(initialStateRDD) => { computeUsingPreviousRDD(parentRDD, initialStateRDD) } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala index 9c8e68b03d..5d9a8ac0d9 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala @@ -119,7 +119,7 @@ private[streaming] class ReceivedBlockTracker( timeToAllocatedBlocks.put(batchTime, allocatedBlocks) lastAllocatedBatchTime = batchTime } else { - logInfo(s"Possibly processed batch $batchTime need to be processed again in WAL recovery") + logInfo(s"Possibly processed batch $batchTime needs to be processed again in WAL recovery") } } else { // This situation occurs when: @@ -129,7 +129,7 @@ private[streaming] class ReceivedBlockTracker( // 2. Slow checkpointing makes recovered batch time older than WAL recovered // lastAllocatedBatchTime. // This situation will only occurs in recovery time. - logInfo(s"Possibly processed batch $batchTime need to be processed again in WAL recovery") + logInfo(s"Possibly processed batch $batchTime needs to be processed again in WAL recovery") } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/rate/RateEstimator.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/rate/RateEstimator.scala index d7210f64fc..7b2ef6881d 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/rate/RateEstimator.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/rate/RateEstimator.scala @@ -21,18 +21,20 @@ import org.apache.spark.SparkConf import org.apache.spark.streaming.Duration /** - * A component that estimates the rate at wich an InputDStream should ingest - * elements, based on updates at every batch completion. + * A component that estimates the rate at which an `InputDStream` should ingest + * records, based on updates at every batch completion. + * + * @see [[org.apache.spark.streaming.scheduler.RateController]] */ private[streaming] trait RateEstimator extends Serializable { /** - * Computes the number of elements the stream attached to this `RateEstimator` + * Computes the number of records the stream attached to this `RateEstimator` * should ingest per second, given an update on the size and completion * times of the latest batch. * - * @param time The timetamp of the current batch interval that just finished - * @param elements The number of elements that were processed in this batch + * @param time The timestamp of the current batch interval that just finished + * @param elements The number of records that were processed in this batch * @param processingDelay The time in ms that took for the job to complete * @param schedulingDelay The time in ms that the job spent in the scheduling queue */ @@ -46,13 +48,13 @@ private[streaming] trait RateEstimator extends Serializable { object RateEstimator { /** - * Return a new RateEstimator based on the value of `spark.streaming.RateEstimator`. + * Return a new `RateEstimator` based on the value of + * `spark.streaming.backpressure.rateEstimator`. * - * The only known estimator right now is `pid`. + * The only known and acceptable estimator right now is `pid`. * * @return An instance of RateEstimator - * @throws IllegalArgumentException if there is a configured RateEstimator that doesn't match any - * known estimators. + * @throws IllegalArgumentException if the configured RateEstimator is not `pid`. */ def create(conf: SparkConf, batchInterval: Duration): RateEstimator = conf.get("spark.streaming.backpressure.rateEstimator", "pid") match { @@ -64,6 +66,6 @@ object RateEstimator { new PIDRateEstimator(batchInterval.milliseconds, proportional, integral, derived, minRate) case estimator => - throw new IllegalArgumentException(s"Unkown rate estimator: $estimator") + throw new IllegalArgumentException(s"Unknown rate estimator: $estimator") } } -- cgit v1.2.3 From a3e293542a6e7df9bcc7d9bbd22b3c93a81bcc38 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Sat, 2 Apr 2016 12:44:02 -0700 Subject: [HOTFIX] Disable StateStoreSuite.maintenance --- .../apache/spark/sql/execution/streaming/state/StateStoreSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala index 0e5936d53f..dd23925716 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala @@ -352,7 +352,7 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth } } - test("maintenance") { + ignore("maintenance") { val conf = new SparkConf() .setMaster("local") .setAppName("test") -- cgit v1.2.3 From f705037617d55bb479ec60bcb1e55c736224be94 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Sat, 2 Apr 2016 17:48:53 -0700 Subject: [SPARK-14338][SQL] Improve `SimplifyConditionals` rule to handle `null` in IF/CASEWHEN ## What changes were proposed in this pull request? Currently, `SimplifyConditionals` handles `true` and `false` to optimize branches. This PR improves `SimplifyConditionals` to take advantage of `null` conditions for `if` and `CaseWhen` expressions, too. **Before** ``` scala> sql("SELECT IF(null, 1, 0)").explain() == Physical Plan == WholeStageCodegen : +- Project [if (null) 1 else 0 AS (IF(CAST(NULL AS BOOLEAN), 1, 0))#4] : +- INPUT +- Scan OneRowRelation[] scala> sql("select case when cast(null as boolean) then 1 else 2 end").explain() == Physical Plan == WholeStageCodegen : +- Project [CASE WHEN null THEN 1 ELSE 2 END AS CASE WHEN CAST(NULL AS BOOLEAN) THEN 1 ELSE 2 END#14] : +- INPUT +- Scan OneRowRelation[] ``` **After** ``` scala> sql("SELECT IF(null, 1, 0)").explain() == Physical Plan == WholeStageCodegen : +- Project [0 AS (IF(CAST(NULL AS BOOLEAN), 1, 0))#4] : +- INPUT +- Scan OneRowRelation[] scala> sql("select case when cast(null as boolean) then 1 else 2 end").explain() == Physical Plan == WholeStageCodegen : +- Project [2 AS CASE WHEN CAST(NULL AS BOOLEAN) THEN 1 ELSE 2 END#4] : +- INPUT +- Scan OneRowRelation[] ``` **Hive** ``` hive> select if(null,1,2); OK 2 hive> select case when cast(null as boolean) then 1 else 2 end; OK 2 ``` ## How was this patch tested? Pass the Jenkins tests (including new extended test cases). Author: Dongjoon Hyun Closes #12122 from dongjoon-hyun/SPARK-14338. --- .../apache/spark/sql/catalyst/optimizer/Optimizer.scala | 13 ++++++++++--- .../catalyst/optimizer/SimplifyConditionalSuite.scala | 16 +++++++++++----- 2 files changed, 21 insertions(+), 8 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 326933ec9e..a5ab390c76 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -527,7 +527,7 @@ object LikeSimplification extends Rule[LogicalPlan] { * Null value propagation from bottom to top of the expression tree. */ object NullPropagation extends Rule[LogicalPlan] { - def nonNullLiteral(e: Expression): Boolean = e match { + private def nonNullLiteral(e: Expression): Boolean = e match { case Literal(null, _) => false case _ => true } @@ -773,17 +773,24 @@ object BooleanSimplification extends Rule[LogicalPlan] with PredicateHelper { * Simplifies conditional expressions (if / case). */ object SimplifyConditionals extends Rule[LogicalPlan] with PredicateHelper { + private def falseOrNullLiteral(e: Expression): Boolean = e match { + case FalseLiteral => true + case Literal(null, _) => true + case _ => false + } + def apply(plan: LogicalPlan): LogicalPlan = plan transform { case q: LogicalPlan => q transformExpressionsUp { case If(TrueLiteral, trueValue, _) => trueValue case If(FalseLiteral, _, falseValue) => falseValue + case If(Literal(null, _), _, falseValue) => falseValue - case e @ CaseWhen(branches, elseValue) if branches.exists(_._1 == FalseLiteral) => + case e @ CaseWhen(branches, elseValue) if branches.exists(x => falseOrNullLiteral(x._1)) => // If there are branches that are always false, remove them. // If there are no more branches left, just use the else value. // Note that these two are handled together here in a single case statement because // otherwise we cannot determine the data type for the elseValue if it is None (i.e. null). - val newBranches = branches.filter(_._1 != FalseLiteral) + val newBranches = branches.filter(x => !falseOrNullLiteral(x._1)) if (newBranches.isEmpty) { elseValue.getOrElse(Literal.create(null, e.dataType)) } else { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala index d436b627f6..33239c0084 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLite import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ -import org.apache.spark.sql.types.IntegerType +import org.apache.spark.sql.types.{IntegerType, NullType} class SimplifyConditionalSuite extends PlanTest with PredicateHelper { @@ -41,6 +41,7 @@ class SimplifyConditionalSuite extends PlanTest with PredicateHelper { private val trueBranch = (TrueLiteral, Literal(5)) private val normalBranch = (NonFoldableLiteral(true), Literal(10)) private val unreachableBranch = (FalseLiteral, Literal(20)) + private val nullBranch = (Literal(null, NullType), Literal(30)) test("simplify if") { assertEquivalent( @@ -50,18 +51,22 @@ class SimplifyConditionalSuite extends PlanTest with PredicateHelper { assertEquivalent( If(FalseLiteral, Literal(10), Literal(20)), Literal(20)) + + assertEquivalent( + If(Literal(null, NullType), Literal(10), Literal(20)), + Literal(20)) } test("remove unreachable branches") { // i.e. removing branches whose conditions are always false assertEquivalent( - CaseWhen(unreachableBranch :: normalBranch :: unreachableBranch :: Nil, None), + CaseWhen(unreachableBranch :: normalBranch :: unreachableBranch :: nullBranch :: Nil, None), CaseWhen(normalBranch :: Nil, None)) } test("remove entire CaseWhen if only the else branch is reachable") { assertEquivalent( - CaseWhen(unreachableBranch :: unreachableBranch :: Nil, Some(Literal(30))), + CaseWhen(unreachableBranch :: unreachableBranch :: nullBranch :: Nil, Some(Literal(30))), Literal(30)) assertEquivalent( @@ -71,12 +76,13 @@ class SimplifyConditionalSuite extends PlanTest with PredicateHelper { test("remove entire CaseWhen if the first branch is always true") { assertEquivalent( - CaseWhen(trueBranch :: normalBranch :: Nil, None), + CaseWhen(trueBranch :: normalBranch :: nullBranch :: Nil, None), Literal(5)) // Test branch elimination and simplification in combination assertEquivalent( - CaseWhen(unreachableBranch :: unreachableBranch:: trueBranch :: normalBranch :: Nil, None), + CaseWhen(unreachableBranch :: unreachableBranch :: nullBranch :: trueBranch :: normalBranch + :: Nil, None), Literal(5)) // Make sure this doesn't trigger if there is a non-foldable branch before the true branch -- cgit v1.2.3 From 4a6e78abd9d5edc4a5092738dff0006bbe202a89 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Sat, 2 Apr 2016 17:50:40 -0700 Subject: [MINOR][DOCS] Use multi-line JavaDoc comments in Scala code. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? This PR aims to fix all Scala-Style multiline comments into Java-Style multiline comments in Scala codes. (All comment-only changes over 77 files: +786 lines, −747 lines) ## How was this patch tested? Manual. Author: Dongjoon Hyun Closes #12130 from dongjoon-hyun/use_multiine_javadoc_comments. --- .../main/scala/org/apache/spark/FutureAction.scala | 14 +- .../main/scala/org/apache/spark/SSLOptions.scala | 57 +++--- .../main/scala/org/apache/spark/SparkContext.scala | 42 +++-- .../org/apache/spark/api/java/JavaPairRDD.scala | 8 +- .../apache/spark/api/java/JavaSparkContext.scala | 60 ++++--- .../apache/spark/deploy/worker/CommandUtils.scala | 2 +- .../scala/org/apache/spark/rdd/CoGroupedRDD.scala | 10 +- core/src/main/scala/org/apache/spark/rdd/RDD.scala | 4 +- .../mesos/CoarseMesosSchedulerBackend.scala | 24 +-- .../cluster/mesos/MesosClusterScheduler.scala | 12 +- .../cluster/mesos/MesosSchedulerUtils.scala | 4 +- .../org/apache/spark/shuffle/ShuffleManager.scala | 6 +- .../apache/spark/storage/memory/MemoryStore.scala | 20 +-- .../main/scala/org/apache/spark/util/Utils.scala | 17 +- core/src/test/scala/org/apache/spark/Smuggle.scala | 46 ++--- .../apache/spark/memory/MemoryManagerSuite.scala | 24 +-- .../org/apache/spark/examples/BroadcastTest.scala | 4 +- .../apache/spark/examples/DFSReadWriteTest.scala | 20 +-- .../org/apache/spark/examples/GroupByTest.scala | 4 +- .../apache/spark/examples/MultiBroadcastTest.scala | 4 +- .../spark/examples/SimpleSkewedGroupByTest.scala | 4 +- .../apache/spark/examples/SkewedGroupByTest.scala | 4 +- .../streaming/clickstream/PageViewGenerator.scala | 23 +-- .../streaming/clickstream/PageViewStream.scala | 21 +-- .../spark/streaming/flume/FlumeInputDStream.scala | 15 +- .../spark/streaming/kafka/KafkaRDDPartition.scala | 15 +- .../scala/org/apache/spark/graphx/GraphOps.scala | 10 +- .../spark/graphx/lib/ConnectedComponents.scala | 18 +- .../spark/ml/feature/ElementwiseProduct.scala | 6 +- .../api/python/GaussianMixtureModelWrapper.scala | 8 +- .../mllib/api/python/Word2VecModelWrapper.scala | 4 +- .../org/apache/spark/mllib/linalg/Matrices.scala | 16 +- .../StreamingLinearRegressionWithSGD.scala | 4 +- .../scala/org/apache/spark/repl/SparkILoop.scala | 21 ++- .../scala/org/apache/spark/repl/SparkImports.scala | 5 +- .../main/scala/org/apache/spark/sql/Encoder.scala | 24 +-- .../spark/sql/catalyst/analysis/Analyzer.scala | 20 +-- .../sql/catalyst/expressions/Projection.scala | 6 +- .../expressions/codegen/CodeGenerator.scala | 26 +-- .../spark/sql/catalyst/expressions/grouping.scala | 18 +- .../spark/sql/catalyst/expressions/misc.scala | 4 +- .../spark/sql/catalyst/optimizer/Optimizer.scala | 40 ++--- .../spark/sql/catalyst/parser/AstBuilder.scala | 4 +- .../spark/sql/catalyst/planning/patterns.scala | 28 +-- .../spark/sql/catalyst/plans/QueryPlan.scala | 4 +- .../sql/catalyst/plans/physical/partitioning.scala | 6 +- .../optimizer/OptimizerExtendableSuite.scala | 14 +- .../scala/org/apache/spark/sql/SQLContext.scala | 14 +- .../apache/spark/sql/execution/CacheManager.scala | 7 +- .../org/apache/spark/sql/execution/SparkPlan.scala | 4 +- .../spark/sql/execution/WholeStageCodegen.scala | 172 +++++++++---------- .../org/apache/spark/sql/execution/Window.scala | 36 ++-- .../execution/aggregate/AggregationIterator.scala | 22 +-- .../aggregate/SortBasedAggregationIterator.scala | 6 +- .../execution/aggregate/TungstenAggregate.scala | 16 +- .../execution/datasources/SqlNewHadoopRDD.scala | 8 +- .../execution/datasources/csv/CSVInferSchema.scala | 22 +-- .../execution/datasources/csv/DefaultSource.scala | 4 +- .../spark/sql/execution/datasources/ddl.scala | 8 +- .../sql/execution/joins/BroadcastHashJoin.scala | 22 +-- .../sql/execution/joins/CartesianProduct.scala | 8 +- .../spark/sql/execution/joins/HashJoin.scala | 8 +- .../spark/sql/execution/joins/HashedRelation.scala | 76 ++++---- .../spark/sql/execution/joins/SortMergeJoin.scala | 36 ++-- .../state/HDFSBackedStateStoreProvider.scala | 8 +- .../spark/sql/execution/ui/SparkPlanGraph.scala | 4 +- .../scala/org/apache/spark/sql/functions.scala | 191 ++++++++++----------- .../org/apache/spark/sql/sources/interfaces.scala | 26 +-- .../scala/org/apache/spark/sql/QueryTest.scala | 4 +- .../sql/execution/BenchmarkWholeStageCodegen.scala | 8 +- .../execution/datasources/csv/CSVParserSuite.scala | 4 +- .../apache/spark/streaming/StreamingContext.scala | 7 +- .../streaming/api/java/JavaStreamingContext.scala | 7 +- .../spark/streaming/receiver/RateLimiter.scala | 23 +-- .../apache/spark/streaming/scheduler/JobSet.scala | 7 +- .../apache/spark/tools/GenerateMIMAIgnore.scala | 9 +- .../spark/deploy/yarn/YarnSparkHadoopUtil.scala | 16 +- 77 files changed, 786 insertions(+), 747 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/FutureAction.scala b/core/src/main/scala/org/apache/spark/FutureAction.scala index 2a8220ff40..ce11772a6d 100644 --- a/core/src/main/scala/org/apache/spark/FutureAction.scala +++ b/core/src/main/scala/org/apache/spark/FutureAction.scala @@ -146,16 +146,16 @@ class SimpleFutureAction[T] private[spark](jobWaiter: JobWaiter[_], resultFunc: /** - * Handle via which a "run" function passed to a [[ComplexFutureAction]] - * can submit jobs for execution. - */ + * Handle via which a "run" function passed to a [[ComplexFutureAction]] + * can submit jobs for execution. + */ @DeveloperApi trait JobSubmitter { /** - * Submit a job for execution and return a FutureAction holding the result. - * This is a wrapper around the same functionality provided by SparkContext - * to enable cancellation. - */ + * Submit a job for execution and return a FutureAction holding the result. + * This is a wrapper around the same functionality provided by SparkContext + * to enable cancellation. + */ def submitJob[T, U, R]( rdd: RDD[T], processPartition: Iterator[T] => U, diff --git a/core/src/main/scala/org/apache/spark/SSLOptions.scala b/core/src/main/scala/org/apache/spark/SSLOptions.scala index 30db6ccbf4..719905a2c9 100644 --- a/core/src/main/scala/org/apache/spark/SSLOptions.scala +++ b/core/src/main/scala/org/apache/spark/SSLOptions.scala @@ -132,34 +132,35 @@ private[spark] case class SSLOptions( private[spark] object SSLOptions extends Logging { - /** Resolves SSLOptions settings from a given Spark configuration object at a given namespace. - * - * The following settings are allowed: - * $ - `[ns].enabled` - `true` or `false`, to enable or disable SSL respectively - * $ - `[ns].keyStore` - a path to the key-store file; can be relative to the current directory - * $ - `[ns].keyStorePassword` - a password to the key-store file - * $ - `[ns].keyPassword` - a password to the private key - * $ - `[ns].keyStoreType` - the type of the key-store - * $ - `[ns].needClientAuth` - whether SSL needs client authentication - * $ - `[ns].trustStore` - a path to the trust-store file; can be relative to the current - * directory - * $ - `[ns].trustStorePassword` - a password to the trust-store file - * $ - `[ns].trustStoreType` - the type of trust-store - * $ - `[ns].protocol` - a protocol name supported by a particular Java version - * $ - `[ns].enabledAlgorithms` - a comma separated list of ciphers - * - * For a list of protocols and ciphers supported by particular Java versions, you may go to - * [[https://blogs.oracle.com/java-platform-group/entry/diagnosing_tls_ssl_and_https Oracle - * blog page]]. - * - * You can optionally specify the default configuration. If you do, for each setting which is - * missing in SparkConf, the corresponding setting is used from the default configuration. - * - * @param conf Spark configuration object where the settings are collected from - * @param ns the namespace name - * @param defaults the default configuration - * @return [[org.apache.spark.SSLOptions]] object - */ + /** + * Resolves SSLOptions settings from a given Spark configuration object at a given namespace. + * + * The following settings are allowed: + * $ - `[ns].enabled` - `true` or `false`, to enable or disable SSL respectively + * $ - `[ns].keyStore` - a path to the key-store file; can be relative to the current directory + * $ - `[ns].keyStorePassword` - a password to the key-store file + * $ - `[ns].keyPassword` - a password to the private key + * $ - `[ns].keyStoreType` - the type of the key-store + * $ - `[ns].needClientAuth` - whether SSL needs client authentication + * $ - `[ns].trustStore` - a path to the trust-store file; can be relative to the current + * directory + * $ - `[ns].trustStorePassword` - a password to the trust-store file + * $ - `[ns].trustStoreType` - the type of trust-store + * $ - `[ns].protocol` - a protocol name supported by a particular Java version + * $ - `[ns].enabledAlgorithms` - a comma separated list of ciphers + * + * For a list of protocols and ciphers supported by particular Java versions, you may go to + * [[https://blogs.oracle.com/java-platform-group/entry/diagnosing_tls_ssl_and_https Oracle + * blog page]]. + * + * You can optionally specify the default configuration. If you do, for each setting which is + * missing in SparkConf, the corresponding setting is used from the default configuration. + * + * @param conf Spark configuration object where the settings are collected from + * @param ns the namespace name + * @param defaults the default configuration + * @return [[org.apache.spark.SSLOptions]] object + */ def parse(conf: SparkConf, ns: String, defaults: Option[SSLOptions] = None): SSLOptions = { val enabled = conf.getBoolean(s"$ns.enabled", defaultValue = defaults.exists(_.enabled)) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index d7cb253d69..4b3264cbf5 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -773,9 +773,11 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli parallelize(seq, numSlices) } - /** Distribute a local Scala collection to form an RDD, with one or more - * location preferences (hostnames of Spark nodes) for each object. - * Create a new partition for each collection item. */ + /** + * Distribute a local Scala collection to form an RDD, with one or more + * location preferences (hostnames of Spark nodes) for each object. + * Create a new partition for each collection item. + */ def makeRDD[T: ClassTag](seq: Seq[(T, Seq[String])]): RDD[T] = withScope { assertNotStopped() val indexToPrefs = seq.zipWithIndex.map(t => (t._2, t._1._2)).toMap @@ -1095,14 +1097,15 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli new NewHadoopRDD(this, fClass, kClass, vClass, jconf) } - /** Get an RDD for a Hadoop SequenceFile with given key and value types. - * - * '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each - * record, directly caching the returned RDD or directly passing it to an aggregation or shuffle - * operation will create many references to the same object. - * If you plan to directly cache, sort, or aggregate Hadoop writable objects, you should first - * copy them using a `map` function. - */ + /** + * Get an RDD for a Hadoop SequenceFile with given key and value types. + * + * '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each + * record, directly caching the returned RDD or directly passing it to an aggregation or shuffle + * operation will create many references to the same object. + * If you plan to directly cache, sort, or aggregate Hadoop writable objects, you should first + * copy them using a `map` function. + */ def sequenceFile[K, V](path: String, keyClass: Class[K], valueClass: Class[V], @@ -1113,14 +1116,15 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli hadoopFile(path, inputFormatClass, keyClass, valueClass, minPartitions) } - /** Get an RDD for a Hadoop SequenceFile with given key and value types. - * - * '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each - * record, directly caching the returned RDD or directly passing it to an aggregation or shuffle - * operation will create many references to the same object. - * If you plan to directly cache, sort, or aggregate Hadoop writable objects, you should first - * copy them using a `map` function. - * */ + /** + * Get an RDD for a Hadoop SequenceFile with given key and value types. + * + * '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each + * record, directly caching the returned RDD or directly passing it to an aggregation or shuffle + * operation will create many references to the same object. + * If you plan to directly cache, sort, or aggregate Hadoop writable objects, you should first + * copy them using a `map` function. + */ def sequenceFile[K, V]( path: String, keyClass: Class[K], diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala index e080f91f50..2897272a8b 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala @@ -461,10 +461,10 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) fromRDD(rdd.partitionBy(partitioner)) /** - * Return an RDD containing all pairs of elements with matching keys in `this` and `other`. Each - * pair of elements will be returned as a (k, (v1, v2)) tuple, where (k, v1) is in `this` and - * (k, v2) is in `other`. Uses the given Partitioner to partition the output RDD. - */ + * Return an RDD containing all pairs of elements with matching keys in `this` and `other`. Each + * pair of elements will be returned as a (k, (v1, v2)) tuple, where (k, v1) is in `this` and + * (k, v2) is in `other`. Uses the given Partitioner to partition the output RDD. + */ def join[W](other: JavaPairRDD[K, W], partitioner: Partitioner): JavaPairRDD[K, (V, W)] = fromRDD(rdd.join(other, partitioner)) diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala index d362c40b7a..dfd91ae338 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala @@ -295,13 +295,14 @@ class JavaSparkContext(val sc: SparkContext) new JavaRDD(sc.binaryRecords(path, recordLength)) } - /** Get an RDD for a Hadoop SequenceFile with given key and value types. - * - * '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each - * record, directly caching the returned RDD will create many references to the same object. - * If you plan to directly cache Hadoop writable objects, you should first copy them using - * a `map` function. - * */ + /** + * Get an RDD for a Hadoop SequenceFile with given key and value types. + * + * '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each + * record, directly caching the returned RDD will create many references to the same object. + * If you plan to directly cache Hadoop writable objects, you should first copy them using + * a `map` function. + */ def sequenceFile[K, V](path: String, keyClass: Class[K], valueClass: Class[V], @@ -312,13 +313,14 @@ class JavaSparkContext(val sc: SparkContext) new JavaPairRDD(sc.sequenceFile(path, keyClass, valueClass, minPartitions)) } - /** Get an RDD for a Hadoop SequenceFile. - * - * '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each - * record, directly caching the returned RDD will create many references to the same object. - * If you plan to directly cache Hadoop writable objects, you should first copy them using - * a `map` function. - */ + /** + * Get an RDD for a Hadoop SequenceFile. + * + * '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each + * record, directly caching the returned RDD will create many references to the same object. + * If you plan to directly cache Hadoop writable objects, you should first copy them using + * a `map` function. + */ def sequenceFile[K, V](path: String, keyClass: Class[K], valueClass: Class[V]): JavaPairRDD[K, V] = { implicit val ctagK: ClassTag[K] = ClassTag(keyClass) @@ -411,13 +413,14 @@ class JavaSparkContext(val sc: SparkContext) new JavaHadoopRDD(rdd.asInstanceOf[HadoopRDD[K, V]]) } - /** Get an RDD for a Hadoop file with an arbitrary InputFormat. - * - * '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each - * record, directly caching the returned RDD will create many references to the same object. - * If you plan to directly cache Hadoop writable objects, you should first copy them using - * a `map` function. - */ + /** + * Get an RDD for a Hadoop file with an arbitrary InputFormat. + * + * '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each + * record, directly caching the returned RDD will create many references to the same object. + * If you plan to directly cache Hadoop writable objects, you should first copy them using + * a `map` function. + */ def hadoopFile[K, V, F <: InputFormat[K, V]]( path: String, inputFormatClass: Class[F], @@ -431,13 +434,14 @@ class JavaSparkContext(val sc: SparkContext) new JavaHadoopRDD(rdd.asInstanceOf[HadoopRDD[K, V]]) } - /** Get an RDD for a Hadoop file with an arbitrary InputFormat - * - * '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each - * record, directly caching the returned RDD will create many references to the same object. - * If you plan to directly cache Hadoop writable objects, you should first copy them using - * a `map` function. - */ + /** + * Get an RDD for a Hadoop file with an arbitrary InputFormat + * + * '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each + * record, directly caching the returned RDD will create many references to the same object. + * If you plan to directly cache Hadoop writable objects, you should first copy them using + * a `map` function. + */ def hadoopFile[K, V, F <: InputFormat[K, V]]( path: String, inputFormatClass: Class[F], diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/CommandUtils.scala b/core/src/main/scala/org/apache/spark/deploy/worker/CommandUtils.scala index a4efafcb27..cba4aaffe2 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/CommandUtils.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/CommandUtils.scala @@ -29,7 +29,7 @@ import org.apache.spark.launcher.WorkerCommandBuilder import org.apache.spark.util.Utils /** - ** Utilities for running commands with the spark classpath. + * Utilities for running commands with the spark classpath. */ private[deploy] object CommandUtils extends Logging { diff --git a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala index e5ebc63082..7bc1eb0436 100644 --- a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala @@ -29,10 +29,12 @@ import org.apache.spark.serializer.Serializer import org.apache.spark.util.collection.{CompactBuffer, ExternalAppendOnlyMap} import org.apache.spark.util.Utils -/** The references to rdd and splitIndex are transient because redundant information is stored - * in the CoGroupedRDD object. Because CoGroupedRDD is serialized separately from - * CoGroupPartition, if rdd and splitIndex aren't transient, they'll be included twice in the - * task closure. */ +/** + * The references to rdd and splitIndex are transient because redundant information is stored + * in the CoGroupedRDD object. Because CoGroupedRDD is serialized separately from + * CoGroupPartition, if rdd and splitIndex aren't transient, they'll be included twice in the + * task closure. + */ private[spark] case class NarrowCoGroupSplitDep( @transient rdd: RDD[_], @transient splitIndex: Int, diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index f96551c793..4a0a2199ef 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -255,8 +255,8 @@ abstract class RDD[T: ClassTag]( } /** - * Returns the number of partitions of this RDD. - */ + * Returns the number of partitions of this RDD. + */ @Since("1.6.0") final def getNumPartitions: Int = partitions.length diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala index 90b1813750..50b452c72f 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala @@ -295,12 +295,12 @@ private[spark] class CoarseMesosSchedulerBackend( } /** - * Launches executors on accepted offers, and declines unused offers. Executors are launched - * round-robin on offers. - * - * @param d SchedulerDriver - * @param offers Mesos offers that match attribute constraints - */ + * Launches executors on accepted offers, and declines unused offers. Executors are launched + * round-robin on offers. + * + * @param d SchedulerDriver + * @param offers Mesos offers that match attribute constraints + */ private def handleMatchedOffers(d: SchedulerDriver, offers: Buffer[Offer]): Unit = { val tasks = buildMesosTasks(offers) for (offer <- offers) { @@ -336,12 +336,12 @@ private[spark] class CoarseMesosSchedulerBackend( } /** - * Returns a map from OfferIDs to the tasks to launch on those offers. In order to maximize - * per-task memory and IO, tasks are round-robin assigned to offers. - * - * @param offers Mesos offers that match attribute constraints - * @return A map from OfferID to a list of Mesos tasks to launch on that offer - */ + * Returns a map from OfferIDs to the tasks to launch on those offers. In order to maximize + * per-task memory and IO, tasks are round-robin assigned to offers. + * + * @param offers Mesos offers that match attribute constraints + * @return A map from OfferID to a list of Mesos tasks to launch on that offer + */ private def buildMesosTasks(offers: Buffer[Offer]): Map[OfferID, List[MesosTaskInfo]] = { // offerID -> tasks val tasks = new HashMap[OfferID, List[MesosTaskInfo]].withDefaultValue(Nil) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala index c41fa58607..73bd4c58e1 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala @@ -453,12 +453,12 @@ private[spark] class MesosClusterScheduler( } /** - * Escape args for Unix-like shells, unless already quoted by the user. - * Based on: http://www.gnu.org/software/bash/manual/html_node/Double-Quotes.html - * and http://www.grymoire.com/Unix/Quote.html - * @param value argument - * @return escaped argument - */ + * Escape args for Unix-like shells, unless already quoted by the user. + * Based on: http://www.gnu.org/software/bash/manual/html_node/Double-Quotes.html + * and http://www.grymoire.com/Unix/Quote.html + * @param value argument + * @return escaped argument + */ private[scheduler] def shellEscape(value: String): String = { val WrappedInQuotes = """^(".+"|'.+')$""".r val ShellSpecialChars = (""".*([ '<>&|\?\*;!#\\(\)"$`]).*""").r diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala index 9a12a61f2f..35f914355d 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala @@ -148,8 +148,8 @@ private[mesos] trait MesosSchedulerUtils extends Logging { } /** - * Signal that the scheduler has registered with Mesos. - */ + * Signal that the scheduler has registered with Mesos. + */ protected def markRegistered(): Unit = { registerLatch.countDown() } diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleManager.scala index 76fd249fbd..364fad664e 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/ShuffleManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleManager.scala @@ -54,9 +54,9 @@ private[spark] trait ShuffleManager { context: TaskContext): ShuffleReader[K, C] /** - * Remove a shuffle's metadata from the ShuffleManager. - * @return true if the metadata removed successfully, otherwise false. - */ + * Remove a shuffle's metadata from the ShuffleManager. + * @return true if the metadata removed successfully, otherwise false. + */ def unregisterShuffle(shuffleId: Int): Boolean /** diff --git a/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala b/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala index df38d11e43..99be4de065 100644 --- a/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala @@ -455,16 +455,16 @@ private[spark] class MemoryStore( } /** - * Try to evict blocks to free up a given amount of space to store a particular block. - * Can fail if either the block is bigger than our memory or it would require replacing - * another block from the same RDD (which leads to a wasteful cyclic replacement pattern for - * RDDs that don't fit into memory that we want to avoid). - * - * @param blockId the ID of the block we are freeing space for, if any - * @param space the size of this block - * @param memoryMode the type of memory to free (on- or off-heap) - * @return the amount of memory (in bytes) freed by eviction - */ + * Try to evict blocks to free up a given amount of space to store a particular block. + * Can fail if either the block is bigger than our memory or it would require replacing + * another block from the same RDD (which leads to a wasteful cyclic replacement pattern for + * RDDs that don't fit into memory that we want to avoid). + * + * @param blockId the ID of the block we are freeing space for, if any + * @param space the size of this block + * @param memoryMode the type of memory to free (on- or off-heap) + * @return the amount of memory (in bytes) freed by eviction + */ private[spark] def evictBlocksToFreeSpace( blockId: Option[BlockId], space: Long, diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 73768ff4c8..50bcf85805 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -256,10 +256,11 @@ private[spark] object Utils extends Logging { dir } - /** Copy all data from an InputStream to an OutputStream. NIO way of file stream to file stream - * copying is disabled by default unless explicitly set transferToEnabled as true, - * the parameter transferToEnabled should be configured by spark.file.transferTo = [true|false]. - */ + /** + * Copy all data from an InputStream to an OutputStream. NIO way of file stream to file stream + * copying is disabled by default unless explicitly set transferToEnabled as true, + * the parameter transferToEnabled should be configured by spark.file.transferTo = [true|false]. + */ def copyStream(in: InputStream, out: OutputStream, closeStreams: Boolean = false, @@ -1564,9 +1565,11 @@ private[spark] object Utils extends Logging { else -1 } - /** Returns the system properties map that is thread-safe to iterator over. It gets the - * properties which have been set explicitly, as well as those for which only a default value - * has been defined. */ + /** + * Returns the system properties map that is thread-safe to iterator over. It gets the + * properties which have been set explicitly, as well as those for which only a default value + * has been defined. + */ def getSystemProperties: Map[String, String] = { System.getProperties.stringPropertyNames().asScala .map(key => (key, System.getProperty(key))).toMap diff --git a/core/src/test/scala/org/apache/spark/Smuggle.scala b/core/src/test/scala/org/apache/spark/Smuggle.scala index 9f0a1b4c25..9d9217ea1b 100644 --- a/core/src/test/scala/org/apache/spark/Smuggle.scala +++ b/core/src/test/scala/org/apache/spark/Smuggle.scala @@ -24,16 +24,16 @@ import scala.collection.mutable import scala.language.implicitConversions /** - * Utility wrapper to "smuggle" objects into tasks while bypassing serialization. - * This is intended for testing purposes, primarily to make locks, semaphores, and - * other constructs that would not survive serialization available from within tasks. - * A Smuggle reference is itself serializable, but after being serialized and - * deserialized, it still refers to the same underlying "smuggled" object, as long - * as it was deserialized within the same JVM. This can be useful for tests that - * depend on the timing of task completion to be deterministic, since one can "smuggle" - * a lock or semaphore into the task, and then the task can block until the test gives - * the go-ahead to proceed via the lock. - */ + * Utility wrapper to "smuggle" objects into tasks while bypassing serialization. + * This is intended for testing purposes, primarily to make locks, semaphores, and + * other constructs that would not survive serialization available from within tasks. + * A Smuggle reference is itself serializable, but after being serialized and + * deserialized, it still refers to the same underlying "smuggled" object, as long + * as it was deserialized within the same JVM. This can be useful for tests that + * depend on the timing of task completion to be deterministic, since one can "smuggle" + * a lock or semaphore into the task, and then the task can block until the test gives + * the go-ahead to proceed via the lock. + */ class Smuggle[T] private(val key: Symbol) extends Serializable { def smuggledObject: T = Smuggle.get(key) } @@ -41,13 +41,13 @@ class Smuggle[T] private(val key: Symbol) extends Serializable { object Smuggle { /** - * Wraps the specified object to be smuggled into a serialized task without - * being serialized itself. - * - * @param smuggledObject - * @tparam T - * @return Smuggle wrapper around smuggledObject. - */ + * Wraps the specified object to be smuggled into a serialized task without + * being serialized itself. + * + * @param smuggledObject + * @tparam T + * @return Smuggle wrapper around smuggledObject. + */ def apply[T](smuggledObject: T): Smuggle[T] = { val key = Symbol(UUID.randomUUID().toString) lock.writeLock().lock() @@ -72,12 +72,12 @@ object Smuggle { } /** - * Implicit conversion of a Smuggle wrapper to the object being smuggled. - * - * @param smuggle the wrapper to unpack. - * @tparam T - * @return the smuggled object represented by the wrapper. - */ + * Implicit conversion of a Smuggle wrapper to the object being smuggled. + * + * @param smuggle the wrapper to unpack. + * @tparam T + * @return the smuggled object represented by the wrapper. + */ implicit def unpackSmuggledObject[T](smuggle : Smuggle[T]): T = smuggle.smuggledObject } diff --git a/core/src/test/scala/org/apache/spark/memory/MemoryManagerSuite.scala b/core/src/test/scala/org/apache/spark/memory/MemoryManagerSuite.scala index 3d1a0e9795..99d5b496bc 100644 --- a/core/src/test/scala/org/apache/spark/memory/MemoryManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/memory/MemoryManagerSuite.scala @@ -78,18 +78,18 @@ private[memory] trait MemoryManagerSuite extends SparkFunSuite with BeforeAndAft } /** - * Simulate the part of [[MemoryStore.evictBlocksToFreeSpace]] that releases storage memory. - * - * This is a significant simplification of the real method, which actually drops existing - * blocks based on the size of each block. Instead, here we simply release as many bytes - * as needed to ensure the requested amount of free space. This allows us to set up the - * test without relying on the [[org.apache.spark.storage.BlockManager]], which brings in - * many other dependencies. - * - * Every call to this method will set a global variable, [[evictBlocksToFreeSpaceCalled]], that - * records the number of bytes this is called with. This variable is expected to be cleared - * by the test code later through [[assertEvictBlocksToFreeSpaceCalled]]. - */ + * Simulate the part of [[MemoryStore.evictBlocksToFreeSpace]] that releases storage memory. + * + * This is a significant simplification of the real method, which actually drops existing + * blocks based on the size of each block. Instead, here we simply release as many bytes + * as needed to ensure the requested amount of free space. This allows us to set up the + * test without relying on the [[org.apache.spark.storage.BlockManager]], which brings in + * many other dependencies. + * + * Every call to this method will set a global variable, [[evictBlocksToFreeSpaceCalled]], that + * records the number of bytes this is called with. This variable is expected to be cleared + * by the test code later through [[assertEvictBlocksToFreeSpaceCalled]]. + */ private def evictBlocksToFreeSpaceAnswer(mm: MemoryManager): Answer[Long] = { new Answer[Long] { override def answer(invocation: InvocationOnMock): Long = { diff --git a/examples/src/main/scala/org/apache/spark/examples/BroadcastTest.scala b/examples/src/main/scala/org/apache/spark/examples/BroadcastTest.scala index 3da5236745..af5a815f6e 100644 --- a/examples/src/main/scala/org/apache/spark/examples/BroadcastTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/BroadcastTest.scala @@ -21,8 +21,8 @@ package org.apache.spark.examples import org.apache.spark.{SparkConf, SparkContext} /** - * Usage: BroadcastTest [slices] [numElem] [blockSize] - */ + * Usage: BroadcastTest [slices] [numElem] [blockSize] + */ object BroadcastTest { def main(args: Array[String]) { diff --git a/examples/src/main/scala/org/apache/spark/examples/DFSReadWriteTest.scala b/examples/src/main/scala/org/apache/spark/examples/DFSReadWriteTest.scala index 743fc13db7..7bf023667d 100644 --- a/examples/src/main/scala/org/apache/spark/examples/DFSReadWriteTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/DFSReadWriteTest.scala @@ -25,16 +25,16 @@ import scala.io.Source._ import org.apache.spark.{SparkConf, SparkContext} /** - * Simple test for reading and writing to a distributed - * file system. This example does the following: - * - * 1. Reads local file - * 2. Computes word count on local file - * 3. Writes local file to a DFS - * 4. Reads the file back from the DFS - * 5. Computes word count on the file using Spark - * 6. Compares the word count results - */ + * Simple test for reading and writing to a distributed + * file system. This example does the following: + * + * 1. Reads local file + * 2. Computes word count on local file + * 3. Writes local file to a DFS + * 4. Reads the file back from the DFS + * 5. Computes word count on the file using Spark + * 6. Compares the word count results + */ object DFSReadWriteTest { private var localFilePath: File = new File(".") diff --git a/examples/src/main/scala/org/apache/spark/examples/GroupByTest.scala b/examples/src/main/scala/org/apache/spark/examples/GroupByTest.scala index 08b6c717d4..4db229b5de 100644 --- a/examples/src/main/scala/org/apache/spark/examples/GroupByTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/GroupByTest.scala @@ -23,8 +23,8 @@ import java.util.Random import org.apache.spark.{SparkConf, SparkContext} /** - * Usage: GroupByTest [numMappers] [numKVPairs] [KeySize] [numReducers] - */ + * Usage: GroupByTest [numMappers] [numKVPairs] [KeySize] [numReducers] + */ object GroupByTest { def main(args: Array[String]) { val sparkConf = new SparkConf().setAppName("GroupBy Test") diff --git a/examples/src/main/scala/org/apache/spark/examples/MultiBroadcastTest.scala b/examples/src/main/scala/org/apache/spark/examples/MultiBroadcastTest.scala index 134c3d1d63..3eb0c27723 100644 --- a/examples/src/main/scala/org/apache/spark/examples/MultiBroadcastTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/MultiBroadcastTest.scala @@ -22,8 +22,8 @@ import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.rdd.RDD /** - * Usage: MultiBroadcastTest [slices] [numElem] - */ + * Usage: MultiBroadcastTest [slices] [numElem] + */ object MultiBroadcastTest { def main(args: Array[String]) { diff --git a/examples/src/main/scala/org/apache/spark/examples/SimpleSkewedGroupByTest.scala b/examples/src/main/scala/org/apache/spark/examples/SimpleSkewedGroupByTest.scala index 7c09664c2f..ec07e6323e 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SimpleSkewedGroupByTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SimpleSkewedGroupByTest.scala @@ -23,8 +23,8 @@ import java.util.Random import org.apache.spark.{SparkConf, SparkContext} /** - * Usage: SimpleSkewedGroupByTest [numMappers] [numKVPairs] [valSize] [numReducers] [ratio] - */ + * Usage: SimpleSkewedGroupByTest [numMappers] [numKVPairs] [valSize] [numReducers] [ratio] + */ object SimpleSkewedGroupByTest { def main(args: Array[String]) { diff --git a/examples/src/main/scala/org/apache/spark/examples/SkewedGroupByTest.scala b/examples/src/main/scala/org/apache/spark/examples/SkewedGroupByTest.scala index d498af9c39..8e4c2b6229 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SkewedGroupByTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SkewedGroupByTest.scala @@ -23,8 +23,8 @@ import java.util.Random import org.apache.spark.{SparkConf, SparkContext} /** - * Usage: GroupByTest [numMappers] [numKVPairs] [KeySize] [numReducers] - */ + * Usage: GroupByTest [numMappers] [numKVPairs] [KeySize] [numReducers] + */ object SkewedGroupByTest { def main(args: Array[String]) { val sparkConf = new SparkConf().setAppName("GroupBy Test") diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewGenerator.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewGenerator.scala index 50216b9bd4..0ddd065f0d 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewGenerator.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewGenerator.scala @@ -38,17 +38,18 @@ object PageView extends Serializable { } // scalastyle:off -/** Generates streaming events to simulate page views on a website. - * - * This should be used in tandem with PageViewStream.scala. Example: - * - * To run the generator - * `$ bin/run-example org.apache.spark.examples.streaming.clickstream.PageViewGenerator 44444 10` - * To process the generated stream - * `$ bin/run-example \ - * org.apache.spark.examples.streaming.clickstream.PageViewStream errorRatePerZipCode localhost 44444` - * - */ +/** + * Generates streaming events to simulate page views on a website. + * + * This should be used in tandem with PageViewStream.scala. Example: + * + * To run the generator + * `$ bin/run-example org.apache.spark.examples.streaming.clickstream.PageViewGenerator 44444 10` + * To process the generated stream + * `$ bin/run-example \ + * org.apache.spark.examples.streaming.clickstream.PageViewStream errorRatePerZipCode localhost 44444` + * + */ // scalastyle:on object PageViewGenerator { val pages = Map("http://foo.com/" -> .7, diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewStream.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewStream.scala index 773a2e5fc2..1ba093f57b 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewStream.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewStream.scala @@ -22,16 +22,17 @@ import org.apache.spark.examples.streaming.StreamingExamples import org.apache.spark.streaming.{Seconds, StreamingContext} // scalastyle:off -/** Analyses a streaming dataset of web page views. This class demonstrates several types of - * operators available in Spark streaming. - * - * This should be used in tandem with PageViewStream.scala. Example: - * To run the generator - * `$ bin/run-example org.apache.spark.examples.streaming.clickstream.PageViewGenerator 44444 10` - * To process the generated stream - * `$ bin/run-example \ - * org.apache.spark.examples.streaming.clickstream.PageViewStream errorRatePerZipCode localhost 44444` - */ +/** + * Analyses a streaming dataset of web page views. This class demonstrates several types of + * operators available in Spark streaming. + * + * This should be used in tandem with PageViewStream.scala. Example: + * To run the generator + * `$ bin/run-example org.apache.spark.examples.streaming.clickstream.PageViewGenerator 44444 10` + * To process the generated stream + * `$ bin/run-example \ + * org.apache.spark.examples.streaming.clickstream.PageViewStream errorRatePerZipCode localhost 44444` + */ // scalastyle:on object PageViewStream { def main(args: Array[String]) { diff --git a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeInputDStream.scala b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeInputDStream.scala index 7dc9606913..6e7c3f358e 100644 --- a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeInputDStream.scala +++ b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeInputDStream.scala @@ -185,13 +185,14 @@ class FlumeReceiver( override def preferredLocation: Option[String] = Option(host) - /** A Netty Pipeline factory that will decompress incoming data from - * and the Netty client and compress data going back to the client. - * - * The compression on the return is required because Flume requires - * a successful response to indicate it can remove the event/batch - * from the configured channel - */ + /** + * A Netty Pipeline factory that will decompress incoming data from + * and the Netty client and compress data going back to the client. + * + * The compression on the return is required because Flume requires + * a successful response to indicate it can remove the event/batch + * from the configured channel + */ private[streaming] class CompressionChannelPipelineFactory extends ChannelPipelineFactory { def getPipeline(): ChannelPipeline = { diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDDPartition.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDDPartition.scala index a660d2a00c..02917becf0 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDDPartition.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDDPartition.scala @@ -19,13 +19,14 @@ package org.apache.spark.streaming.kafka import org.apache.spark.Partition -/** @param topic kafka topic name - * @param partition kafka partition id - * @param fromOffset inclusive starting offset - * @param untilOffset exclusive ending offset - * @param host preferred kafka host, i.e. the leader at the time the rdd was created - * @param port preferred kafka host's port - */ +/** + * @param topic kafka topic name + * @param partition kafka partition id + * @param fromOffset inclusive starting offset + * @param untilOffset exclusive ending offset + * @param host preferred kafka host, i.e. the leader at the time the rdd was created + * @param port preferred kafka host's port + */ private[kafka] class KafkaRDDPartition( val index: Int, diff --git a/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala b/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala index a783fe305f..868658dfe5 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala @@ -415,11 +415,11 @@ class GraphOps[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED]) extends Seriali } /** - * Compute the connected component membership of each vertex and return a graph with the vertex - * value containing the lowest vertex id in the connected component containing that vertex. - * - * @see [[org.apache.spark.graphx.lib.ConnectedComponents$#run]] - */ + * Compute the connected component membership of each vertex and return a graph with the vertex + * value containing the lowest vertex id in the connected component containing that vertex. + * + * @see [[org.apache.spark.graphx.lib.ConnectedComponents$#run]] + */ def connectedComponents(): Graph[VertexId, ED] = { ConnectedComponents.run(graph) } diff --git a/graphx/src/main/scala/org/apache/spark/graphx/lib/ConnectedComponents.scala b/graphx/src/main/scala/org/apache/spark/graphx/lib/ConnectedComponents.scala index 137c512c99..4e9b13162e 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/lib/ConnectedComponents.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/lib/ConnectedComponents.scala @@ -60,15 +60,15 @@ object ConnectedComponents { } // end of connectedComponents /** - * Compute the connected component membership of each vertex and return a graph with the vertex - * value containing the lowest vertex id in the connected component containing that vertex. - * - * @tparam VD the vertex attribute type (discarded in the computation) - * @tparam ED the edge attribute type (preserved in the computation) - * @param graph the graph for which to compute the connected components - * @return a graph with vertex attributes containing the smallest vertex in each - * connected component - */ + * Compute the connected component membership of each vertex and return a graph with the vertex + * value containing the lowest vertex id in the connected component containing that vertex. + * + * @tparam VD the vertex attribute type (discarded in the computation) + * @tparam ED the edge attribute type (preserved in the computation) + * @param graph the graph for which to compute the connected components + * @return a graph with vertex attributes containing the smallest vertex in each + * connected component + */ def run[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED]): Graph[VertexId, ED] = { run(graph, Int.MaxValue) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/ElementwiseProduct.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/ElementwiseProduct.scala index 2c7ffdb7ba..1b0a9a12e8 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/ElementwiseProduct.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/ElementwiseProduct.scala @@ -38,9 +38,9 @@ class ElementwiseProduct(override val uid: String) def this() = this(Identifiable.randomUID("elemProd")) /** - * the vector to multiply with input vectors - * @group param - */ + * the vector to multiply with input vectors + * @group param + */ val scalingVec: Param[Vector] = new Param(this, "scalingVec", "vector for hadamard product") /** @group setParam */ diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/GaussianMixtureModelWrapper.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/GaussianMixtureModelWrapper.scala index a689b09341..364d5eea08 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/GaussianMixtureModelWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/GaussianMixtureModelWrapper.scala @@ -24,15 +24,15 @@ import org.apache.spark.mllib.clustering.GaussianMixtureModel import org.apache.spark.mllib.linalg.{Vector, Vectors} /** - * Wrapper around GaussianMixtureModel to provide helper methods in Python - */ + * Wrapper around GaussianMixtureModel to provide helper methods in Python + */ private[python] class GaussianMixtureModelWrapper(model: GaussianMixtureModel) { val weights: Vector = Vectors.dense(model.weights) val k: Int = weights.size /** - * Returns gaussians as a List of Vectors and Matrices corresponding each MultivariateGaussian - */ + * Returns gaussians as a List of Vectors and Matrices corresponding each MultivariateGaussian + */ val gaussians: Array[Byte] = { val modelGaussians = model.gaussians.map { gaussian => Array[Any](gaussian.mu, gaussian.sigma) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/Word2VecModelWrapper.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/Word2VecModelWrapper.scala index 073f03e16f..05273c3434 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/Word2VecModelWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/Word2VecModelWrapper.scala @@ -27,8 +27,8 @@ import org.apache.spark.mllib.feature.Word2VecModel import org.apache.spark.mllib.linalg.{Vector, Vectors} /** - * Wrapper around Word2VecModel to provide helper methods in Python - */ + * Wrapper around Word2VecModel to provide helper methods in Python + */ private[python] class Word2VecModelWrapper(model: Word2VecModel) { def transform(word: String): Vector = { model.transform(word) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala index 6e571fe35a..8c09b69b3c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala @@ -123,14 +123,18 @@ sealed trait Matrix extends Serializable { @Since("1.4.0") def toString(maxLines: Int, maxLineWidth: Int): String = toBreeze.toString(maxLines, maxLineWidth) - /** Map the values of this matrix using a function. Generates a new matrix. Performs the - * function on only the backing array. For example, an operation such as addition or - * subtraction will only be performed on the non-zero values in a `SparseMatrix`. */ + /** + * Map the values of this matrix using a function. Generates a new matrix. Performs the + * function on only the backing array. For example, an operation such as addition or + * subtraction will only be performed on the non-zero values in a `SparseMatrix`. + */ private[spark] def map(f: Double => Double): Matrix - /** Update all the values of this matrix using the function f. Performed in-place on the - * backing array. For example, an operation such as addition or subtraction will only be - * performed on the non-zero values in a `SparseMatrix`. */ + /** + * Update all the values of this matrix using the function f. Performed in-place on the + * backing array. For example, an operation such as addition or subtraction will only be + * performed on the non-zero values in a `SparseMatrix`. + */ private[mllib] def update(f: Double => Double): Matrix /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionWithSGD.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionWithSGD.scala index e8f4422fd4..84764963b5 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionWithSGD.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionWithSGD.scala @@ -81,8 +81,8 @@ class StreamingLinearRegressionWithSGD private[mllib] ( } /** - * Set the number of iterations of gradient descent to run per update. Default: 50. - */ + * Set the number of iterations of gradient descent to run per update. Default: 50. + */ @Since("1.1.0") def setNumIterations(numIterations: Int): this.type = { this.algorithm.optimizer.setNumIterations(numIterations) diff --git a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoop.scala b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoop.scala index 67a616dc15..c5dc6ba221 100644 --- a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoop.scala +++ b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoop.scala @@ -797,9 +797,11 @@ class SparkILoop( // echo("Switched " + (if (old) "off" else "on") + " result printing.") } - /** Run one command submitted by the user. Two values are returned: - * (1) whether to keep running, (2) the line to record for replay, - * if any. */ + /** + * Run one command submitted by the user. Two values are returned: + * (1) whether to keep running, (2) the line to record for replay, + * if any. + */ private[repl] def command(line: String): Result = { if (line startsWith ":") { val cmd = line.tail takeWhile (x => !x.isWhitespace) @@ -841,12 +843,13 @@ class SparkILoop( } import paste.{ ContinueString, PromptString } - /** Interpret expressions starting with the first line. - * Read lines until a complete compilation unit is available - * or until a syntax error has been seen. If a full unit is - * read, go ahead and interpret it. Return the full string - * to be recorded for replay, if any. - */ + /** + * Interpret expressions starting with the first line. + * Read lines until a complete compilation unit is available + * or until a syntax error has been seen. If a full unit is + * read, go ahead and interpret it. Return the full string + * to be recorded for replay, if any. + */ private def interpretStartingWith(code: String): Option[String] = { // signal completion non-completion input has been received in.completion.resetVerbosity() diff --git a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkImports.scala b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkImports.scala index 1d0fe10d3d..f22776592c 100644 --- a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkImports.scala +++ b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkImports.scala @@ -118,8 +118,9 @@ private[repl] trait SparkImports { case class ReqAndHandler(req: Request, handler: MemberHandler) { } def reqsToUse: List[ReqAndHandler] = { - /** Loop through a list of MemberHandlers and select which ones to keep. - * 'wanted' is the set of names that need to be imported. + /** + * Loop through a list of MemberHandlers and select which ones to keep. + * 'wanted' is the set of names that need to be imported. */ def select(reqs: List[ReqAndHandler], wanted: Set[Name]): List[ReqAndHandler] = { // Single symbol imports might be implicits! See bug #1752. Rather than diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala index 1f20e26354..e0bfe3c32f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala @@ -140,27 +140,27 @@ object Encoders { def STRING: Encoder[java.lang.String] = ExpressionEncoder() /** - * An encoder for nullable decimal type. - * @since 1.6.0 - */ + * An encoder for nullable decimal type. + * @since 1.6.0 + */ def DECIMAL: Encoder[java.math.BigDecimal] = ExpressionEncoder() /** - * An encoder for nullable date type. - * @since 1.6.0 - */ + * An encoder for nullable date type. + * @since 1.6.0 + */ def DATE: Encoder[java.sql.Date] = ExpressionEncoder() /** - * An encoder for nullable timestamp type. - * @since 1.6.0 - */ + * An encoder for nullable timestamp type. + * @since 1.6.0 + */ def TIMESTAMP: Encoder[java.sql.Timestamp] = ExpressionEncoder() /** - * An encoder for arrays of bytes. - * @since 1.6.1 - */ + * An encoder for arrays of bytes. + * @since 1.6.1 + */ def BINARY: Encoder[Array[Byte]] = ExpressionEncoder() /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 05e2b9a447..a6e317ebf0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -733,9 +733,9 @@ class Analyzer( } /** - * Add the missing attributes into projectList of Project/Window or aggregateExpressions of - * Aggregate. - */ + * Add the missing attributes into projectList of Project/Window or aggregateExpressions of + * Aggregate. + */ private def addMissingAttr(plan: LogicalPlan, missingAttrs: AttributeSet): LogicalPlan = { if (missingAttrs.isEmpty) { return plan @@ -767,9 +767,9 @@ class Analyzer( } /** - * Resolve the expression on a specified logical plan and it's child (recursively), until - * the expression is resolved or meet a non-unary node or Subquery. - */ + * Resolve the expression on a specified logical plan and it's child (recursively), until + * the expression is resolved or meet a non-unary node or Subquery. + */ @tailrec private def resolveExpressionRecursively(expr: Expression, plan: LogicalPlan): Expression = { val resolved = resolveExpression(expr, plan) @@ -1398,8 +1398,8 @@ class Analyzer( } /** - * Check and add order to [[AggregateWindowFunction]]s. - */ + * Check and add order to [[AggregateWindowFunction]]s. + */ object ResolveWindowOrder extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { case logical: LogicalPlan => logical transformExpressions { @@ -1489,8 +1489,8 @@ object EliminateSubqueryAliases extends Rule[LogicalPlan] { } /** - * Removes [[Union]] operators from the plan if it just has one child. - */ + * Removes [[Union]] operators from the plan if it just has one child. + */ object EliminateUnions extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { case Union(children) if children.size == 1 => children.head diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala index 053e612f3e..354311c5e7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala @@ -136,9 +136,9 @@ object UnsafeProjection { } /** - * Same as other create()'s but allowing enabling/disabling subexpression elimination. - * TODO: refactor the plumbing and clean this up. - */ + * Same as other create()'s but allowing enabling/disabling subexpression elimination. + * TODO: refactor the plumbing and clean this up. + */ def create( exprs: Seq[Expression], inputSchema: Seq[Attribute], diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index cd490dd676..b64d3eea49 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -58,10 +58,10 @@ class CodegenContext { val references: mutable.ArrayBuffer[Any] = new mutable.ArrayBuffer[Any]() /** - * Add an object to `references`, create a class member to access it. - * - * Returns the name of class member. - */ + * Add an object to `references`, create a class member to access it. + * + * Returns the name of class member. + */ def addReferenceObj(name: String, obj: Any, className: String = null): String = { val term = freshName(name) val idx = references.length @@ -72,9 +72,9 @@ class CodegenContext { } /** - * Holding a list of generated columns as input of current operator, will be used by - * BoundReference to generate code. - */ + * Holding a list of generated columns as input of current operator, will be used by + * BoundReference to generate code. + */ var currentVars: Seq[ExprCode] = null /** @@ -169,14 +169,14 @@ class CodegenContext { final var INPUT_ROW = "i" /** - * The map from a variable name to it's next ID. - */ + * The map from a variable name to it's next ID. + */ private val freshNameIds = new mutable.HashMap[String, Int] freshNameIds += INPUT_ROW -> 1 /** - * A prefix used to generate fresh name. - */ + * A prefix used to generate fresh name. + */ var freshNamePrefix = "" /** @@ -234,8 +234,8 @@ class CodegenContext { } /** - * Update a column in MutableRow from ExprCode. - */ + * Update a column in MutableRow from ExprCode. + */ def updateColumn( row: String, dataType: DataType, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/grouping.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/grouping.scala index 437e417266..3be761c867 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/grouping.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/grouping.scala @@ -22,8 +22,8 @@ import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.types._ /** - * A placeholder expression for cube/rollup, which will be replaced by analyzer - */ + * A placeholder expression for cube/rollup, which will be replaced by analyzer + */ trait GroupingSet extends Expression with CodegenFallback { def groupByExprs: Seq[Expression] @@ -43,9 +43,9 @@ case class Cube(groupByExprs: Seq[Expression]) extends GroupingSet {} case class Rollup(groupByExprs: Seq[Expression]) extends GroupingSet {} /** - * Indicates whether a specified column expression in a GROUP BY list is aggregated or not. - * GROUPING returns 1 for aggregated or 0 for not aggregated in the result set. - */ + * Indicates whether a specified column expression in a GROUP BY list is aggregated or not. + * GROUPING returns 1 for aggregated or 0 for not aggregated in the result set. + */ case class Grouping(child: Expression) extends Expression with Unevaluable { override def references: AttributeSet = AttributeSet(VirtualColumn.groupingIdAttribute :: Nil) override def children: Seq[Expression] = child :: Nil @@ -54,10 +54,10 @@ case class Grouping(child: Expression) extends Expression with Unevaluable { } /** - * GroupingID is a function that computes the level of grouping. - * - * If groupByExprs is empty, it means all grouping expressions in GroupingSets. - */ + * GroupingID is a function that computes the level of grouping. + * + * If groupByExprs is empty, it means all grouping expressions in GroupingSets. + */ case class GroupingID(groupByExprs: Seq[Expression]) extends Expression with Unevaluable { override def references: AttributeSet = AttributeSet(VirtualColumn.groupingIdAttribute :: Nil) override def children: Seq[Expression] = groupByExprs diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala index e8a3e129b4..eb8dc1423a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala @@ -467,8 +467,8 @@ object Murmur3HashFunction extends InterpretedHashFunction { } /** - * Print the result of an expression to stderr (used for debugging codegen). - */ + * Print the result of an expression to stderr (used for debugging codegen). + */ case class PrintToStderr(child: Expression) extends UnaryExpression { override def dataType: DataType = child.dataType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index a5ab390c76..69b09bcb35 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -31,9 +31,9 @@ import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.types._ /** - * Abstract class all optimizers should inherit of, contains the standard batches (extending - * Optimizers can override this. - */ + * Abstract class all optimizers should inherit of, contains the standard batches (extending + * Optimizers can override this. + */ abstract class Optimizer extends RuleExecutor[LogicalPlan] { def batches: Seq[Batch] = { // Technically some of the rules in Finish Analysis are not optimizer rules and belong more @@ -111,11 +111,11 @@ abstract class Optimizer extends RuleExecutor[LogicalPlan] { } /** - * Non-abstract representation of the standard Spark optimizing strategies - * - * To ensure extendability, we leave the standard rules in the abstract optimizer rules, while - * specific rules go to the subclasses - */ + * Non-abstract representation of the standard Spark optimizing strategies + * + * To ensure extendability, we leave the standard rules in the abstract optimizer rules, while + * specific rules go to the subclasses + */ object DefaultOptimizer extends Optimizer /** @@ -962,21 +962,21 @@ object PushPredicateThroughAggregate extends Rule[LogicalPlan] with PredicateHel } /** - * Reorder the joins and push all the conditions into join, so that the bottom ones have at least - * one condition. - * - * The order of joins will not be changed if all of them already have at least one condition. - */ + * Reorder the joins and push all the conditions into join, so that the bottom ones have at least + * one condition. + * + * The order of joins will not be changed if all of them already have at least one condition. + */ object ReorderJoin extends Rule[LogicalPlan] with PredicateHelper { /** - * Join a list of plans together and push down the conditions into them. - * - * The joined plan are picked from left to right, prefer those has at least one join condition. - * - * @param input a list of LogicalPlans to join. - * @param conditions a list of condition for join. - */ + * Join a list of plans together and push down the conditions into them. + * + * The joined plan are picked from left to right, prefer those has at least one join condition. + * + * @param input a list of LogicalPlans to join. + * @param conditions a list of condition for join. + */ @tailrec def createOrderedJoin(input: Seq[LogicalPlan], conditions: Seq[Expression]): LogicalPlan = { assert(input.size >= 2) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index c350f3049f..8541b1f7c6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -1430,8 +1430,8 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { } /** - * Create a [[StructType]] from a sequence of [[StructField]]s. - */ + * Create a [[StructType]] from a sequence of [[StructField]]s. + */ protected def createStructType(ctx: ColTypeListContext): StructType = { StructType(Option(ctx).toSeq.flatMap(visitColTypeList)) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index 28d2c445b1..6f35d87ebb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -140,20 +140,20 @@ object ExtractEquiJoinKeys extends Logging with PredicateHelper { } /** - * A pattern that collects the filter and inner joins. - * - * Filter - * | - * inner Join - * / \ ----> (Seq(plan0, plan1, plan2), conditions) - * Filter plan2 - * | - * inner join - * / \ - * plan0 plan1 - * - * Note: This pattern currently only works for left-deep trees. - */ + * A pattern that collects the filter and inner joins. + * + * Filter + * | + * inner Join + * / \ ----> (Seq(plan0, plan1, plan2), conditions) + * Filter plan2 + * | + * inner join + * / \ + * plan0 plan1 + * + * Note: This pattern currently only works for left-deep trees. + */ object ExtractFiltersAndInnerJoins extends PredicateHelper { // flatten all inner joins, which are next to each other diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index 22a4461e66..609a33e2f1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -122,8 +122,8 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT AttributeSet(children.flatMap(_.asInstanceOf[QueryPlan[PlanType]].output)) /** - * The set of all attributes that are produced by this node. - */ + * The set of all attributes that are produced by this node. + */ def producedAttributes: AttributeSet = AttributeSet.empty /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index be9f1ffa22..d449088498 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -76,9 +76,9 @@ case class OrderedDistribution(ordering: Seq[SortOrder]) extends Distribution { } /** - * Represents data where tuples are broadcasted to every node. It is quite common that the - * entire set of tuples is transformed into different data structure. - */ + * Represents data where tuples are broadcasted to every node. It is quite common that the + * entire set of tuples is transformed into different data structure. + */ case class BroadcastDistribution(mode: BroadcastMode) extends Distribution /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizerExtendableSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizerExtendableSuite.scala index 7e3da6bea7..6e5672ddc3 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizerExtendableSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizerExtendableSuite.scala @@ -23,21 +23,21 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.Rule /** - * This is a test for SPARK-7727 if the Optimizer is kept being extendable - */ + * This is a test for SPARK-7727 if the Optimizer is kept being extendable + */ class OptimizerExtendableSuite extends SparkFunSuite { /** - * Dummy rule for test batches - */ + * Dummy rule for test batches + */ object DummyRule extends Rule[LogicalPlan] { def apply(p: LogicalPlan): LogicalPlan = p } /** - * This class represents a dummy extended optimizer that takes the batches of the - * Optimizer and adds custom ones. - */ + * This class represents a dummy extended optimizer that takes the batches of the + * Optimizer and adds custom ones. + */ class ExtendedOptimizer extends Optimizer { // rules set to DummyRule, would not be executed anyways diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 221782ee8f..d4290fee0a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -712,13 +712,13 @@ class SQLContext private[sql]( } /** - * :: Experimental :: - * Creates a [[Dataset]] with a single [[LongType]] column named `id`, containing elements - * in an range from `start` to `end` (exclusive) with an step value. - * - * @since 2.0.0 - * @group dataset - */ + * :: Experimental :: + * Creates a [[Dataset]] with a single [[LongType]] column named `id`, containing elements + * in an range from `start` to `end` (exclusive) with an step value. + * + * @since 2.0.0 + * @group dataset + */ @Experimental def range(start: Long, end: Long, step: Long): Dataset[java.lang.Long] = { range(start, end, step, numPartitions = sparkContext.defaultParallelism) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala index f3478a873a..124ec09efd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala @@ -109,9 +109,10 @@ private[sql] class CacheManager extends Logging { cachedData.remove(dataIndex) } - /** Tries to remove the data for the given [[Dataset]] from the cache - * if it's cached - */ + /** + * Tries to remove the data for the given [[Dataset]] from the cache + * if it's cached + */ private[sql] def tryUncacheQuery( query: Dataset[_], blocking: Boolean = true): Boolean = writeLock { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index b1b3d4ac81..ff19d1be1c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -84,8 +84,8 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ private[sql] def metrics: Map[String, SQLMetric[_, _]] = Map.empty /** - * Reset all the metrics. - */ + * Reset all the metrics. + */ private[sql] def resetMetrics(): Unit = { metrics.valuesIterator.foreach(_.reset()) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala index 9bdf611f6e..9f539c4929 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala @@ -31,8 +31,8 @@ import org.apache.spark.sql.execution.metric.{LongSQLMetricValue, SQLMetrics} import org.apache.spark.sql.internal.SQLConf /** - * An interface for those physical operators that support codegen. - */ + * An interface for those physical operators that support codegen. + */ trait CodegenSupport extends SparkPlan { /** Prefix used in the current operator's variable names. */ @@ -46,10 +46,10 @@ trait CodegenSupport extends SparkPlan { } /** - * Creates a metric using the specified name. - * - * @return name of the variable representing the metric - */ + * Creates a metric using the specified name. + * + * @return name of the variable representing the metric + */ def metricTerm(ctx: CodegenContext, name: String): String = { val metric = ctx.addReferenceObj(name, longMetric(name)) val value = ctx.freshName("metricValue") @@ -59,25 +59,25 @@ trait CodegenSupport extends SparkPlan { } /** - * Whether this SparkPlan support whole stage codegen or not. - */ + * Whether this SparkPlan support whole stage codegen or not. + */ def supportCodegen: Boolean = true /** - * Which SparkPlan is calling produce() of this one. It's itself for the first SparkPlan. - */ + * Which SparkPlan is calling produce() of this one. It's itself for the first SparkPlan. + */ protected var parent: CodegenSupport = null /** - * Returns all the RDDs of InternalRow which generates the input rows. - * - * Note: right now we support up to two RDDs. - */ + * Returns all the RDDs of InternalRow which generates the input rows. + * + * Note: right now we support up to two RDDs. + */ def upstreams(): Seq[RDD[InternalRow]] /** - * Returns Java source code to process the rows from upstream. - */ + * Returns Java source code to process the rows from upstream. + */ final def produce(ctx: CodegenContext, parent: CodegenSupport): String = { this.parent = parent ctx.freshNamePrefix = variablePrefix @@ -89,28 +89,28 @@ trait CodegenSupport extends SparkPlan { } /** - * Generate the Java source code to process, should be overridden by subclass to support codegen. - * - * doProduce() usually generate the framework, for example, aggregation could generate this: - * - * if (!initialized) { - * # create a hash map, then build the aggregation hash map - * # call child.produce() - * initialized = true; - * } - * while (hashmap.hasNext()) { - * row = hashmap.next(); - * # build the aggregation results - * # create variables for results - * # call consume(), which will call parent.doConsume() + * Generate the Java source code to process, should be overridden by subclass to support codegen. + * + * doProduce() usually generate the framework, for example, aggregation could generate this: + * + * if (!initialized) { + * # create a hash map, then build the aggregation hash map + * # call child.produce() + * initialized = true; + * } + * while (hashmap.hasNext()) { + * row = hashmap.next(); + * # build the aggregation results + * # create variables for results + * # call consume(), which will call parent.doConsume() * if (shouldStop()) return; - * } - */ + * } + */ protected def doProduce(ctx: CodegenContext): String /** - * Consume the generated columns or row from current SparkPlan, call it's parent's doConsume(). - */ + * Consume the generated columns or row from current SparkPlan, call it's parent's doConsume(). + */ final def consume(ctx: CodegenContext, outputVars: Seq[ExprCode], row: String = null): String = { val inputVars = if (row != null) { @@ -158,9 +158,9 @@ trait CodegenSupport extends SparkPlan { } /** - * Returns source code to evaluate all the variables, and clear the code of them, to prevent - * them to be evaluated twice. - */ + * Returns source code to evaluate all the variables, and clear the code of them, to prevent + * them to be evaluated twice. + */ protected def evaluateVariables(variables: Seq[ExprCode]): String = { val evaluate = variables.filter(_.code != "").map(_.code.trim).mkString("\n") variables.foreach(_.code = "") @@ -168,9 +168,9 @@ trait CodegenSupport extends SparkPlan { } /** - * Returns source code to evaluate the variables for required attributes, and clear the code - * of evaluated variables, to prevent them to be evaluated twice.. - */ + * Returns source code to evaluate the variables for required attributes, and clear the code + * of evaluated variables, to prevent them to be evaluated twice.. + */ protected def evaluateRequiredVariables( attributes: Seq[Attribute], variables: Seq[ExprCode], @@ -194,18 +194,18 @@ trait CodegenSupport extends SparkPlan { def usedInputs: AttributeSet = references /** - * Generate the Java source code to process the rows from child SparkPlan. - * - * This should be override by subclass to support codegen. - * - * For example, Filter will generate the code like this: - * - * # code to evaluate the predicate expression, result is isNull1 and value2 - * if (isNull1 || !value2) continue; - * # call consume(), which will call parent.doConsume() - * - * Note: A plan can either consume the rows as UnsafeRow (row), or a list of variables (input). - */ + * Generate the Java source code to process the rows from child SparkPlan. + * + * This should be override by subclass to support codegen. + * + * For example, Filter will generate the code like this: + * + * # code to evaluate the predicate expression, result is isNull1 and value2 + * if (isNull1 || !value2) continue; + * # call consume(), which will call parent.doConsume() + * + * Note: A plan can either consume the rows as UnsafeRow (row), or a list of variables (input). + */ def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = { throw new UnsupportedOperationException } @@ -213,11 +213,11 @@ trait CodegenSupport extends SparkPlan { /** - * InputAdapter is used to hide a SparkPlan from a subtree that support codegen. - * - * This is the leaf node of a tree with WholeStageCodegen, is used to generate code that consumes - * an RDD iterator of InternalRow. - */ + * InputAdapter is used to hide a SparkPlan from a subtree that support codegen. + * + * This is the leaf node of a tree with WholeStageCodegen, is used to generate code that consumes + * an RDD iterator of InternalRow. + */ case class InputAdapter(child: SparkPlan) extends UnaryNode with CodegenSupport { override def output: Seq[Attribute] = child.output @@ -260,33 +260,33 @@ object WholeStageCodegen { } /** - * WholeStageCodegen compile a subtree of plans that support codegen together into single Java - * function. - * - * Here is the call graph of to generate Java source (plan A support codegen, but plan B does not): - * - * WholeStageCodegen Plan A FakeInput Plan B - * ========================================================================= - * - * -> execute() - * | - * doExecute() ---------> upstreams() -------> upstreams() ------> execute() - * | - * +-----------------> produce() - * | - * doProduce() -------> produce() - * | - * doProduce() - * | - * doConsume() <--------- consume() - * | - * doConsume() <-------- consume() - * - * SparkPlan A should override doProduce() and doConsume(). - * - * doCodeGen() will create a CodeGenContext, which will hold a list of variables for input, - * used to generated code for BoundReference. - */ + * WholeStageCodegen compile a subtree of plans that support codegen together into single Java + * function. + * + * Here is the call graph of to generate Java source (plan A support codegen, but plan B does not): + * + * WholeStageCodegen Plan A FakeInput Plan B + * ========================================================================= + * + * -> execute() + * | + * doExecute() ---------> upstreams() -------> upstreams() ------> execute() + * | + * +-----------------> produce() + * | + * doProduce() -------> produce() + * | + * doProduce() + * | + * doConsume() <--------- consume() + * | + * doConsume() <-------- consume() + * + * SparkPlan A should override doProduce() and doConsume(). + * + * doCodeGen() will create a CodeGenContext, which will hold a list of variables for input, + * used to generated code for BoundReference. + */ case class WholeStageCodegen(child: SparkPlan) extends UnaryNode with CodegenSupport { override def output: Seq[Attribute] = child.output @@ -422,8 +422,8 @@ case class WholeStageCodegen(child: SparkPlan) extends UnaryNode with CodegenSup /** - * Find the chained plans that support codegen, collapse them together as WholeStageCodegen. - */ + * Find the chained plans that support codegen, collapse them together as WholeStageCodegen. + */ case class CollapseCodegenStages(conf: SQLConf) extends Rule[SparkPlan] { private def supportCodegen(e: Expression): Boolean = e match { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala index 7d0567842c..806089196c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala @@ -444,8 +444,8 @@ private[execution] final case class RangeBoundOrdering( } /** - * The interface of row buffer for a partition - */ + * The interface of row buffer for a partition + */ private[execution] abstract class RowBuffer { /** Number of rows. */ @@ -462,8 +462,8 @@ private[execution] abstract class RowBuffer { } /** - * A row buffer based on ArrayBuffer (the number of rows is limited) - */ + * A row buffer based on ArrayBuffer (the number of rows is limited) + */ private[execution] class ArrayRowBuffer(buffer: ArrayBuffer[UnsafeRow]) extends RowBuffer { private[this] var cursor: Int = -1 @@ -493,8 +493,8 @@ private[execution] class ArrayRowBuffer(buffer: ArrayBuffer[UnsafeRow]) extends } /** - * An external buffer of rows based on UnsafeExternalSorter - */ + * An external buffer of rows based on UnsafeExternalSorter + */ private[execution] class ExternalRowBuffer(sorter: UnsafeExternalSorter, numFields: Int) extends RowBuffer { @@ -654,12 +654,16 @@ private[execution] final class SlidingWindowFunctionFrame( /** The rows within current sliding window. */ private[this] val buffer = new util.ArrayDeque[InternalRow]() - /** Index of the first input row with a value greater than the upper bound of the current - * output row. */ + /** + * Index of the first input row with a value greater than the upper bound of the current + * output row. + */ private[this] var inputHighIndex = 0 - /** Index of the first input row with a value equal to or greater than the lower bound of the - * current output row. */ + /** + * Index of the first input row with a value equal to or greater than the lower bound of the + * current output row. + */ private[this] var inputLowIndex = 0 /** Prepare the frame for calculating a new partition. Reset all variables. */ @@ -763,8 +767,10 @@ private[execution] final class UnboundedPrecedingWindowFunctionFrame( /** The next row from `input`. */ private[this] var nextRow: InternalRow = null - /** Index of the first input row with a value greater than the upper bound of the current - * output row. */ + /** + * Index of the first input row with a value greater than the upper bound of the current + * output row. + */ private[this] var inputIndex = 0 /** Prepare the frame for calculating a new partition. */ @@ -819,8 +825,10 @@ private[execution] final class UnboundedFollowingWindowFunctionFrame( /** Rows of the partition currently being processed. */ private[this] var input: RowBuffer = null - /** Index of the first input row with a value equal to or greater than the lower bound of the - * current output row. */ + /** + * Index of the first input row with a value equal to or greater than the lower bound of the + * current output row. + */ private[this] var inputIndex = 0 /** Prepare the frame for calculating a new partition. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala index 15627a7004..042c731901 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala @@ -47,17 +47,17 @@ abstract class AggregationIterator( /////////////////////////////////////////////////////////////////////////// /** - * The following combinations of AggregationMode are supported: - * - Partial - * - PartialMerge (for single distinct) - * - Partial and PartialMerge (for single distinct) - * - Final - * - Complete (for SortBasedAggregate with functions that does not support Partial) - * - Final and Complete (currently not used) - * - * TODO: AggregateMode should have only two modes: Update and Merge, AggregateExpression - * could have a flag to tell it's final or not. - */ + * The following combinations of AggregationMode are supported: + * - Partial + * - PartialMerge (for single distinct) + * - Partial and PartialMerge (for single distinct) + * - Final + * - Complete (for SortBasedAggregate with functions that does not support Partial) + * - Final and Complete (currently not used) + * + * TODO: AggregateMode should have only two modes: Update and Merge, AggregateExpression + * could have a flag to tell it's final or not. + */ { val modes = aggregateExpressions.map(_.mode).distinct.toSet require(modes.size <= 2, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala index 8f974980bb..de1491d357 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala @@ -46,9 +46,9 @@ class SortBasedAggregationIterator( newMutableProjection) { /** - * Creates a new aggregation buffer and initializes buffer values - * for all aggregate functions. - */ + * Creates a new aggregation buffer and initializes buffer values + * for all aggregate functions. + */ private def newBuffer: MutableRow = { val bufferSchema = aggregateFunctions.flatMap(_.aggBufferAttributes) val bufferRowSize: Int = bufferSchema.length diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala index 7c215d1b96..60027edc7c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala @@ -266,8 +266,8 @@ case class TungstenAggregate( private var sorterTerm: String = _ /** - * This is called by generated Java class, should be public. - */ + * This is called by generated Java class, should be public. + */ def createHashMap(): UnsafeFixedWidthAggregationMap = { // create initialized aggregate buffer val initExpr = declFunctions.flatMap(f => f.initialValues) @@ -286,15 +286,15 @@ case class TungstenAggregate( } /** - * This is called by generated Java class, should be public. - */ + * This is called by generated Java class, should be public. + */ def createUnsafeJoiner(): UnsafeRowJoiner = { GenerateUnsafeRowJoiner.create(groupingKeySchema, bufferSchema) } /** - * Called by generated Java class to finish the aggregate and return a KVIterator. - */ + * Called by generated Java class to finish the aggregate and return a KVIterator. + */ def finishAggregate( hashMap: UnsafeFixedWidthAggregationMap, sorter: UnsafeKVExternalSorter): KVIterator[UnsafeRow, UnsafeRow] = { @@ -372,8 +372,8 @@ case class TungstenAggregate( } /** - * Generate the code for output. - */ + * Generate the code for output. + */ private def generateResultCode( ctx: CodegenContext, keyTerm: String, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SqlNewHadoopRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SqlNewHadoopRDD.scala index f3514cd14c..159fdc99dd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SqlNewHadoopRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SqlNewHadoopRDD.scala @@ -168,10 +168,10 @@ private[spark] class SqlNewHadoopRDD[V: ClassTag]( private[this] var reader: RecordReader[Void, V] = null /** - * If the format is ParquetInputFormat, try to create the optimized RecordReader. If this - * fails (for example, unsupported schema), try with the normal reader. - * TODO: plumb this through a different way? - */ + * If the format is ParquetInputFormat, try to create the optimized RecordReader. If this + * fails (for example, unsupported schema), try with the normal reader. + * TODO: plumb this through a different way? + */ if (enableVectorizedParquetReader && format.getClass.getName == "org.apache.parquet.hadoop.ParquetInputFormat") { val parquetReader: VectorizedParquetRecordReader = new VectorizedParquetRecordReader() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala index 797f740dc5..ea843a1013 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala @@ -33,11 +33,11 @@ import org.apache.spark.unsafe.types.UTF8String private[csv] object CSVInferSchema { /** - * Similar to the JSON schema inference - * 1. Infer type of each row - * 2. Merge row types to find common type - * 3. Replace any null types with string type - */ + * Similar to the JSON schema inference + * 1. Infer type of each row + * 2. Merge row types to find common type + * 3. Replace any null types with string type + */ def infer( tokenRdd: RDD[Array[String]], header: Array[String], @@ -75,9 +75,9 @@ private[csv] object CSVInferSchema { } /** - * Infer type of string field. Given known type Double, and a string "1", there is no - * point checking if it is an Int, as the final type must be Double or higher. - */ + * Infer type of string field. Given known type Double, and a string "1", there is no + * point checking if it is an Int, as the final type must be Double or higher. + */ def inferField(typeSoFar: DataType, field: String, nullValue: String = ""): DataType = { if (field == null || field.isEmpty || field == nullValue) { typeSoFar @@ -142,9 +142,9 @@ private[csv] object CSVInferSchema { private val numericPrecedence: IndexedSeq[DataType] = HiveTypeCoercion.numericPrecedence /** - * Copied from internal Spark api - * [[org.apache.spark.sql.catalyst.analysis.HiveTypeCoercion]] - */ + * Copied from internal Spark api + * [[org.apache.spark.sql.catalyst.analysis.HiveTypeCoercion]] + */ val findTightestCommonType: (DataType, DataType) => Option[DataType] = { case (t1, t2) if t1 == t2 => Some(t1) case (NullType, t1) => Some(t1) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/DefaultSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/DefaultSource.scala index c0d6f6fbf7..34fcbdf871 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/DefaultSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/DefaultSource.scala @@ -38,8 +38,8 @@ import org.apache.spark.util.SerializableConfiguration import org.apache.spark.util.collection.BitSet /** - * Provides access to CSV data from pure SQL statements. - */ + * Provides access to CSV data from pure SQL statements. + */ class DefaultSource extends FileFormat with DataSourceRegister { override def shortName(): String = "csv" diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala index 877e159fbd..2e88d588be 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala @@ -51,11 +51,11 @@ case class DescribeCommand( } /** - * Used to represent the operation of create table using a data source. + * Used to represent the operation of create table using a data source. * - * @param allowExisting If it is true, we will do nothing when the table already exists. - * If it is false, an exception will be thrown - */ + * @param allowExisting If it is true, we will do nothing when the table already exists. + * If it is false, an exception will be thrown + */ case class CreateTableUsing( tableIdent: TableIdentifier, userSpecifiedSchema: Option[StructType], diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala index 0ed1ed41b0..41e566c27b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala @@ -122,8 +122,8 @@ case class BroadcastHashJoin( } /** - * Returns a tuple of Broadcast of HashedRelation and the variable name for it. - */ + * Returns a tuple of Broadcast of HashedRelation and the variable name for it. + */ private def prepareBroadcast(ctx: CodegenContext): (Broadcast[HashedRelation], String) = { // create a name for HashedRelation val broadcastRelation = buildPlan.executeBroadcast[HashedRelation]() @@ -139,9 +139,9 @@ case class BroadcastHashJoin( } /** - * Returns the code for generating join key for stream side, and expression of whether the key - * has any null in it or not. - */ + * Returns the code for generating join key for stream side, and expression of whether the key + * has any null in it or not. + */ private def genStreamSideJoinKey( ctx: CodegenContext, input: Seq[ExprCode]): (ExprCode, String) = { @@ -160,8 +160,8 @@ case class BroadcastHashJoin( } /** - * Generates the code for variable of build side. - */ + * Generates the code for variable of build side. + */ private def genBuildSideVars(ctx: CodegenContext, matched: String): Seq[ExprCode] = { ctx.currentVars = null ctx.INPUT_ROW = matched @@ -188,8 +188,8 @@ case class BroadcastHashJoin( } /** - * Generates the code for Inner join. - */ + * Generates the code for Inner join. + */ private def codegenInner(ctx: CodegenContext, input: Seq[ExprCode]): String = { val (broadcastRelation, relationTerm) = prepareBroadcast(ctx) val (keyEv, anyNull) = genStreamSideJoinKey(ctx, input) @@ -254,8 +254,8 @@ case class BroadcastHashJoin( /** - * Generates the code for left or right outer join. - */ + * Generates the code for left or right outer join. + */ private def codegenOuter(ctx: CodegenContext, input: Seq[ExprCode]): String = { val (broadcastRelation, relationTerm) = prepareBroadcast(ctx) val (keyEv, anyNull) = genStreamSideJoinKey(ctx, input) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala index fb65b50da8..edb4c5a16f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala @@ -28,10 +28,10 @@ import org.apache.spark.util.CompletionIterator import org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter /** - * An optimized CartesianRDD for UnsafeRow, which will cache the rows from second child RDD, - * will be much faster than building the right partition for every row in left RDD, it also - * materialize the right RDD (in case of the right RDD is nondeterministic). - */ + * An optimized CartesianRDD for UnsafeRow, which will cache the rows from second child RDD, + * will be much faster than building the right partition for every row in left RDD, it also + * materialize the right RDD (in case of the right RDD is nondeterministic). + */ private[spark] class UnsafeCartesianRDD(left : RDD[UnsafeRow], right : RDD[UnsafeRow], numFieldsOfRight: Int) extends CartesianRDD[UnsafeRow, UnsafeRow](left.sparkContext, left, right) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala index 5f42d07273..c298b7dee0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala @@ -64,10 +64,10 @@ trait HashJoin { } /** - * Try to rewrite the key as LongType so we can use getLong(), if they key can fit with a long. - * - * If not, returns the original expressions. - */ + * Try to rewrite the key as LongType so we can use getLong(), if they key can fit with a long. + * + * If not, returns the original expressions. + */ def rewriteKeyExpr(keys: Seq[Expression]): Seq[Expression] = { var keyExpr: Expression = null var width = 0 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala index dc4793e85a..91c470d187 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala @@ -38,20 +38,20 @@ import org.apache.spark.util.collection.CompactBuffer */ private[execution] sealed trait HashedRelation { /** - * Returns matched rows. - */ + * Returns matched rows. + */ def get(key: InternalRow): Seq[InternalRow] /** - * Returns matched rows for a key that has only one column with LongType. - */ + * Returns matched rows for a key that has only one column with LongType. + */ def get(key: Long): Seq[InternalRow] = { throw new UnsupportedOperationException } /** - * Returns the size of used memory. - */ + * Returns the size of used memory. + */ def getMemorySize: Long = 1L // to make the test happy /** @@ -77,20 +77,20 @@ private[execution] sealed trait HashedRelation { } /** - * Interface for a hashed relation that have only one row per key. - * - * We should call getValue() for better performance. - */ + * Interface for a hashed relation that have only one row per key. + * + * We should call getValue() for better performance. + */ private[execution] trait UniqueHashedRelation extends HashedRelation { /** - * Returns the matched single row. - */ + * Returns the matched single row. + */ def getValue(key: InternalRow): InternalRow /** - * Returns the matched single row with key that have only one column of LongType. - */ + * Returns the matched single row with key that have only one column of LongType. + */ def getValue(key: Long): InternalRow = { throw new UnsupportedOperationException } @@ -345,8 +345,8 @@ private[joins] object UnsafeHashedRelation { } /** - * An interface for a hashed relation that the key is a Long. - */ + * An interface for a hashed relation that the key is a Long. + */ private[joins] trait LongHashedRelation extends HashedRelation { override def get(key: InternalRow): Seq[InternalRow] = { get(key.getLong(0)) @@ -396,26 +396,26 @@ private[joins] final class UniqueLongHashedRelation( } /** - * A relation that pack all the rows into a byte array, together with offsets and sizes. - * - * All the bytes of UnsafeRow are packed together as `bytes`: - * - * [ Row0 ][ Row1 ][] ... [ RowN ] - * - * With keys: - * - * start start+1 ... start+N - * - * `offsets` are offsets of UnsafeRows in the `bytes` - * `sizes` are the numbers of bytes of UnsafeRows, 0 means no row for this key. - * - * For example, two UnsafeRows (24 bytes and 32 bytes), with keys as 3 and 5 will stored as: - * - * start = 3 - * offsets = [0, 0, 24] - * sizes = [24, 0, 32] - * bytes = [0 - 24][][24 - 56] - */ + * A relation that pack all the rows into a byte array, together with offsets and sizes. + * + * All the bytes of UnsafeRow are packed together as `bytes`: + * + * [ Row0 ][ Row1 ][] ... [ RowN ] + * + * With keys: + * + * start start+1 ... start+N + * + * `offsets` are offsets of UnsafeRows in the `bytes` + * `sizes` are the numbers of bytes of UnsafeRows, 0 means no row for this key. + * + * For example, two UnsafeRows (24 bytes and 32 bytes), with keys as 3 and 5 will stored as: + * + * start = 3 + * offsets = [0, 0, 24] + * sizes = [24, 0, 32] + * bytes = [0 - 24][][24 - 56] + */ private[joins] final class LongArrayRelation( private var numFields: Int, private var start: Long, @@ -483,8 +483,8 @@ private[joins] final class LongArrayRelation( } /** - * Create hashed relation with key that is long. - */ + * Create hashed relation with key that is long. + */ private[joins] object LongHashedRelation { val DENSE_FACTOR = 0.2 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala index 60bd8ea39a..0e7b2f2f31 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala @@ -256,9 +256,9 @@ case class SortMergeJoin( } /** - * Generate a function to scan both left and right to find a match, returns the term for - * matched one row from left side and buffered rows from right side. - */ + * Generate a function to scan both left and right to find a match, returns the term for + * matched one row from left side and buffered rows from right side. + */ private def genScanner(ctx: CodegenContext): (String, String) = { // Create class member for next row from both sides. val leftRow = ctx.freshName("leftRow") @@ -341,12 +341,12 @@ case class SortMergeJoin( } /** - * Creates variables for left part of result row. - * - * In order to defer the access after condition and also only access once in the loop, - * the variables should be declared separately from accessing the columns, we can't use the - * codegen of BoundReference here. - */ + * Creates variables for left part of result row. + * + * In order to defer the access after condition and also only access once in the loop, + * the variables should be declared separately from accessing the columns, we can't use the + * codegen of BoundReference here. + */ private def createLeftVars(ctx: CodegenContext, leftRow: String): Seq[ExprCode] = { ctx.INPUT_ROW = leftRow left.output.zipWithIndex.map { case (a, i) => @@ -370,9 +370,9 @@ case class SortMergeJoin( } /** - * Creates the variables for right part of result row, using BoundReference, since the right - * part are accessed inside the loop. - */ + * Creates the variables for right part of result row, using BoundReference, since the right + * part are accessed inside the loop. + */ private def createRightVar(ctx: CodegenContext, rightRow: String): Seq[ExprCode] = { ctx.INPUT_ROW = rightRow right.output.zipWithIndex.map { case (a, i) => @@ -381,12 +381,12 @@ case class SortMergeJoin( } /** - * Splits variables based on whether it's used by condition or not, returns the code to create - * these variables before the condition and after the condition. - * - * Only a few columns are used by condition, then we can skip the accessing of those columns - * that are not used by condition also filtered out by condition. - */ + * Splits variables based on whether it's used by condition or not, returns the code to create + * these variables before the condition and after the condition. + * + * Only a few columns are used by condition, then we can skip the accessing of those columns + * that are not used by condition also filtered out by condition. + */ private def splitVarsByCondition( attributes: Seq[Attribute], variables: Seq[ExprCode]): (String, String) = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala index 998eb82de1..8ece3c971a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala @@ -468,10 +468,10 @@ private[state] class HDFSBackedStateStoreProvider( } /** - * Clean up old snapshots and delta files that are not needed any more. It ensures that last - * few versions of the store can be recovered from the files, so re-executed RDD operations - * can re-apply updates on the past versions of the store. - */ + * Clean up old snapshots and delta files that are not needed any more. It ensures that last + * few versions of the store can be recovered from the files, so re-executed RDD operations + * can re-apply updates on the past versions of the store. + */ private[state] def cleanup(): Unit = { try { val files = fetchFiles() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala index 24a01f5be1..012b125d6b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala @@ -45,8 +45,8 @@ private[ui] case class SparkPlanGraph( } /** - * All the SparkPlanGraphNodes, including those inside of WholeStageCodegen. - */ + * All the SparkPlanGraphNodes, including those inside of WholeStageCodegen. + */ val allNodes: Seq[SparkPlanGraphNode] = { nodes.flatMap { case cluster: SparkPlanGraphCluster => cluster.nodes :+ cluster diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index baf947d037..da58ba2add 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -332,95 +332,94 @@ object functions { } /** - * Aggregate function: returns the first value in a group. - * - * The function by default returns the first values it sees. It will return the first non-null - * value it sees when ignoreNulls is set to true. If all values are null, then null is returned. - * - * @group agg_funcs - * @since 2.0.0 - */ + * Aggregate function: returns the first value in a group. + * + * The function by default returns the first values it sees. It will return the first non-null + * value it sees when ignoreNulls is set to true. If all values are null, then null is returned. + * + * @group agg_funcs + * @since 2.0.0 + */ def first(e: Column, ignoreNulls: Boolean): Column = withAggregateFunction { new First(e.expr, Literal(ignoreNulls)) } /** - * Aggregate function: returns the first value of a column in a group. - * - * The function by default returns the first values it sees. It will return the first non-null - * value it sees when ignoreNulls is set to true. If all values are null, then null is returned. - * - * @group agg_funcs - * @since 2.0.0 - */ + * Aggregate function: returns the first value of a column in a group. + * + * The function by default returns the first values it sees. It will return the first non-null + * value it sees when ignoreNulls is set to true. If all values are null, then null is returned. + * + * @group agg_funcs + * @since 2.0.0 + */ def first(columnName: String, ignoreNulls: Boolean): Column = { first(Column(columnName), ignoreNulls) } /** - * Aggregate function: returns the first value in a group. - * - * The function by default returns the first values it sees. It will return the first non-null - * value it sees when ignoreNulls is set to true. If all values are null, then null is returned. - * - * @group agg_funcs - * @since 1.3.0 - */ + * Aggregate function: returns the first value in a group. + * + * The function by default returns the first values it sees. It will return the first non-null + * value it sees when ignoreNulls is set to true. If all values are null, then null is returned. + * + * @group agg_funcs + * @since 1.3.0 + */ def first(e: Column): Column = first(e, ignoreNulls = false) /** - * Aggregate function: returns the first value of a column in a group. - * - * The function by default returns the first values it sees. It will return the first non-null - * value it sees when ignoreNulls is set to true. If all values are null, then null is returned. - * - * @group agg_funcs - * @since 1.3.0 - */ + * Aggregate function: returns the first value of a column in a group. + * + * The function by default returns the first values it sees. It will return the first non-null + * value it sees when ignoreNulls is set to true. If all values are null, then null is returned. + * + * @group agg_funcs + * @since 1.3.0 + */ def first(columnName: String): Column = first(Column(columnName)) - /** - * Aggregate function: indicates whether a specified column in a GROUP BY list is aggregated - * or not, returns 1 for aggregated or 0 for not aggregated in the result set. - * - * @group agg_funcs - * @since 2.0.0 - */ + * Aggregate function: indicates whether a specified column in a GROUP BY list is aggregated + * or not, returns 1 for aggregated or 0 for not aggregated in the result set. + * + * @group agg_funcs + * @since 2.0.0 + */ def grouping(e: Column): Column = Column(Grouping(e.expr)) /** - * Aggregate function: indicates whether a specified column in a GROUP BY list is aggregated - * or not, returns 1 for aggregated or 0 for not aggregated in the result set. - * - * @group agg_funcs - * @since 2.0.0 - */ + * Aggregate function: indicates whether a specified column in a GROUP BY list is aggregated + * or not, returns 1 for aggregated or 0 for not aggregated in the result set. + * + * @group agg_funcs + * @since 2.0.0 + */ def grouping(columnName: String): Column = grouping(Column(columnName)) /** - * Aggregate function: returns the level of grouping, equals to - * - * (grouping(c1) << (n-1)) + (grouping(c2) << (n-2)) + ... + grouping(cn) - * - * Note: the list of columns should match with grouping columns exactly, or empty (means all the - * grouping columns). - * - * @group agg_funcs - * @since 2.0.0 - */ + * Aggregate function: returns the level of grouping, equals to + * + * (grouping(c1) << (n-1)) + (grouping(c2) << (n-2)) + ... + grouping(cn) + * + * Note: the list of columns should match with grouping columns exactly, or empty (means all the + * grouping columns). + * + * @group agg_funcs + * @since 2.0.0 + */ def grouping_id(cols: Column*): Column = Column(GroupingID(cols.map(_.expr))) /** - * Aggregate function: returns the level of grouping, equals to - * - * (grouping(c1) << (n-1)) + (grouping(c2) << (n-2)) + ... + grouping(cn) - * - * Note: the list of columns should match with grouping columns exactly. - * - * @group agg_funcs - * @since 2.0.0 - */ + * Aggregate function: returns the level of grouping, equals to + * + * (grouping(c1) << (n-1)) + (grouping(c2) << (n-2)) + ... + grouping(cn) + * + * Note: the list of columns should match with grouping columns exactly. + * + * @group agg_funcs + * @since 2.0.0 + */ def grouping_id(colName: String, colNames: String*): Column = { grouping_id((Seq(colName) ++ colNames).map(n => Column(n)) : _*) } @@ -442,51 +441,51 @@ object functions { def kurtosis(columnName: String): Column = kurtosis(Column(columnName)) /** - * Aggregate function: returns the last value in a group. - * - * The function by default returns the last values it sees. It will return the last non-null - * value it sees when ignoreNulls is set to true. If all values are null, then null is returned. - * - * @group agg_funcs - * @since 2.0.0 - */ + * Aggregate function: returns the last value in a group. + * + * The function by default returns the last values it sees. It will return the last non-null + * value it sees when ignoreNulls is set to true. If all values are null, then null is returned. + * + * @group agg_funcs + * @since 2.0.0 + */ def last(e: Column, ignoreNulls: Boolean): Column = withAggregateFunction { new Last(e.expr, Literal(ignoreNulls)) } /** - * Aggregate function: returns the last value of the column in a group. - * - * The function by default returns the last values it sees. It will return the last non-null - * value it sees when ignoreNulls is set to true. If all values are null, then null is returned. - * - * @group agg_funcs - * @since 2.0.0 - */ + * Aggregate function: returns the last value of the column in a group. + * + * The function by default returns the last values it sees. It will return the last non-null + * value it sees when ignoreNulls is set to true. If all values are null, then null is returned. + * + * @group agg_funcs + * @since 2.0.0 + */ def last(columnName: String, ignoreNulls: Boolean): Column = { last(Column(columnName), ignoreNulls) } /** - * Aggregate function: returns the last value in a group. - * - * The function by default returns the last values it sees. It will return the last non-null - * value it sees when ignoreNulls is set to true. If all values are null, then null is returned. - * - * @group agg_funcs - * @since 1.3.0 - */ + * Aggregate function: returns the last value in a group. + * + * The function by default returns the last values it sees. It will return the last non-null + * value it sees when ignoreNulls is set to true. If all values are null, then null is returned. + * + * @group agg_funcs + * @since 1.3.0 + */ def last(e: Column): Column = last(e, ignoreNulls = false) /** - * Aggregate function: returns the last value of the column in a group. - * - * The function by default returns the last values it sees. It will return the last non-null - * value it sees when ignoreNulls is set to true. If all values are null, then null is returned. - * - * @group agg_funcs - * @since 1.3.0 - */ + * Aggregate function: returns the last value of the column in a group. + * + * The function by default returns the last values it sees. It will return the last non-null + * value it sees when ignoreNulls is set to true. If all values are null, then null is returned. + * + * @group agg_funcs + * @since 1.3.0 + */ def last(columnName: String): Column = last(Column(columnName), ignoreNulls = false) /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala index e8834d052c..14e14710f6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala @@ -152,19 +152,19 @@ trait StreamSinkProvider { @DeveloperApi trait CreatableRelationProvider { /** - * Creates a relation with the given parameters based on the contents of the given - * DataFrame. The mode specifies the expected behavior of createRelation when - * data already exists. - * Right now, there are three modes, Append, Overwrite, and ErrorIfExists. - * Append mode means that when saving a DataFrame to a data source, if data already exists, - * contents of the DataFrame are expected to be appended to existing data. - * Overwrite mode means that when saving a DataFrame to a data source, if data already exists, - * existing data is expected to be overwritten by the contents of the DataFrame. - * ErrorIfExists mode means that when saving a DataFrame to a data source, - * if data already exists, an exception is expected to be thrown. - * - * @since 1.3.0 - */ + * Creates a relation with the given parameters based on the contents of the given + * DataFrame. The mode specifies the expected behavior of createRelation when + * data already exists. + * Right now, there are three modes, Append, Overwrite, and ErrorIfExists. + * Append mode means that when saving a DataFrame to a data source, if data already exists, + * contents of the DataFrame are expected to be appended to existing data. + * Overwrite mode means that when saving a DataFrame to a data source, if data already exists, + * existing data is expected to be overwritten by the contents of the DataFrame. + * ErrorIfExists mode means that when saving a DataFrame to a data source, + * if data already exists, an exception is expected to be thrown. + * + * @since 1.3.0 + */ def createRelation( sqlContext: SQLContext, mode: SaveMode, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index 854a662cc4..d160f8ab8c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -286,8 +286,8 @@ abstract class QueryTest extends PlanTest { } /** - * Asserts that a given [[Dataset]] does not have missing inputs in all the analyzed plans. - */ + * Asserts that a given [[Dataset]] does not have missing inputs in all the analyzed plans. + */ def assertEmptyMissingInput(query: Dataset[_]): Unit = { assert(query.queryExecution.analyzed.missingInput.isEmpty, s"The analyzed logical plan has missing inputs: ${query.queryExecution.analyzed}") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala index 55906793c0..289e1b6db9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala @@ -32,10 +32,10 @@ import org.apache.spark.unsafe.map.BytesToBytesMap import org.apache.spark.util.Benchmark /** - * Benchmark to measure whole stage codegen performance. - * To run this: - * build/sbt "sql/test-only *BenchmarkWholeStageCodegen" - */ + * Benchmark to measure whole stage codegen performance. + * To run this: + * build/sbt "sql/test-only *BenchmarkWholeStageCodegen" + */ class BenchmarkWholeStageCodegen extends SparkFunSuite { lazy val conf = new SparkConf().setMaster("local[1]").setAppName("benchmark") .set("spark.sql.shuffle.partitions", "1") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVParserSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVParserSuite.scala index dc54883277..aaeecef5f3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVParserSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVParserSuite.scala @@ -20,8 +20,8 @@ package org.apache.spark.sql.execution.datasources.csv import org.apache.spark.SparkFunSuite /** - * test cases for StringIteratorReader - */ + * test cases for StringIteratorReader + */ class CSVParserSuite extends SparkFunSuite { private def readAll(iter: Iterator[String]) = { diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala index c1e151d08b..ac37e8e022 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala @@ -497,9 +497,10 @@ class StreamingContext private[streaming] ( new TransformedDStream[T](dstreams, sparkContext.clean(transformFunc)) } - /** Add a [[org.apache.spark.streaming.scheduler.StreamingListener]] object for - * receiving system events related to streaming. - */ + /** + * Add a [[org.apache.spark.streaming.scheduler.StreamingListener]] object for + * receiving system events related to streaming. + */ def addStreamingListener(streamingListener: StreamingListener) { scheduler.listenerBus.addListener(streamingListener) } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala index 05f4da6fac..922e4a5e4d 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala @@ -517,9 +517,10 @@ class JavaStreamingContext(val ssc: StreamingContext) extends Closeable { ssc.remember(duration) } - /** Add a [[org.apache.spark.streaming.scheduler.StreamingListener]] object for - * receiving system events related to streaming. - */ + /** + * Add a [[org.apache.spark.streaming.scheduler.StreamingListener]] object for + * receiving system events related to streaming. + */ def addStreamingListener(streamingListener: StreamingListener) { ssc.addStreamingListener(streamingListener) } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/RateLimiter.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/RateLimiter.scala index 0a861f22b1..fbac4880bd 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/RateLimiter.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/RateLimiter.scala @@ -22,17 +22,18 @@ import com.google.common.util.concurrent.{RateLimiter => GuavaRateLimiter} import org.apache.spark.SparkConf import org.apache.spark.internal.Logging -/** Provides waitToPush() method to limit the rate at which receivers consume data. - * - * waitToPush method will block the thread if too many messages have been pushed too quickly, - * and only return when a new message has been pushed. It assumes that only one message is - * pushed at a time. - * - * The spark configuration spark.streaming.receiver.maxRate gives the maximum number of messages - * per second that each receiver will accept. - * - * @param conf spark configuration - */ +/** + * Provides waitToPush() method to limit the rate at which receivers consume data. + * + * waitToPush method will block the thread if too many messages have been pushed too quickly, + * and only return when a new message has been pushed. It assumes that only one message is + * pushed at a time. + * + * The spark configuration spark.streaming.receiver.maxRate gives the maximum number of messages + * per second that each receiver will accept. + * + * @param conf spark configuration + */ private[receiver] abstract class RateLimiter(conf: SparkConf) extends Logging { // treated as an upper limit diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobSet.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobSet.scala index 66d5ffb797..0baedaf275 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobSet.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobSet.scala @@ -21,9 +21,10 @@ import scala.collection.mutable.HashSet import org.apache.spark.streaming.Time -/** Class representing a set of Jobs - * belong to the same batch. - */ +/** + * Class representing a set of Jobs + * belong to the same batch. + */ private[streaming] case class JobSet( time: Time, diff --git a/tools/src/main/scala/org/apache/spark/tools/GenerateMIMAIgnore.scala b/tools/src/main/scala/org/apache/spark/tools/GenerateMIMAIgnore.scala index 0df3c501de..c9058ff409 100644 --- a/tools/src/main/scala/org/apache/spark/tools/GenerateMIMAIgnore.scala +++ b/tools/src/main/scala/org/apache/spark/tools/GenerateMIMAIgnore.scala @@ -91,10 +91,11 @@ object GenerateMIMAIgnore { (ignoredClasses.flatMap(c => Seq(c, c.replace("$", "#"))).toSet, ignoredMembers.toSet) } - /** Scala reflection does not let us see inner function even if they are upgraded - * to public for some reason. So had to resort to java reflection to get all inner - * functions with $$ in there name. - */ + /** + * Scala reflection does not let us see inner function even if they are upgraded + * to public for some reason. So had to resort to java reflection to get all inner + * functions with $$ in there name. + */ def getInnerFunctions(classSymbol: unv.ClassSymbol): Seq[String] = { try { Class.forName(classSymbol.fullName, false, classLoader).getMethods.map(_.getName) diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala index 5af2c29808..4b36da309d 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala @@ -135,8 +135,8 @@ class YarnSparkHadoopUtil extends SparkHadoopUtil { } /** - * Obtains token for the Hive metastore and adds them to the credentials. - */ + * Obtains token for the Hive metastore and adds them to the credentials. + */ def obtainTokenForHiveMetastore( sparkConf: SparkConf, conf: Configuration, @@ -149,8 +149,8 @@ class YarnSparkHadoopUtil extends SparkHadoopUtil { } /** - * Obtain a security token for HBase. - */ + * Obtain a security token for HBase. + */ def obtainTokenForHBase( sparkConf: SparkConf, conf: Configuration, @@ -164,10 +164,10 @@ class YarnSparkHadoopUtil extends SparkHadoopUtil { } /** - * Return whether delegation tokens should be retrieved for the given service when security is - * enabled. By default, tokens are retrieved, but that behavior can be changed by setting - * a service-specific configuration. - */ + * Return whether delegation tokens should be retrieved for the given service when security is + * enabled. By default, tokens are retrieved, but that behavior can be changed by setting + * a service-specific configuration. + */ private def shouldGetTokens(conf: SparkConf, service: String): Boolean = { conf.getBoolean(s"spark.yarn.security.tokens.${service}.enabled", true) } -- cgit v1.2.3 From 03d130f9734be66e8aefc4ffaa207ee13e837629 Mon Sep 17 00:00:00 2001 From: Liwei Lin Date: Sat, 2 Apr 2016 17:55:46 -0700 Subject: [SPARK-14342][CORE][DOCS][TESTS] Remove straggler references to Tachyon ## What changes were proposed in this pull request? Straggler references to Tachyon were removed: - for docs, `tachyon` has been generalized as `off-heap memory`; - for Mesos test suits, the key-value `tachyon:true`/`tachyon:false` has been changed to `os:centos`/`os:ubuntu`, since `os` is an example constrain used by the [Mesos official docs](http://mesos.apache.org/documentation/attributes-resources/). ## How was this patch tested? Existing test suites. Author: Liwei Lin Closes #12129 from lw-lin/tachyon-cleanup. --- .../org/apache/spark/api/java/StorageLevels.java | 4 +-- .../cluster/mesos/MesosSchedulerUtils.scala | 4 +-- .../cluster/mesos/MesosSchedulerUtilsSuite.scala | 32 +++++++++++----------- docs/running-on-mesos.md | 4 +-- docs/streaming-programming-guide.md | 2 +- python/pyspark/storagelevel.py | 2 +- 6 files changed, 24 insertions(+), 24 deletions(-) diff --git a/core/src/main/java/org/apache/spark/api/java/StorageLevels.java b/core/src/main/java/org/apache/spark/api/java/StorageLevels.java index 666c797738..23673d3e3d 100644 --- a/core/src/main/java/org/apache/spark/api/java/StorageLevels.java +++ b/core/src/main/java/org/apache/spark/api/java/StorageLevels.java @@ -39,8 +39,8 @@ public class StorageLevels { /** * Create a new StorageLevel object. * @param useDisk saved to disk, if true - * @param useMemory saved to memory, if true - * @param useOffHeap saved to Tachyon, if true + * @param useMemory saved to on-heap memory, if true + * @param useOffHeap saved to off-heap memory, if true * @param deserialized saved as deserialized objects, if true * @param replication replication factor */ diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala index 35f914355d..233bdc23e6 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala @@ -283,11 +283,11 @@ private[mesos] trait MesosSchedulerUtils extends Logging { * are separated by ':'. The ':' implies equality (for singular values) and "is one of" for * multiple values (comma separated). For example: * {{{ - * parseConstraintString("tachyon:true;zone:us-east-1a,us-east-1b") + * parseConstraintString("os:centos7;zone:us-east-1a,us-east-1b") * // would result in * * Map( - * "tachyon" -> Set("true"), + * "os" -> Set("centos7"), * "zone": -> Set("us-east-1a", "us-east-1b") * ) * }}} diff --git a/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtilsSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtilsSuite.scala index 85437b2f80..ceb3a52983 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtilsSuite.scala @@ -59,10 +59,10 @@ class MesosSchedulerUtilsSuite extends SparkFunSuite with Matchers with MockitoS test("parse a non-empty constraint string correctly") { val expectedMap = Map( - "tachyon" -> Set("true"), + "os" -> Set("centos7"), "zone" -> Set("us-east-1a", "us-east-1b") ) - utils.parseConstraintString("tachyon:true;zone:us-east-1a,us-east-1b") should be (expectedMap) + utils.parseConstraintString("os:centos7;zone:us-east-1a,us-east-1b") should be (expectedMap) } test("parse an empty constraint string correctly") { @@ -71,35 +71,35 @@ class MesosSchedulerUtilsSuite extends SparkFunSuite with Matchers with MockitoS test("throw an exception when the input is malformed") { an[IllegalArgumentException] should be thrownBy - utils.parseConstraintString("tachyon;zone:us-east") + utils.parseConstraintString("os;zone:us-east") } test("empty values for attributes' constraints matches all values") { - val constraintsStr = "tachyon:" + val constraintsStr = "os:" val parsedConstraints = utils.parseConstraintString(constraintsStr) - parsedConstraints shouldBe Map("tachyon" -> Set()) + parsedConstraints shouldBe Map("os" -> Set()) val zoneSet = Value.Set.newBuilder().addItem("us-east-1a").addItem("us-east-1b").build() - val noTachyonOffer = Map("zone" -> zoneSet) - val tachyonTrueOffer = Map("tachyon" -> Value.Text.newBuilder().setValue("true").build()) - val tachyonFalseOffer = Map("tachyon" -> Value.Text.newBuilder().setValue("false").build()) + val noOsOffer = Map("zone" -> zoneSet) + val centosOffer = Map("os" -> Value.Text.newBuilder().setValue("centos").build()) + val ubuntuOffer = Map("os" -> Value.Text.newBuilder().setValue("ubuntu").build()) - utils.matchesAttributeRequirements(parsedConstraints, noTachyonOffer) shouldBe false - utils.matchesAttributeRequirements(parsedConstraints, tachyonTrueOffer) shouldBe true - utils.matchesAttributeRequirements(parsedConstraints, tachyonFalseOffer) shouldBe true + utils.matchesAttributeRequirements(parsedConstraints, noOsOffer) shouldBe false + utils.matchesAttributeRequirements(parsedConstraints, centosOffer) shouldBe true + utils.matchesAttributeRequirements(parsedConstraints, ubuntuOffer) shouldBe true } test("subset match is performed for set attributes") { val supersetConstraint = Map( - "tachyon" -> Value.Text.newBuilder().setValue("true").build(), + "os" -> Value.Text.newBuilder().setValue("ubuntu").build(), "zone" -> Value.Set.newBuilder() .addItem("us-east-1a") .addItem("us-east-1b") .addItem("us-east-1c") .build()) - val zoneConstraintStr = "tachyon:;zone:us-east-1a,us-east-1c" + val zoneConstraintStr = "os:;zone:us-east-1a,us-east-1c" val parsedConstraints = utils.parseConstraintString(zoneConstraintStr) utils.matchesAttributeRequirements(parsedConstraints, supersetConstraint) shouldBe true @@ -131,10 +131,10 @@ class MesosSchedulerUtilsSuite extends SparkFunSuite with Matchers with MockitoS } test("equality match is performed for text attributes") { - val offerAttribs = Map("tachyon" -> Value.Text.newBuilder().setValue("true").build()) + val offerAttribs = Map("os" -> Value.Text.newBuilder().setValue("centos7").build()) - val trueConstraint = utils.parseConstraintString("tachyon:true") - val falseConstraint = utils.parseConstraintString("tachyon:false") + val trueConstraint = utils.parseConstraintString("os:centos7") + val falseConstraint = utils.parseConstraintString("os:ubuntu") utils.matchesAttributeRequirements(trueConstraint, offerAttribs) shouldBe true utils.matchesAttributeRequirements(falseConstraint, offerAttribs) shouldBe false diff --git a/docs/running-on-mesos.md b/docs/running-on-mesos.md index 293a82882e..8e47301a75 100644 --- a/docs/running-on-mesos.md +++ b/docs/running-on-mesos.md @@ -215,10 +215,10 @@ conf.set("spark.mesos.coarse", "false") You may also make use of `spark.mesos.constraints` to set attribute based constraints on mesos resource offers. By default, all resource offers will be accepted. {% highlight scala %} -conf.set("spark.mesos.constraints", "tachyon:true;us-east-1:false") +conf.set("spark.mesos.constraints", "os:centos7;us-east-1:false") {% endhighlight %} -For example, Let's say `spark.mesos.constraints` is set to `tachyon:true;us-east-1:false`, then the resource offers will be checked to see if they meet both these constraints and only then will be accepted to start new executors. +For example, Let's say `spark.mesos.constraints` is set to `os:centos7;us-east-1:false`, then the resource offers will be checked to see if they meet both these constraints and only then will be accepted to start new executors. # Mesos Docker Support diff --git a/docs/streaming-programming-guide.md b/docs/streaming-programming-guide.md index 8d21917a7d..7f6c0ed699 100644 --- a/docs/streaming-programming-guide.md +++ b/docs/streaming-programming-guide.md @@ -2178,7 +2178,7 @@ overall processing throughput of the system, its use is still recommended to ach consistent batch processing times. Make sure you set the CMS GC on both the driver (using `--driver-java-options` in `spark-submit`) and the executors (using [Spark configuration](configuration.html#runtime-environment) `spark.executor.extraJavaOptions`). * **Other tips**: To further reduce GC overheads, here are some more tips to try. - - Use Tachyon for off-heap storage of persisted RDDs. See more detail in the [Spark Programming Guide](programming-guide.html#rdd-persistence). + - Persist RDDs using the `OFF_HEAP` storage level. See more detail in the [Spark Programming Guide](programming-guide.html#rdd-persistence). - Use more executors with smaller heap sizes. This will reduce the GC pressure within each JVM heap. diff --git a/python/pyspark/storagelevel.py b/python/pyspark/storagelevel.py index d4f184a85d..176e3bb41c 100644 --- a/python/pyspark/storagelevel.py +++ b/python/pyspark/storagelevel.py @@ -44,7 +44,7 @@ class StorageLevel(object): result = "" result += "Disk " if self.useDisk else "" result += "Memory " if self.useMemory else "" - result += "Tachyon " if self.useOffHeap else "" + result += "OffHeap " if self.useOffHeap else "" result += "Deserialized " if self.deserialized else "Serialized " result += "%sx Replicated" % self.replication return result -- cgit v1.2.3 From 1cf70183423b938ec064925b20fd4a5b9e355991 Mon Sep 17 00:00:00 2001 From: Sital Kedia Date: Sat, 2 Apr 2016 19:17:25 -0700 Subject: [SPARK-14056] Appends s3 specific configurations and spark.hadoop con… MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? Appends s3 specific configurations and spark.hadoop configurations to hive configuration. ## How was this patch tested? Tested by running a job on cluster. …figurations to hive configuration. Author: Sital Kedia Closes #11876 from sitalkedia/hiveConf. --- .../org/apache/spark/deploy/SparkHadoopUtil.scala | 19 +++++++++++++------ .../scala/org/apache/spark/sql/hive/TableReader.scala | 4 ++-- 2 files changed, 15 insertions(+), 8 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala index 06b7b388ca..4e8e363635 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala @@ -74,13 +74,12 @@ class SparkHadoopUtil extends Logging { } } - /** - * Return an appropriate (subclass) of Configuration. Creating config can initializes some Hadoop - * subsystems. - */ - def newConfiguration(conf: SparkConf): Configuration = { - val hadoopConf = new Configuration() + /** + * Appends S3-specific, spark.hadoop.*, and spark.buffer.size configurations to a Hadoop + * configuration. + */ + def appendS3AndSparkHadoopConfigurations(conf: SparkConf, hadoopConf: Configuration): Unit = { // Note: this null check is around more than just access to the "conf" object to maintain // the behavior of the old implementation of this code, for backwards compatibility. if (conf != null) { @@ -106,7 +105,15 @@ class SparkHadoopUtil extends Logging { val bufferSize = conf.get("spark.buffer.size", "65536") hadoopConf.set("io.file.buffer.size", bufferSize) } + } + /** + * Return an appropriate (subclass) of Configuration. Creating config can initializes some Hadoop + * subsystems. + */ + def newConfiguration(conf: SparkConf): Configuration = { + val hadoopConf = new Configuration() + appendS3AndSparkHadoopConfigurations(conf, hadoopConf) hadoopConf } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala index 80b24dc989..54afe9c2a3 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala @@ -34,6 +34,7 @@ import org.apache.hadoop.io.Writable import org.apache.hadoop.mapred.{FileInputFormat, InputFormat, JobConf} import org.apache.spark.broadcast.Broadcast +import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.internal.Logging import org.apache.spark.rdd.{EmptyRDD, HadoopRDD, RDD, UnionRDD} import org.apache.spark.sql.catalyst.InternalRow @@ -74,8 +75,7 @@ class HadoopTableReader( math.max(sc.hiveconf.getInt("mapred.map.tasks", 1), sc.sparkContext.defaultMinPartitions) } - // TODO: set aws s3 credentials. - + SparkHadoopUtil.get.appendS3AndSparkHadoopConfigurations(sc.sparkContext.conf, hiveExtraConf) private val _broadcastedHiveConf = sc.sparkContext.broadcast(new SerializableConfiguration(hiveExtraConf)) -- cgit v1.2.3 From c2f25b1a148eeb1791ea7018b14b3a665c13212a Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Sat, 2 Apr 2016 19:34:38 -0700 Subject: [SPARK-13996] [SQL] Add more not null attributes for Filter codegen ## What changes were proposed in this pull request? JIRA: https://issues.apache.org/jira/browse/SPARK-13996 Filter codegen finds the attributes not null by checking IsNotNull(a) expression with a condition if child.output.contains(a). However, the current approach to checking it is not comprehensive. We can improve it. E.g., for this plan: val rdd = sqlContext.sparkContext.makeRDD(Seq(Row(1, "1"), Row(null, "1"), Row(2, "2"))) val schema = new StructType().add("k", IntegerType).add("v", StringType) val smallDF = sqlContext.createDataFrame(rdd, schema) val df = smallDF.filter("isnotnull(k + 1)") The code snippet generated without this patch: /* 031 */ protected void processNext() throws java.io.IOException { /* 032 */ /*** PRODUCE: Filter isnotnull((k#0 + 1)) */ /* 033 */ /* 034 */ /*** PRODUCE: INPUT */ /* 035 */ /* 036 */ while (!shouldStop() && inputadapter_input.hasNext()) { /* 037 */ InternalRow inputadapter_row = (InternalRow) inputadapter_input.next(); /* 038 */ /*** CONSUME: Filter isnotnull((k#0 + 1)) */ /* 039 */ /* input[0, int] */ /* 040 */ boolean filter_isNull = inputadapter_row.isNullAt(0); /* 041 */ int filter_value = filter_isNull ? -1 : (inputadapter_row.getInt(0)); /* 042 */ /* 043 */ /* isnotnull((input[0, int] + 1)) */ /* 044 */ /* (input[0, int] + 1) */ /* 045 */ boolean filter_isNull3 = true; /* 046 */ int filter_value3 = -1; /* 047 */ /* 048 */ if (!filter_isNull) { /* 049 */ filter_isNull3 = false; // resultCode could change nullability. /* 050 */ filter_value3 = filter_value + 1; /* 051 */ /* 052 */ } /* 053 */ if (!(!(filter_isNull3))) continue; /* 054 */ /* 055 */ filter_metricValue.add(1); With this patch: /* 031 */ protected void processNext() throws java.io.IOException { /* 032 */ /*** PRODUCE: Filter isnotnull((k#0 + 1)) */ /* 033 */ /* 034 */ /*** PRODUCE: INPUT */ /* 035 */ /* 036 */ while (!shouldStop() && inputadapter_input.hasNext()) { /* 037 */ InternalRow inputadapter_row = (InternalRow) inputadapter_input.next(); /* 038 */ /*** CONSUME: Filter isnotnull((k#0 + 1)) */ /* 039 */ /* input[0, int] */ /* 040 */ boolean filter_isNull = inputadapter_row.isNullAt(0); /* 041 */ int filter_value = filter_isNull ? -1 : (inputadapter_row.getInt(0)); /* 042 */ /* 043 */ if (filter_isNull) continue; /* 044 */ /* 045 */ filter_metricValue.add(1); ## How was this patch tested? Existing tests. Author: Liang-Chi Hsieh Closes #11810 from viirya/add-more-not-null-attrs. --- .../scala/org/apache/spark/sql/execution/basicOperators.scala | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index a6a14df6a3..fb1c6182cf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -79,12 +79,12 @@ case class Filter(condition: Expression, child: SparkPlan) // Split out all the IsNotNulls from condition. private val (notNullPreds, otherPreds) = splitConjunctivePredicates(condition).partition { - case IsNotNull(a) if child.output.exists(_.semanticEquals(a)) => true + case IsNotNull(a: NullIntolerant) if a.references.subsetOf(child.outputSet) => true case _ => false } // The columns that will filtered out by `IsNotNull` could be considered as not nullable. - private val notNullAttributes = notNullPreds.flatMap(_.references) + private val notNullAttributes = notNullPreds.flatMap(_.references).distinct.map(_.exprId) // Mark this as empty. We'll evaluate the input during doConsume(). We don't want to evaluate // all the variables at the beginning to take advantage of short circuiting. @@ -92,7 +92,7 @@ case class Filter(condition: Expression, child: SparkPlan) override def output: Seq[Attribute] = { child.output.map { a => - if (a.nullable && notNullAttributes.exists(_.semanticEquals(a))) { + if (a.nullable && notNullAttributes.contains(a.exprId)) { a.withNullability(false) } else { a @@ -179,7 +179,7 @@ case class Filter(condition: Expression, child: SparkPlan) // Reset the isNull to false for the not-null columns, then the followed operators could // generate better code (remove dead branches). val resultVars = input.zipWithIndex.map { case (ev, i) => - if (notNullAttributes.exists(_.semanticEquals(child.output(i)))) { + if (notNullAttributes.contains(child.output(i).exprId)) { ev.isNull = "false" } ev -- cgit v1.2.3 From 7be46205083fc688249ee619ac7758904f7aa55d Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Sat, 2 Apr 2016 23:05:23 -0700 Subject: [HOTFIX] Fix Scala 2.10 compilation --- .../spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala index 33239c0084..c02fec3085 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala @@ -41,7 +41,7 @@ class SimplifyConditionalSuite extends PlanTest with PredicateHelper { private val trueBranch = (TrueLiteral, Literal(5)) private val normalBranch = (NonFoldableLiteral(true), Literal(10)) private val unreachableBranch = (FalseLiteral, Literal(20)) - private val nullBranch = (Literal(null, NullType), Literal(30)) + private val nullBranch = (Literal.create(null, NullType), Literal(30)) test("simplify if") { assertEquivalent( @@ -53,7 +53,7 @@ class SimplifyConditionalSuite extends PlanTest with PredicateHelper { Literal(20)) assertEquivalent( - If(Literal(null, NullType), Literal(10), Literal(20)), + If(Literal.create(null, NullType), Literal(10), Literal(20)), Literal(20)) } -- cgit v1.2.3 From 2262a93358c2f6d4cfb73645c4ebc963c5640ec8 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Sat, 2 Apr 2016 23:12:04 -0700 Subject: [SPARK-14231] [SQL] JSON data source infers floating-point values as a double when they do not fit in a decimal ## What changes were proposed in this pull request? https://issues.apache.org/jira/browse/SPARK-14231 Currently, JSON data source supports to infer `DecimalType` for big numbers and `floatAsBigDecimal` option which reads floating-point values as `DecimalType`. But there are few restrictions in Spark `DecimalType` below: 1. The precision cannot be bigger than 38. 2. scale cannot be bigger than precision. Currently, both restrictions are not being handled. This PR handles the cases by inferring them as `DoubleType`. Also, the option name was changed from `floatAsBigDecimal` to `prefersDecimal` as suggested [here](https://issues.apache.org/jira/browse/SPARK-14231?focusedCommentId=15215579&page=com.atlassian.jira.plugin.system.issuetabpanels:comment-tabpanel#comment-15215579). So, the codes below: ```scala def doubleRecords: RDD[String] = sqlContext.sparkContext.parallelize( s"""{"a": 1${"0" * 38}, "b": 0.01}""" :: s"""{"a": 2${"0" * 38}, "b": 0.02}""" :: Nil) val jsonDF = sqlContext.read .option("prefersDecimal", "true") .json(doubleRecords) jsonDF.printSchema() ``` produces below: - **Before** ```scala org.apache.spark.sql.AnalysisException: Decimal scale (2) cannot be greater than precision (1).; at org.apache.spark.sql.types.DecimalType.(DecimalType.scala:44) at org.apache.spark.sql.execution.datasources.json.InferSchema$.org$apache$spark$sql$execution$datasources$json$InferSchema$$inferField(InferSchema.scala:144) at org.apache.spark.sql.execution.datasources.json.InferSchema$.org$apache$spark$sql$execution$datasources$json$InferSchema$$inferField(InferSchema.scala:108) at ... ``` - **After** ```scala root |-- a: double (nullable = true) |-- b: double (nullable = true) ``` ## How was this patch tested? Unit tests were used and `./dev/run_tests` for coding style tests. Author: hyukjinkwon Closes #12030 from HyukjinKwon/SPARK-14231. --- python/pyspark/sql/readwriter.py | 4 +- .../org/apache/spark/sql/DataFrameReader.scala | 4 +- .../execution/datasources/json/InferSchema.scala | 17 +++++--- .../execution/datasources/json/JSONOptions.scala | 4 +- .../sql/execution/datasources/json/JsonSuite.scala | 48 +++++++++++++++++++++- .../execution/datasources/json/TestJsonData.scala | 8 ++++ 6 files changed, 71 insertions(+), 14 deletions(-) diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index cca57a385c..0cef37e57c 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -152,8 +152,8 @@ class DataFrameReader(object): You can set the following JSON-specific options to deal with non-standard JSON files: * ``primitivesAsString`` (default ``false``): infers all primitive values as a string \ type - * `floatAsBigDecimal` (default `false`): infers all floating-point values as a decimal \ - type + * `prefersDecimal` (default `false`): infers all floating-point values as a decimal \ + type. If the values do not fit in decimal, then it infers them as doubles. * ``allowComments`` (default ``false``): ignores Java/C++ style comment in JSON records * ``allowUnquotedFieldNames`` (default ``false``): allows unquoted JSON field names * ``allowSingleQuotes`` (default ``true``): allows single quotes in addition to double \ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index 704535adaa..a5a6e01e99 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -315,8 +315,8 @@ class DataFrameReader private[sql](sqlContext: SQLContext) extends Logging { * * You can set the following JSON-specific options to deal with non-standard JSON files: *
  • `primitivesAsString` (default `false`): infers all primitive values as a string type
  • - *
  • `floatAsBigDecimal` (default `false`): infers all floating-point values as a decimal - * type
  • + *
  • `prefersDecimal` (default `false`): infers all floating-point values as a decimal + * type. If the values do not fit in decimal, then it infers them as doubles.
  • *
  • `allowComments` (default `false`): ignores Java/C++ style comment in JSON records
  • *
  • `allowUnquotedFieldNames` (default `false`): allows unquoted JSON field names
  • *
  • `allowSingleQuotes` (default `true`): allows single quotes in addition to double quotes diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/InferSchema.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/InferSchema.scala index 945ed2c211..4a34f365e4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/InferSchema.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/InferSchema.scala @@ -25,7 +25,6 @@ import org.apache.spark.sql.execution.datasources.json.JacksonUtils.nextUntil import org.apache.spark.sql.types._ import org.apache.spark.util.Utils - private[sql] object InferSchema { /** @@ -135,14 +134,20 @@ private[sql] object InferSchema { // when we see a Java BigInteger, we use DecimalType. case BIG_INTEGER | BIG_DECIMAL => val v = parser.getDecimalValue - DecimalType(v.precision(), v.scale()) - case FLOAT | DOUBLE => - if (configOptions.floatAsBigDecimal) { - val v = parser.getDecimalValue - DecimalType(v.precision(), v.scale()) + if (Math.max(v.precision(), v.scale()) <= DecimalType.MAX_PRECISION) { + DecimalType(Math.max(v.precision(), v.scale()), v.scale()) + } else { + DoubleType + } + case FLOAT | DOUBLE if configOptions.prefersDecimal => + val v = parser.getDecimalValue + if (Math.max(v.precision(), v.scale()) <= DecimalType.MAX_PRECISION) { + DecimalType(Math.max(v.precision(), v.scale()), v.scale()) } else { DoubleType } + case FLOAT | DOUBLE => + DoubleType } case VALUE_TRUE | VALUE_FALSE => BooleanType diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONOptions.scala index c0ad9efcb7..66f1126fb9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONOptions.scala @@ -35,8 +35,8 @@ private[sql] class JSONOptions( parameters.get("samplingRatio").map(_.toDouble).getOrElse(1.0) val primitivesAsString = parameters.get("primitivesAsString").map(_.toBoolean).getOrElse(false) - val floatAsBigDecimal = - parameters.get("floatAsBigDecimal").map(_.toBoolean).getOrElse(false) + val prefersDecimal = + parameters.get("prefersDecimal").map(_.toBoolean).getOrElse(false) val allowComments = parameters.get("allowComments").map(_.toBoolean).getOrElse(false) val allowUnquotedFieldNames = diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index c108d81b18..421862c394 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -745,8 +745,8 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { ) } - test("Loading a JSON dataset floatAsBigDecimal returns schema with float types as BigDecimal") { - val jsonDF = sqlContext.read.option("floatAsBigDecimal", "true").json(primitiveFieldAndType) + test("Loading a JSON dataset prefersDecimal returns schema with float types as BigDecimal") { + val jsonDF = sqlContext.read.option("prefersDecimal", "true").json(primitiveFieldAndType) val expectedSchema = StructType( StructField("bigInteger", DecimalType(20, 0), true) :: @@ -773,6 +773,50 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { ) } + test("Infer big integers correctly even when it does not fit in decimal") { + val jsonDF = sqlContext.read + .json(bigIntegerRecords) + + // The value in `a` field will be a double as it does not fit in decimal. For `b` field, + // it will be a decimal as `92233720368547758070`. + val expectedSchema = StructType( + StructField("a", DoubleType, true) :: + StructField("b", DecimalType(20, 0), true) :: Nil) + + assert(expectedSchema === jsonDF.schema) + checkAnswer(jsonDF, Row(1.0E38D, BigDecimal("92233720368547758070"))) + } + + test("Infer floating-point values correctly even when it does not fit in decimal") { + val jsonDF = sqlContext.read + .option("prefersDecimal", "true") + .json(floatingValueRecords) + + // The value in `a` field will be a double as it does not fit in decimal. For `b` field, + // it will be a decimal as `0.01` by having a precision equal to the scale. + val expectedSchema = StructType( + StructField("a", DoubleType, true) :: + StructField("b", DecimalType(2, 2), true):: Nil) + + assert(expectedSchema === jsonDF.schema) + checkAnswer(jsonDF, Row(1.0E-39D, BigDecimal(0.01))) + + val mergedJsonDF = sqlContext.read + .option("prefersDecimal", "true") + .json(floatingValueRecords ++ bigIntegerRecords) + + val expectedMergedSchema = StructType( + StructField("a", DoubleType, true) :: + StructField("b", DecimalType(22, 2), true):: Nil) + + assert(expectedMergedSchema === mergedJsonDF.schema) + checkAnswer( + mergedJsonDF, + Row(1.0E-39D, BigDecimal(0.01)) :: + Row(1.0E38D, BigDecimal("92233720368547758070")) :: Nil + ) + } + test("Loading a JSON dataset from a text file with SQL") { val dir = Utils.createTempDir() dir.delete() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/TestJsonData.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/TestJsonData.scala index b2eff816ee..2873c6a881 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/TestJsonData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/TestJsonData.scala @@ -214,6 +214,14 @@ private[json] trait TestJsonData { """{"a": {"b": 1}}""" :: """{"a": []}""" :: Nil) + def floatingValueRecords: RDD[String] = + sqlContext.sparkContext.parallelize( + s"""{"a": 0.${"0" * 38}1, "b": 0.01}""" :: Nil) + + def bigIntegerRecords: RDD[String] = + sqlContext.sparkContext.parallelize( + s"""{"a": 1${"0" * 38}, "b": 92233720368547758070}""" :: Nil) + lazy val singleRow: RDD[String] = sqlContext.sparkContext.parallelize("""{"a":123}""" :: Nil) def empty: RDD[String] = sqlContext.sparkContext.parallelize(Seq[String]()) -- cgit v1.2.3 From 1f0c5dcebba1f9d1149043a496e0175f78252bae Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Sun, 3 Apr 2016 15:33:29 +0200 Subject: [SPARK-14350][SQL] EXPLAIN output should be in a single cell ## What changes were proposed in this pull request? EXPLAIN output should be in a single cell. **Before** ``` scala> sql("explain select 1").collect() res0: Array[org.apache.spark.sql.Row] = Array([== Physical Plan ==], [WholeStageCodegen], [: +- Project [1 AS 1#1]], [: +- INPUT], [+- Scan OneRowRelation[]]) ``` **After** ``` scala> sql("explain select 1").collect() res1: Array[org.apache.spark.sql.Row] = Array([== Physical Plan == WholeStageCodegen : +- Project [1 AS 1#4] : +- INPUT +- Scan OneRowRelation[]]) ``` Or, ``` scala> sql("explain select 1").head res1: org.apache.spark.sql.Row = [== Physical Plan == WholeStageCodegen : +- Project [1 AS 1#5] : +- INPUT +- Scan OneRowRelation[]] ``` Please note that `Spark-shell(Scala-shell)` trims long string output. So, you may need to use `println` to get full strings. ``` scala> println(sql("explain codegen select 'a' as a group by 1").head) [Found 2 WholeStageCodegen subtrees. == Subtree 1 / 2 == WholeStageCodegen ... /* 059 */ } /* 060 */ } ] ``` ## How was this patch tested? Pass the Jenkins tests. (Testcases are updated.) Author: Dongjoon Hyun Closes #12137 from dongjoon-hyun/SPARK-14350. --- .../main/scala/org/apache/spark/sql/execution/command/commands.scala | 2 +- .../test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala index 4bc62cdc4a..4eb8d7ff0d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala @@ -253,7 +253,7 @@ case class ExplainCommand( } else { queryExecution.simpleString } - outputString.split("\n").map(Row(_)) + Seq(Row(outputString)) } catch { case cause: TreeNodeException[_] => ("Error occurred during query planning: \n" + cause.getMessage).split("\n").map(Row(_)) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index 58259060bf..5450368b88 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -707,7 +707,7 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { def isExplanation(result: DataFrame): Boolean = { val explanation = result.select('plan).collect().map { case Row(plan: String) => plan } - explanation.contains("== Physical Plan ==") + explanation.head.startsWith("== Physical Plan ==") } test("SPARK-1704: Explain commands as a DataFrame") { -- cgit v1.2.3 From c238cd07448f94bbceb661daad90b6a6d597a846 Mon Sep 17 00:00:00 2001 From: bomeng Date: Sun, 3 Apr 2016 17:15:02 +0200 Subject: [SPARK-14341][SQL] Throw exception on unsupported create / drop macro ddl ## What changes were proposed in this pull request? We throw an AnalysisException that looks like this: ``` scala> sqlContext.sql("CREATE TEMPORARY MACRO SIGMOID (x DOUBLE) 1.0 / (1.0 + EXP(-x))") org.apache.spark.sql.catalyst.parser.ParseException: Unsupported SQL statement == SQL == CREATE TEMPORARY MACRO SIGMOID (x DOUBLE) 1.0 / (1.0 + EXP(-x)) at org.apache.spark.sql.catalyst.parser.AbstractSqlParser.nativeCommand(ParseDriver.scala:66) at org.apache.spark.sql.catalyst.parser.AbstractSqlParser$$anonfun$parsePlan$1.apply(ParseDriver.scala:56) at org.apache.spark.sql.catalyst.parser.AbstractSqlParser$$anonfun$parsePlan$1.apply(ParseDriver.scala:53) at org.apache.spark.sql.catalyst.parser.AbstractSqlParser.parse(ParseDriver.scala:86) at org.apache.spark.sql.catalyst.parser.AbstractSqlParser.parsePlan(ParseDriver.scala:53) at org.apache.spark.sql.SQLContext.parseSql(SQLContext.scala:198) at org.apache.spark.sql.SQLContext.sql(SQLContext.scala:749) ... 48 elided ``` ## How was this patch tested? Add test cases in HiveQuerySuite.scala Author: bomeng Closes #12125 from bomeng/SPARK-14341. --- .../main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 | 3 +++ .../apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala | 6 ++++-- .../scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala | 7 +++++++ 3 files changed, 14 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index f34bb061e4..6cf47b5c30 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -173,6 +173,8 @@ unsupportedHiveNativeCommands | kw1=LOCK kw2=DATABASE | kw1=UNLOCK kw2=TABLE | kw1=UNLOCK kw2=DATABASE + | kw1=CREATE kw2=TEMPORARY kw3=MACRO + | kw1=DROP kw2=TEMPORARY kw3=MACRO ; createTableHeader @@ -759,6 +761,7 @@ SNAPSHOT: 'SNAPSHOT'; READ: 'READ'; WRITE: 'WRITE'; ONLY: 'ONLY'; +MACRO: 'MACRO'; IF: 'IF'; diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala index d8695bc5db..4b4f88ece0 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala @@ -363,7 +363,10 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { // Index commands are not supported "drop_index", "drop_index_removes_partition_dirs", - "alter_index" + "alter_index", + + // Macro commands are not supported + "macro" ) /** @@ -733,7 +736,6 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "load_file_with_space_in_the_name", "loadpart1", "louter_join_ppr", - "macro", "mapjoin_distinct", "mapjoin_filter_on_outerjoin", "mapjoin_mapjoin", diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index 5450368b88..b951948fda 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -1295,6 +1295,13 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { assertUnsupportedFeature { sql("ALTER INDEX my_index ON my_table set IDXPROPERTIES (\"prop1\"=\"val1_new\")")} } + + test("create/drop macro commands are not supported") { + assertUnsupportedFeature { + sql("CREATE TEMPORARY MACRO SIGMOID (x DOUBLE) 1.0 / (1.0 + EXP(-x))") + } + assertUnsupportedFeature { sql("DROP TEMPORARY MACRO SIGMOID") } + } } // for SPARK-2180 test -- cgit v1.2.3 From 9023015f059327b3ce4a7eaf71e57ac77b84ad7b Mon Sep 17 00:00:00 2001 From: Marcin Tustin Date: Sun, 3 Apr 2016 17:42:33 -0700 Subject: [SPARK-14163][CORE] SumEvaluator and countApprox cannot reliably handle RDDs of size 1 ## What changes were proposed in this pull request? This special cases 0 and 1 counts to avoid passing 0 degrees of freedom. ## How was this patch tested? Tests run successfully. New test added. ## Note: This recreates #11982 which was closed to due to non-updated diff. rxin srowen Commented there. This also adds tests, reworks the code to perform the special casing (based on srowen's comments), and adds equality machinery for BoundedDouble, as well as changing how it is transformed to string. Author: Marcin Tustin Author: Marcin Tustin Closes #12016 from mtustin-handy/SPARK-14163. --- .../org/apache/spark/partial/BoundedDouble.scala | 18 ++++ .../org/apache/spark/partial/SumEvaluator.scala | 36 ++++--- .../apache/spark/partial/SumEvaluatorSuite.scala | 107 +++++++++++++++++++++ 3 files changed, 148 insertions(+), 13 deletions(-) create mode 100644 core/src/test/scala/org/apache/spark/partial/SumEvaluatorSuite.scala diff --git a/core/src/main/scala/org/apache/spark/partial/BoundedDouble.scala b/core/src/main/scala/org/apache/spark/partial/BoundedDouble.scala index 48b9434153..d06b2c67d2 100644 --- a/core/src/main/scala/org/apache/spark/partial/BoundedDouble.scala +++ b/core/src/main/scala/org/apache/spark/partial/BoundedDouble.scala @@ -21,5 +21,23 @@ package org.apache.spark.partial * A Double value with error bars and associated confidence. */ class BoundedDouble(val mean: Double, val confidence: Double, val low: Double, val high: Double) { + override def toString(): String = "[%.3f, %.3f]".format(low, high) + + override def hashCode: Int = + this.mean.hashCode ^ this.confidence.hashCode ^ this.low.hashCode ^ this.high.hashCode + + /** + * Note that consistent with Double, any NaN value will make equality false + */ + override def equals(that: Any): Boolean = + that match { + case that: BoundedDouble => { + this.mean == that.mean && + this.confidence == that.confidence && + this.low == that.low && + this.high == that.high + } + case _ => false + } } diff --git a/core/src/main/scala/org/apache/spark/partial/SumEvaluator.scala b/core/src/main/scala/org/apache/spark/partial/SumEvaluator.scala index 44295e5a1a..5fe3358316 100644 --- a/core/src/main/scala/org/apache/spark/partial/SumEvaluator.scala +++ b/core/src/main/scala/org/apache/spark/partial/SumEvaluator.scala @@ -29,8 +29,9 @@ import org.apache.spark.util.StatCounter private[spark] class SumEvaluator(totalOutputs: Int, confidence: Double) extends ApproximateEvaluator[StatCounter, BoundedDouble] { + // modified in merge var outputsMerged = 0 - var counter = new StatCounter + val counter = new StatCounter override def merge(outputId: Int, taskResult: StatCounter) { outputsMerged += 1 @@ -40,30 +41,39 @@ private[spark] class SumEvaluator(totalOutputs: Int, confidence: Double) override def currentResult(): BoundedDouble = { if (outputsMerged == totalOutputs) { new BoundedDouble(counter.sum, 1.0, counter.sum, counter.sum) - } else if (outputsMerged == 0) { + } else if (outputsMerged == 0 || counter.count == 0) { new BoundedDouble(0, 0.0, Double.NegativeInfinity, Double.PositiveInfinity) } else { val p = outputsMerged.toDouble / totalOutputs val meanEstimate = counter.mean - val meanVar = counter.sampleVariance / counter.count val countEstimate = (counter.count + 1 - p) / p - val countVar = (counter.count + 1) * (1 - p) / (p * p) val sumEstimate = meanEstimate * countEstimate - val sumVar = (meanEstimate * meanEstimate * countVar) + - (countEstimate * countEstimate * meanVar) + - (meanVar * countVar) - val sumStdev = math.sqrt(sumVar) - val confFactor = { - if (counter.count > 100) { + + val meanVar = counter.sampleVariance / counter.count + + // branch at this point because counter.count == 1 implies counter.sampleVariance == Nan + // and we don't want to ever return a bound of NaN + if (meanVar.isNaN || counter.count == 1) { + new BoundedDouble(sumEstimate, confidence, Double.NegativeInfinity, Double.PositiveInfinity) + } else { + val countVar = (counter.count + 1) * (1 - p) / (p * p) + val sumVar = (meanEstimate * meanEstimate * countVar) + + (countEstimate * countEstimate * meanVar) + + (meanVar * countVar) + val sumStdev = math.sqrt(sumVar) + val confFactor = if (counter.count > 100) { new NormalDistribution().inverseCumulativeProbability(1 - (1 - confidence) / 2) } else { + // note that if this goes to 0, TDistribution will throw an exception. + // Hence special casing 1 above. val degreesOfFreedom = (counter.count - 1).toInt new TDistribution(degreesOfFreedom).inverseCumulativeProbability(1 - (1 - confidence) / 2) } + + val low = sumEstimate - confFactor * sumStdev + val high = sumEstimate + confFactor * sumStdev + new BoundedDouble(sumEstimate, confidence, low, high) } - val low = sumEstimate - confFactor * sumStdev - val high = sumEstimate + confFactor * sumStdev - new BoundedDouble(sumEstimate, confidence, low, high) } } } diff --git a/core/src/test/scala/org/apache/spark/partial/SumEvaluatorSuite.scala b/core/src/test/scala/org/apache/spark/partial/SumEvaluatorSuite.scala new file mode 100644 index 0000000000..a79f5b4d74 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/partial/SumEvaluatorSuite.scala @@ -0,0 +1,107 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.partial + +import org.apache.spark._ +import org.apache.spark.util.StatCounter + +class SumEvaluatorSuite extends SparkFunSuite with SharedSparkContext { + + test("correct handling of count 1") { + + // setup + val counter = new StatCounter(List(2.0)) + // count of 10 because it's larger than 1, + // and 0.95 because that's the default + val evaluator = new SumEvaluator(10, 0.95) + // arbitrarily assign id 1 + evaluator.merge(1, counter) + + // execute + val res = evaluator.currentResult() + // 38.0 - 7.1E-15 because that's how the maths shakes out + val targetMean = 38.0 - 7.1E-15 + + // Sanity check that equality works on BoundedDouble + assert(new BoundedDouble(2.0, 0.95, 1.1, 1.2) == new BoundedDouble(2.0, 0.95, 1.1, 1.2)) + + // actual test + assert(res == + new BoundedDouble(targetMean, 0.950, Double.NegativeInfinity, Double.PositiveInfinity)) + } + + test("correct handling of count 0") { + + // setup + val counter = new StatCounter(List()) + // count of 10 because it's larger than 0, + // and 0.95 because that's the default + val evaluator = new SumEvaluator(10, 0.95) + // arbitrarily assign id 1 + evaluator.merge(1, counter) + + // execute + val res = evaluator.currentResult() + // assert + assert(res == new BoundedDouble(0, 0.0, Double.NegativeInfinity, Double.PositiveInfinity)) + } + + test("correct handling of NaN") { + + // setup + val counter = new StatCounter(List(1, Double.NaN, 2)) + // count of 10 because it's larger than 0, + // and 0.95 because that's the default + val evaluator = new SumEvaluator(10, 0.95) + // arbitrarily assign id 1 + evaluator.merge(1, counter) + + // execute + val res = evaluator.currentResult() + // assert - note semantics of == in face of NaN + assert(res.mean.isNaN) + assert(res.confidence == 0.95) + assert(res.low == Double.NegativeInfinity) + assert(res.high == Double.PositiveInfinity) + } + + test("correct handling of > 1 values") { + + // setup + val counter = new StatCounter(List(1, 3, 2)) + // count of 10 because it's larger than 0, + // and 0.95 because that's the default + val evaluator = new SumEvaluator(10, 0.95) + // arbitrarily assign id 1 + evaluator.merge(1, counter) + + // execute + val res = evaluator.currentResult() + + // These vals because that's how the maths shakes out + val targetMean = 78.0 + val targetLow = -117.617 + 2.732357258139473E-5 + val targetHigh = 273.617 - 2.7323572624027292E-5 + val target = new BoundedDouble(targetMean, 0.95, targetLow, targetHigh) + + + // check that values are within expected tolerance of expectation + assert(res == target) + } + +} -- cgit v1.2.3 From 3f749f7ed443899d667c9e2b2a11bc595d6fc7f6 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Sun, 3 Apr 2016 18:14:16 -0700 Subject: [SPARK-14355][BUILD] Fix typos in Exception/Testcase/Comments and static analysis results ## What changes were proposed in this pull request? This PR contains the following 5 types of maintenance fix over 59 files (+94 lines, -93 lines). - Fix typos(exception/log strings, testcase name, comments) in 44 lines. - Fix lint-java errors (MaxLineLength) in 6 lines. (New codes after SPARK-14011) - Use diamond operators in 40 lines. (New codes after SPARK-13702) - Fix redundant semicolon in 5 lines. - Rename class `InferSchemaSuite` to `CSVInferSchemaSuite` in CSVInferSchemaSuite.scala. ## How was this patch tested? Manual and pass the Jenkins tests. Author: Dongjoon Hyun Closes #12139 from dongjoon-hyun/SPARK-14355. --- .../spark/network/client/TransportClientFactory.java | 2 +- .../spark/network/client/TransportResponseHandler.java | 6 +++--- .../spark/network/server/OneForOneStreamManager.java | 2 +- .../apache/spark/network/sasl/ShuffleSecretManager.java | 2 +- .../collection/unsafe/sort/UnsafeSorterSpillMerger.java | 2 +- core/src/main/scala/org/apache/spark/api/r/RRunner.scala | 2 +- .../scala/org/apache/spark/util/random/RandomSampler.scala | 2 +- .../spark/shuffle/sort/UnsafeShuffleWriterSuite.java | 4 ++-- .../main/java/org/apache/spark/examples/JavaLogQuery.java | 4 ++-- .../mllib/JavaMultiLabelClassificationMetricsExample.java | 14 +++++++------- .../mllib/JavaPowerIterationClusteringExample.java | 10 +++++----- .../examples/mllib/JavaStratifiedSamplingExample.java | 2 +- .../spark/examples/streaming/JavaFlumeEventCount.java | 4 ++-- .../apache/spark/streaming/flume/JavaFlumeStreamSuite.java | 11 ++++++----- .../org/apache/spark/launcher/CommandBuilderUtils.java | 2 +- .../main/java/org/apache/spark/launcher/SparkLauncher.java | 2 +- .../org/apache/spark/launcher/LauncherServerSuite.java | 2 +- .../spark/launcher/SparkSubmitCommandBuilderSuite.java | 2 +- .../spark/ml/classification/DecisionTreeClassifier.scala | 2 +- .../apache/spark/ml/param/shared/SharedParamsCodeGen.scala | 2 +- .../org/apache/spark/ml/param/shared/sharedParams.scala | 4 ++-- .../apache/spark/ml/regression/DecisionTreeRegressor.scala | 2 +- .../main/scala/org/apache/spark/ml/tree/treeModels.scala | 2 +- .../spark/mllib/classification/LogisticRegression.scala | 2 +- .../java/org/apache/spark/ml/param/JavaTestParams.java | 2 +- .../JavaStreamingLogisticRegressionSuite.java | 4 ++-- .../spark/mllib/clustering/JavaStreamingKMeansSuite.java | 4 ++-- .../org/apache/spark/mllib/linalg/JavaVectorsSuite.java | 4 ++-- .../regression/JavaStreamingLinearRegressionSuite.java | 4 ++-- .../apache/spark/ml/regression/LinearRegressionSuite.scala | 2 +- .../spark/sql/execution/UnsafeExternalRowSorter.java | 2 +- .../scala/org/apache/spark/sql/catalyst/CatalystConf.scala | 2 +- .../apache/spark/sql/catalyst/expressions/Expression.scala | 4 ++-- .../sql/catalyst/expressions/codegen/CodeGenerator.scala | 2 +- .../sql/catalyst/expressions/conditionalExpressions.scala | 2 +- .../org/apache/spark/sql/catalyst/parser/AstBuilder.scala | 4 ++-- .../src/test/scala/org/apache/spark/sql/RowTest.scala | 2 +- .../apache/spark/sql/execution/UnsafeKVExternalSorter.java | 2 +- .../spark/sql/execution/vectorized/ColumnarBatch.java | 4 ++-- .../spark/sql/execution/vectorized/OnHeapColumnVector.java | 2 +- .../main/scala/org/apache/spark/sql/ContinuousQuery.scala | 2 +- sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala | 2 +- .../org/apache/spark/sql/execution/SparkStrategies.scala | 2 +- .../main/scala/org/apache/spark/sql/execution/Window.scala | 2 +- .../org/apache/spark/sql/execution/basicOperators.scala | 2 +- .../spark/sql/execution/columnar/ColumnBuilder.scala | 2 +- .../apache/spark/sql/execution/stat/StatFunctions.scala | 2 +- .../spark/sql/execution/streaming/FileStreamSink.scala | 2 +- .../streaming/state/HDFSBackedStateStoreProvider.scala | 2 +- .../spark/sql/execution/streaming/state/StateStore.scala | 2 +- .../scala/org/apache/spark/sql/jdbc/JdbcDialects.scala | 2 +- .../java/test/org/apache/spark/sql/JavaDatasetSuite.java | 10 +++++----- .../src/test/scala/org/apache/spark/sql/QueryTest.scala | 4 ++-- .../execution/datasources/csv/CSVInferSchemaSuite.scala | 2 +- .../sql/execution/datasources/parquet/ParquetIOSuite.scala | 2 +- .../org/apache/spark/sql/streaming/FileStressSuite.scala | 2 +- .../scala/org/apache/spark/sql/hive/test/TestHive.scala | 4 ++-- .../spark/sql/hive/execution/HiveComparisonTest.scala | 2 +- .../scala/org/apache/spark/sql/hive/parquetSuites.scala | 4 ++-- 59 files changed, 94 insertions(+), 93 deletions(-) diff --git a/common/network-common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java b/common/network-common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java index 5a36e18b09..b5a9d6671f 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java +++ b/common/network-common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java @@ -94,7 +94,7 @@ public class TransportClientFactory implements Closeable { this.context = Preconditions.checkNotNull(context); this.conf = context.getConf(); this.clientBootstraps = Lists.newArrayList(Preconditions.checkNotNull(clientBootstraps)); - this.connectionPool = new ConcurrentHashMap(); + this.connectionPool = new ConcurrentHashMap<>(); this.numConnectionsPerPeer = conf.numConnectionsPerPeer(); this.rand = new Random(); diff --git a/common/network-common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java b/common/network-common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java index f0e2004d2d..8a69223c88 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java +++ b/common/network-common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java @@ -64,9 +64,9 @@ public class TransportResponseHandler extends MessageHandler { public TransportResponseHandler(Channel channel) { this.channel = channel; - this.outstandingFetches = new ConcurrentHashMap(); - this.outstandingRpcs = new ConcurrentHashMap(); - this.streamCallbacks = new ConcurrentLinkedQueue(); + this.outstandingFetches = new ConcurrentHashMap<>(); + this.outstandingRpcs = new ConcurrentHashMap<>(); + this.streamCallbacks = new ConcurrentLinkedQueue<>(); this.timeOfLastRequestNs = new AtomicLong(0); } diff --git a/common/network-common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java b/common/network-common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java index e2222ae085..ae7e520b2f 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java +++ b/common/network-common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java @@ -63,7 +63,7 @@ public class OneForOneStreamManager extends StreamManager { // For debugging purposes, start with a random stream id to help identifying different streams. // This does not need to be globally unique, only unique to this class. nextStreamId = new AtomicLong((long) new Random().nextInt(Integer.MAX_VALUE) * 1000); - streams = new ConcurrentHashMap(); + streams = new ConcurrentHashMap<>(); } @Override diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/sasl/ShuffleSecretManager.java b/common/network-shuffle/src/main/java/org/apache/spark/network/sasl/ShuffleSecretManager.java index 268cb40121..56a025c4d9 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/sasl/ShuffleSecretManager.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/sasl/ShuffleSecretManager.java @@ -37,7 +37,7 @@ public class ShuffleSecretManager implements SecretKeyHolder { private static final String SPARK_SASL_USER = "sparkSaslUser"; public ShuffleSecretManager() { - shuffleSecretMap = new ConcurrentHashMap(); + shuffleSecretMap = new ConcurrentHashMap<>(); } /** diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillMerger.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillMerger.java index 2b1c860e55..01aed95878 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillMerger.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillMerger.java @@ -45,7 +45,7 @@ final class UnsafeSorterSpillMerger { } } }; - priorityQueue = new PriorityQueue(numSpills, comparator); + priorityQueue = new PriorityQueue<>(numSpills, comparator); } /** diff --git a/core/src/main/scala/org/apache/spark/api/r/RRunner.scala b/core/src/main/scala/org/apache/spark/api/r/RRunner.scala index ff279ec270..07d1fa2c4a 100644 --- a/core/src/main/scala/org/apache/spark/api/r/RRunner.scala +++ b/core/src/main/scala/org/apache/spark/api/r/RRunner.scala @@ -182,7 +182,7 @@ private[spark] class RRunner[U]( } stream.flush() } catch { - // TODO: We should propogate this error to the task thread + // TODO: We should propagate this error to the task thread case e: Exception => logError("R Writer thread got an exception", e) } finally { diff --git a/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala b/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala index d397cca4b4..8c67364ef1 100644 --- a/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala +++ b/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala @@ -326,7 +326,7 @@ class GapSamplingReplacement( /** * Skip elements with replication factor zero (i.e. elements that won't be sampled). * Samples 'k' from geometric distribution P(k) = (1-q)(q)^k, where q = e^(-f), that is - * q is the probabililty of Poisson(0; f) + * q is the probability of Poisson(0; f) */ private def advance(): Unit = { val u = math.max(rng.nextDouble(), epsilon) diff --git a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java index 44733dcdaf..30750b1bf1 100644 --- a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java +++ b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java @@ -170,11 +170,11 @@ public class UnsafeShuffleWriterSuite { private UnsafeShuffleWriter createWriter( boolean transferToEnabled) throws IOException { conf.set("spark.file.transferTo", String.valueOf(transferToEnabled)); - return new UnsafeShuffleWriter( + return new UnsafeShuffleWriter<>( blockManager, shuffleBlockResolver, taskMemoryManager, - new SerializedShuffleHandle(0, 1, shuffleDep), + new SerializedShuffleHandle<>(0, 1, shuffleDep), 0, // map id taskContext, conf diff --git a/examples/src/main/java/org/apache/spark/examples/JavaLogQuery.java b/examples/src/main/java/org/apache/spark/examples/JavaLogQuery.java index 8abc03e73d..ebb0687b14 100644 --- a/examples/src/main/java/org/apache/spark/examples/JavaLogQuery.java +++ b/examples/src/main/java/org/apache/spark/examples/JavaLogQuery.java @@ -82,10 +82,10 @@ public final class JavaLogQuery { String user = m.group(3); String query = m.group(5); if (!user.equalsIgnoreCase("-")) { - return new Tuple3(ip, user, query); + return new Tuple3<>(ip, user, query); } } - return new Tuple3(null, null, null); + return new Tuple3<>(null, null, null); } public static Stats extractStats(String line) { diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaMultiLabelClassificationMetricsExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaMultiLabelClassificationMetricsExample.java index 5904260e2d..bc99dc023f 100644 --- a/examples/src/main/java/org/apache/spark/examples/mllib/JavaMultiLabelClassificationMetricsExample.java +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaMultiLabelClassificationMetricsExample.java @@ -34,13 +34,13 @@ public class JavaMultiLabelClassificationMetricsExample { JavaSparkContext sc = new JavaSparkContext(conf); // $example on$ List> data = Arrays.asList( - new Tuple2(new double[]{0.0, 1.0}, new double[]{0.0, 2.0}), - new Tuple2(new double[]{0.0, 2.0}, new double[]{0.0, 1.0}), - new Tuple2(new double[]{}, new double[]{0.0}), - new Tuple2(new double[]{2.0}, new double[]{2.0}), - new Tuple2(new double[]{2.0, 0.0}, new double[]{2.0, 0.0}), - new Tuple2(new double[]{0.0, 1.0, 2.0}, new double[]{0.0, 1.0}), - new Tuple2(new double[]{1.0}, new double[]{1.0, 2.0}) + new Tuple2<>(new double[]{0.0, 1.0}, new double[]{0.0, 2.0}), + new Tuple2<>(new double[]{0.0, 2.0}, new double[]{0.0, 1.0}), + new Tuple2<>(new double[]{}, new double[]{0.0}), + new Tuple2<>(new double[]{2.0}, new double[]{2.0}), + new Tuple2<>(new double[]{2.0, 0.0}, new double[]{2.0, 0.0}), + new Tuple2<>(new double[]{0.0, 1.0, 2.0}, new double[]{0.0, 1.0}), + new Tuple2<>(new double[]{1.0}, new double[]{1.0, 2.0}) ); JavaRDD> scoreAndLabels = sc.parallelize(data); diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaPowerIterationClusteringExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaPowerIterationClusteringExample.java index b62fa90c34..91c3bd72da 100644 --- a/examples/src/main/java/org/apache/spark/examples/mllib/JavaPowerIterationClusteringExample.java +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaPowerIterationClusteringExample.java @@ -40,11 +40,11 @@ public class JavaPowerIterationClusteringExample { @SuppressWarnings("unchecked") // $example on$ JavaRDD> similarities = sc.parallelize(Lists.newArrayList( - new Tuple3(0L, 1L, 0.9), - new Tuple3(1L, 2L, 0.9), - new Tuple3(2L, 3L, 0.9), - new Tuple3(3L, 4L, 0.1), - new Tuple3(4L, 5L, 0.9))); + new Tuple3<>(0L, 1L, 0.9), + new Tuple3<>(1L, 2L, 0.9), + new Tuple3<>(2L, 3L, 0.9), + new Tuple3<>(3L, 4L, 0.1), + new Tuple3<>(4L, 5L, 0.9))); PowerIterationClustering pic = new PowerIterationClustering() .setK(2) diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaStratifiedSamplingExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaStratifiedSamplingExample.java index c27fba2783..86c389e11c 100644 --- a/examples/src/main/java/org/apache/spark/examples/mllib/JavaStratifiedSamplingExample.java +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaStratifiedSamplingExample.java @@ -36,7 +36,7 @@ public class JavaStratifiedSamplingExample { JavaSparkContext jsc = new JavaSparkContext(conf); // $example on$ - List> list = new ArrayList>( + List> list = new ArrayList<>( Arrays.>asList( new Tuple2(1, 'a'), new Tuple2(1, 'b'), diff --git a/examples/src/main/java/org/apache/spark/examples/streaming/JavaFlumeEventCount.java b/examples/src/main/java/org/apache/spark/examples/streaming/JavaFlumeEventCount.java index da56637fe8..bae4b78ac2 100644 --- a/examples/src/main/java/org/apache/spark/examples/streaming/JavaFlumeEventCount.java +++ b/examples/src/main/java/org/apache/spark/examples/streaming/JavaFlumeEventCount.java @@ -19,7 +19,6 @@ package org.apache.spark.examples.streaming; import org.apache.spark.SparkConf; import org.apache.spark.api.java.function.Function; -import org.apache.spark.examples.streaming.StreamingExamples; import org.apache.spark.streaming.*; import org.apache.spark.streaming.api.java.*; import org.apache.spark.streaming.flume.FlumeUtils; @@ -58,7 +57,8 @@ public final class JavaFlumeEventCount { Duration batchInterval = new Duration(2000); SparkConf sparkConf = new SparkConf().setAppName("JavaFlumeEventCount"); JavaStreamingContext ssc = new JavaStreamingContext(sparkConf, batchInterval); - JavaReceiverInputDStream flumeStream = FlumeUtils.createStream(ssc, host, port); + JavaReceiverInputDStream flumeStream = + FlumeUtils.createStream(ssc, host, port); flumeStream.count(); diff --git a/external/flume/src/test/java/org/apache/spark/streaming/flume/JavaFlumeStreamSuite.java b/external/flume/src/test/java/org/apache/spark/streaming/flume/JavaFlumeStreamSuite.java index 3b5e0c7746..ada05f203b 100644 --- a/external/flume/src/test/java/org/apache/spark/streaming/flume/JavaFlumeStreamSuite.java +++ b/external/flume/src/test/java/org/apache/spark/streaming/flume/JavaFlumeStreamSuite.java @@ -27,10 +27,11 @@ public class JavaFlumeStreamSuite extends LocalJavaStreamingContext { @Test public void testFlumeStream() { // tests the API, does not actually test data receiving - JavaReceiverInputDStream test1 = FlumeUtils.createStream(ssc, "localhost", 12345); - JavaReceiverInputDStream test2 = FlumeUtils.createStream(ssc, "localhost", 12345, - StorageLevel.MEMORY_AND_DISK_SER_2()); - JavaReceiverInputDStream test3 = FlumeUtils.createStream(ssc, "localhost", 12345, - StorageLevel.MEMORY_AND_DISK_SER_2(), false); + JavaReceiverInputDStream test1 = FlumeUtils.createStream(ssc, "localhost", + 12345); + JavaReceiverInputDStream test2 = FlumeUtils.createStream(ssc, "localhost", + 12345, StorageLevel.MEMORY_AND_DISK_SER_2()); + JavaReceiverInputDStream test3 = FlumeUtils.createStream(ssc, "localhost", + 12345, StorageLevel.MEMORY_AND_DISK_SER_2(), false); } } diff --git a/launcher/src/main/java/org/apache/spark/launcher/CommandBuilderUtils.java b/launcher/src/main/java/org/apache/spark/launcher/CommandBuilderUtils.java index 1e55aad5c9..a08c8dcba4 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/CommandBuilderUtils.java +++ b/launcher/src/main/java/org/apache/spark/launcher/CommandBuilderUtils.java @@ -34,7 +34,7 @@ class CommandBuilderUtils { /** The set of known JVM vendors. */ enum JavaVendor { Oracle, IBM, OpenJDK, Unknown - }; + } /** Returns whether the given string is null or empty. */ static boolean isEmpty(String s) { diff --git a/launcher/src/main/java/org/apache/spark/launcher/SparkLauncher.java b/launcher/src/main/java/org/apache/spark/launcher/SparkLauncher.java index a542159901..a083f05a2a 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/SparkLauncher.java +++ b/launcher/src/main/java/org/apache/spark/launcher/SparkLauncher.java @@ -477,6 +477,6 @@ public class SparkLauncher { // No op. } - }; + } } diff --git a/launcher/src/test/java/org/apache/spark/launcher/LauncherServerSuite.java b/launcher/src/test/java/org/apache/spark/launcher/LauncherServerSuite.java index 5bf2babdd1..a9039b3ec9 100644 --- a/launcher/src/test/java/org/apache/spark/launcher/LauncherServerSuite.java +++ b/launcher/src/test/java/org/apache/spark/launcher/LauncherServerSuite.java @@ -175,7 +175,7 @@ public class LauncherServerSuite extends BaseSuite { TestClient(Socket s) throws IOException { super(s); - this.inbound = new LinkedBlockingQueue(); + this.inbound = new LinkedBlockingQueue<>(); this.clientThread = new Thread(this); clientThread.setName("TestClient"); clientThread.setDaemon(true); diff --git a/launcher/src/test/java/org/apache/spark/launcher/SparkSubmitCommandBuilderSuite.java b/launcher/src/test/java/org/apache/spark/launcher/SparkSubmitCommandBuilderSuite.java index b7f4f2efc5..29cbbe825b 100644 --- a/launcher/src/test/java/org/apache/spark/launcher/SparkSubmitCommandBuilderSuite.java +++ b/launcher/src/test/java/org/apache/spark/launcher/SparkSubmitCommandBuilderSuite.java @@ -160,7 +160,7 @@ public class SparkSubmitCommandBuilderSuite extends BaseSuite { "SparkPi", "42"); - Map env = new HashMap(); + Map env = new HashMap<>(); List cmd = buildCommand(sparkSubmitArgs, env); assertEquals("foo", findArgValue(cmd, parser.MASTER)); assertEquals("bar", findArgValue(cmd, parser.DEPLOY_MODE)); diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala index 23c4af17f9..4525bf71f6 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala @@ -205,7 +205,7 @@ final class DecisionTreeClassificationModel private[ml] ( @Since("2.0.0") lazy val featureImportances: Vector = TreeEnsembleModel.featureImportances(this, numFeatures) - /** Convert to spark.mllib DecisionTreeModel (losing some infomation) */ + /** Convert to spark.mllib DecisionTreeModel (losing some information) */ override private[spark] def toOld: OldDecisionTreeModel = { new OldDecisionTreeModel(rootNode.toOld(1), OldAlgo.Classification) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala index 3ce129b12c..1d03a5b4f4 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala @@ -62,7 +62,7 @@ private[shared] object SharedParamsCodeGen { "every 10 iterations", isValid = "(interval: Int) => interval == -1 || interval >= 1"), ParamDesc[Boolean]("fitIntercept", "whether to fit an intercept term", Some("true")), ParamDesc[String]("handleInvalid", "how to handle invalid entries. Options are skip (which " + - "will filter out rows with bad values), or error (which will throw an errror). More " + + "will filter out rows with bad values), or error (which will throw an error). More " + "options may be added later", isValid = "ParamValidators.inArray(Array(\"skip\", \"error\"))"), ParamDesc[Boolean]("standardization", "whether to standardize the training features" + diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala index 96263c5baf..64d6af2766 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala @@ -270,10 +270,10 @@ private[ml] trait HasFitIntercept extends Params { private[ml] trait HasHandleInvalid extends Params { /** - * Param for how to handle invalid entries. Options are skip (which will filter out rows with bad values), or error (which will throw an errror). More options may be added later. + * Param for how to handle invalid entries. Options are skip (which will filter out rows with bad values), or error (which will throw an error). More options may be added later. * @group param */ - final val handleInvalid: Param[String] = new Param[String](this, "handleInvalid", "how to handle invalid entries. Options are skip (which will filter out rows with bad values), or error (which will throw an errror). More options may be added later", ParamValidators.inArray(Array("skip", "error"))) + final val handleInvalid: Param[String] = new Param[String](this, "handleInvalid", "how to handle invalid entries. Options are skip (which will filter out rows with bad values), or error (which will throw an error). More options may be added later", ParamValidators.inArray(Array("skip", "error"))) /** @group getParam */ final def getHandleInvalid: String = $(handleInvalid) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala index 0a3d00e470..1289a317ee 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala @@ -205,7 +205,7 @@ final class DecisionTreeRegressionModel private[ml] ( @Since("2.0.0") lazy val featureImportances: Vector = TreeEnsembleModel.featureImportances(this, numFeatures) - /** Convert to spark.mllib DecisionTreeModel (losing some infomation) */ + /** Convert to spark.mllib DecisionTreeModel (losing some information) */ override private[spark] def toOld: OldDecisionTreeModel = { new OldDecisionTreeModel(rootNode.toOld(1), OldAlgo.Regression) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala index 1fad9d6d8c..8ea767b2b3 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala @@ -71,7 +71,7 @@ private[spark] trait DecisionTreeModel { */ private[ml] def maxSplitFeatureIndex(): Int = rootNode.maxSplitFeatureIndex() - /** Convert to spark.mllib DecisionTreeModel (losing some infomation) */ + /** Convert to spark.mllib DecisionTreeModel (losing some information) */ private[spark] def toOld: OldDecisionTreeModel } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala index c0404be019..f10570e662 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala @@ -418,7 +418,7 @@ class LogisticRegressionWithLBFGS private def run(input: RDD[LabeledPoint], initialWeights: Vector, userSuppliedWeights: Boolean): LogisticRegressionModel = { - // ml's Logisitic regression only supports binary classifcation currently. + // ml's Logistic regression only supports binary classification currently. if (numOfLinearPredictor == 1) { def runWithMlLogisitcRegression(elasticNetParam: Double) = { // Prepare the ml LogisticRegression based on our settings diff --git a/mllib/src/test/java/org/apache/spark/ml/param/JavaTestParams.java b/mllib/src/test/java/org/apache/spark/ml/param/JavaTestParams.java index 65841182df..06f7fbb86e 100644 --- a/mllib/src/test/java/org/apache/spark/ml/param/JavaTestParams.java +++ b/mllib/src/test/java/org/apache/spark/ml/param/JavaTestParams.java @@ -89,7 +89,7 @@ public class JavaTestParams extends JavaParams { myDoubleParam_ = new DoubleParam(this, "myDoubleParam", "this is a double param", ParamValidators.inRange(0.0, 1.0)); List validStrings = Arrays.asList("a", "b"); - myStringParam_ = new Param(this, "myStringParam", "this is a string param", + myStringParam_ = new Param<>(this, "myStringParam", "this is a string param", ParamValidators.inArray(validStrings)); myDoubleArrayParam_ = new DoubleArrayParam(this, "myDoubleArrayParam", "this is a double param"); diff --git a/mllib/src/test/java/org/apache/spark/mllib/classification/JavaStreamingLogisticRegressionSuite.java b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaStreamingLogisticRegressionSuite.java index c9e5ee22f3..62c6d9b7e3 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/classification/JavaStreamingLogisticRegressionSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaStreamingLogisticRegressionSuite.java @@ -66,8 +66,8 @@ public class JavaStreamingLogisticRegressionSuite implements Serializable { JavaDStream training = attachTestInputStream(ssc, Arrays.asList(trainingBatch, trainingBatch), 2); List> testBatch = Arrays.asList( - new Tuple2(10, Vectors.dense(1.0)), - new Tuple2(11, Vectors.dense(0.0))); + new Tuple2<>(10, Vectors.dense(1.0)), + new Tuple2<>(11, Vectors.dense(0.0))); JavaPairDStream test = JavaPairDStream.fromJavaDStream( attachTestInputStream(ssc, Arrays.asList(testBatch, testBatch), 2)); StreamingLogisticRegressionWithSGD slr = new StreamingLogisticRegressionWithSGD() diff --git a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaStreamingKMeansSuite.java b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaStreamingKMeansSuite.java index d644766d1e..62edbd3a29 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaStreamingKMeansSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaStreamingKMeansSuite.java @@ -66,8 +66,8 @@ public class JavaStreamingKMeansSuite implements Serializable { JavaDStream training = attachTestInputStream(ssc, Arrays.asList(trainingBatch, trainingBatch), 2); List> testBatch = Arrays.asList( - new Tuple2(10, Vectors.dense(1.0)), - new Tuple2(11, Vectors.dense(0.0))); + new Tuple2<>(10, Vectors.dense(1.0)), + new Tuple2<>(11, Vectors.dense(0.0))); JavaPairDStream test = JavaPairDStream.fromJavaDStream( attachTestInputStream(ssc, Arrays.asList(testBatch, testBatch), 2)); StreamingKMeans skmeans = new StreamingKMeans() diff --git a/mllib/src/test/java/org/apache/spark/mllib/linalg/JavaVectorsSuite.java b/mllib/src/test/java/org/apache/spark/mllib/linalg/JavaVectorsSuite.java index 77c8c6274f..4ba8e543a9 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/linalg/JavaVectorsSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/linalg/JavaVectorsSuite.java @@ -37,8 +37,8 @@ public class JavaVectorsSuite implements Serializable { public void sparseArrayConstruction() { @SuppressWarnings("unchecked") Vector v = Vectors.sparse(3, Arrays.asList( - new Tuple2(0, 2.0), - new Tuple2(2, 3.0))); + new Tuple2<>(0, 2.0), + new Tuple2<>(2, 3.0))); assertArrayEquals(new double[]{2.0, 0.0, 3.0}, v.toArray(), 0.0); } } diff --git a/mllib/src/test/java/org/apache/spark/mllib/regression/JavaStreamingLinearRegressionSuite.java b/mllib/src/test/java/org/apache/spark/mllib/regression/JavaStreamingLinearRegressionSuite.java index dbf6488d41..ea0ccd7448 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/regression/JavaStreamingLinearRegressionSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/regression/JavaStreamingLinearRegressionSuite.java @@ -65,8 +65,8 @@ public class JavaStreamingLinearRegressionSuite implements Serializable { JavaDStream training = attachTestInputStream(ssc, Arrays.asList(trainingBatch, trainingBatch), 2); List> testBatch = Arrays.asList( - new Tuple2(10, Vectors.dense(1.0)), - new Tuple2(11, Vectors.dense(0.0))); + new Tuple2<>(10, Vectors.dense(1.0)), + new Tuple2<>(11, Vectors.dense(0.0))); JavaPairDStream test = JavaPairDStream.fromJavaDStream( attachTestInputStream(ssc, Arrays.asList(testBatch, testBatch), 2)); StreamingLinearRegressionWithSGD slr = new StreamingLinearRegressionWithSGD() diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala index cccb7f8d1b..eb19d13093 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala @@ -759,7 +759,7 @@ class LinearRegressionSuite .sliding(2) .forall(x => x(0) >= x(1))) } else { - // To clalify that the normal solver is used here. + // To clarify that the normal solver is used here. assert(model.summary.objectiveHistory.length == 1) assert(model.summary.objectiveHistory(0) == 0.0) val devianceResidualsR = Array(-0.47082, 0.34635) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java index aa7fc2121e..7784345a7a 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java @@ -151,7 +151,7 @@ public final class UnsafeExternalRowSorter { Platform.throwException(e); } throw new RuntimeException("Exception should have been re-thrown in next()"); - }; + } }; } catch (IOException e) { cleanupResources(); diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystConf.scala index d5ac01500b..2b98aacdd7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystConf.scala @@ -26,7 +26,7 @@ private[spark] trait CatalystConf { def groupByOrdinal: Boolean /** - * Returns the [[Resolver]] for the current configuration, which can be used to determin if two + * Returns the [[Resolver]] for the current configuration, which can be used to determine if two * identifiers are equal. */ def resolver: Resolver = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 5f8899d599..a24a5db8d4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -153,8 +153,8 @@ abstract class Expression extends TreeNode[Expression] { * evaluate to the same result. */ lazy val canonicalized: Expression = { - val canonicalizedChildred = children.map(_.canonicalized) - Canonicalize.execute(withNewChildren(canonicalizedChildred)) + val canonicalizedChildren = children.map(_.canonicalized) + Canonicalize.execute(withNewChildren(canonicalizedChildren)) } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index b64d3eea49..1bebd4e904 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -509,7 +509,7 @@ class CodegenContext { /** * Checks and sets up the state and codegen for subexpression elimination. This finds the - * common subexpresses, generates the functions that evaluate those expressions and populates + * common subexpressions, generates the functions that evaluate those expressions and populates * the mapping of common subexpressions to the generated functions. */ private def subexpressionElimination(expressions: Seq[Expression]) = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala index 103ab365e3..35a7b46020 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala @@ -222,7 +222,7 @@ object CaseWhen { } /** - * A factory method to faciliate the creation of this expression when used in parsers. + * A factory method to facilitate the creation of this expression when used in parsers. * @param branches Expressions at even position are the branch conditions, and expressions at odd * position are branch values. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 8541b1f7c6..61ea3e4010 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -965,7 +965,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { /** * Create a binary arithmetic expression. The following arithmetic operators are supported: - * - Mulitplication: '*' + * - Multiplication: '*' * - Division: '/' * - Hive Long Division: 'DIV' * - Modulo: '%' @@ -1270,7 +1270,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { } /** - * Create a double literal for a number denoted in scientifc notation. + * Create a double literal for a number denoted in scientific notation. */ override def visitScientificDecimalLiteral( ctx: ScientificDecimalLiteralContext): Literal = withOrigin(ctx) { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/RowTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/RowTest.scala index d9577dea1b..c9c9599e7f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/RowTest.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/RowTest.scala @@ -121,7 +121,7 @@ class RowTest extends FunSpec with Matchers { externalRow should be theSameInstanceAs externalRow.copy() } - it("copy should return same ref for interal rows") { + it("copy should return same ref for internal rows") { internalRow should be theSameInstanceAs internalRow.copy() } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java index d3bfb00b3f..8132bba04c 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java @@ -272,5 +272,5 @@ public final class UnsafeKVExternalSorter { public void close() { cleanupResources(); } - }; + } } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java index 792e17911f..d1cc4e6d03 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java @@ -79,7 +79,7 @@ public final class ColumnarBatch { /** * Called to close all the columns in this batch. It is not valid to access the data after - * calling this. This must be called at the end to clean up memory allcoations. + * calling this. This must be called at the end to clean up memory allocations. */ public void close() { for (ColumnVector c: columns) { @@ -315,7 +315,7 @@ public final class ColumnarBatch { public int numRows() { return numRows; } /** - * Returns the number of valid rowss. + * Returns the number of valid rows. */ public int numValidRows() { assert(numRowsFiltered <= numRows); diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java index b1429fe7cb..708a00953a 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java @@ -212,7 +212,7 @@ public final class OnHeapColumnVector extends ColumnVector { public void putIntsLittleEndian(int rowId, int count, byte[] src, int srcIndex) { int srcOffset = srcIndex + Platform.BYTE_ARRAY_OFFSET; for (int i = 0; i < count; ++i) { - intData[i + rowId] = Platform.getInt(src, srcOffset);; + intData[i + rowId] = Platform.getInt(src, srcOffset); srcIndex += 4; srcOffset += 4; } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/ContinuousQuery.scala b/sql/core/src/main/scala/org/apache/spark/sql/ContinuousQuery.scala index 1dc9a6893e..d9973b092d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/ContinuousQuery.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/ContinuousQuery.scala @@ -94,7 +94,7 @@ trait ContinuousQuery { /** * Blocks until all available data in the source has been processed an committed to the sink. * This method is intended for testing. Note that in the case of continually arriving data, this - * method may block forever. Additionally, this method is only guranteed to block until data that + * method may block forever. Additionally, this method is only guaranteed to block until data that * has been synchronously appended data to a [[org.apache.spark.sql.execution.streaming.Source]] * prior to invocation. (i.e. `getOffset` must immediately reflect the addition). */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 41cb799b97..a39a2113e5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -2077,7 +2077,7 @@ class Dataset[T] private[sql]( /** * Returns a new [[Dataset]] partitioned by the given partitioning expressions into - * `numPartitions`. The resulting Datasetis hash partitioned. + * `numPartitions`. The resulting Dataset is hash partitioned. * * This is the same operation as "DISTRIBUTE BY" in SQL (Hive QL). * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 5bcc172ca7..e1fabf519a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -108,7 +108,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { /** * Matches a plan whose single partition should be small enough to build a hash table. * - * Note: this assume that the number of partition is fixed, requires addtional work if it's + * Note: this assume that the number of partition is fixed, requires additional work if it's * dynamic. */ def canBuildHashMap(plan: LogicalPlan): Boolean = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala index 806089196c..8e9214fa25 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala @@ -811,7 +811,7 @@ private[execution] final class UnboundedPrecedingWindowFunctionFrame( * * This is a very expensive operator to use, O(n * (n - 1) /2), because we need to maintain a * buffer and must do full recalculation after each row. Reverse iteration would be possible, if - * the communitativity of the used window functions can be guaranteed. + * the commutativity of the used window functions can be guaranteed. * * @param target to write results to. * @param processor to calculate the row values with. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index fb1c6182cf..aba500ad8d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -146,7 +146,7 @@ case class Filter(condition: Expression, child: SparkPlan) // This has the property of not doing redundant IsNotNull checks and taking better advantage of // short-circuiting, not loading attributes until they are needed. // This is very perf sensitive. - // TODO: revisit this. We can consider reodering predicates as well. + // TODO: revisit this. We can consider reordering predicates as well. val generatedIsNotNullChecks = new Array[Boolean](notNullPreds.length) val generated = otherPreds.map { c => val nullChecks = c.references.map { r => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnBuilder.scala index 7e26f19bb7..9a173367f4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnBuilder.scala @@ -185,7 +185,7 @@ private[columnar] object ColumnBuilder { case udt: UserDefinedType[_] => return apply(udt.sqlType, initialSize, columnName, useCompression) case other => - throw new Exception(s"not suppported type: $other") + throw new Exception(s"not supported type: $other") } builder.initialize(initialSize, columnName, useCompression) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala index e0b6709c51..d603f63a08 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala @@ -296,7 +296,7 @@ private[sql] object StatFunctions extends Logging { val defaultRelativeError: Double = 0.01 /** - * Statisttics from the Greenwald-Khanna paper. + * Statistics from the Greenwald-Khanna paper. * @param value the sampled value * @param g the minimum rank jump from the previous value's minimum rank * @param delta the maximum span of the rank. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala index e819e95d61..6921ae584d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala @@ -32,7 +32,7 @@ object FileStreamSink { /** * A sink that writes out results to parquet files. Each batch is written out to a unique - * directory. After all of the files in a batch have been succesfully written, the list of + * directory. After all of the files in a batch have been successfully written, the list of * file paths is appended to the log atomically. In the case of partial failures, some duplicate * data may be present in the target directory, but only one copy of each file will be present * in the log. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala index 8ece3c971a..1e0a4a5d4f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala @@ -178,7 +178,7 @@ private[state] class HDFSBackedStateStoreProvider( * This can be called only after committing all the updates made in the current thread. */ override def iterator(): Iterator[(UnsafeRow, UnsafeRow)] = { - verify(state == COMMITTED, "Cannot get iterator of store data before comitting") + verify(state == COMMITTED, "Cannot get iterator of store data before committing") HDFSBackedStateStoreProvider.this.iterator(newVersion) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala index d60e6185ac..07f63f928b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala @@ -220,7 +220,7 @@ private[state] object StateStore extends Logging { val executorId = SparkEnv.get.blockManager.blockManagerId.executorId val verified = coordinatorRef.map(_.verifyIfInstanceActive(storeId, executorId)).getOrElse(false) - logDebug(s"Verifyied whether the loaded instance $storeId is active: $verified" ) + logDebug(s"Verified whether the loaded instance $storeId is active: $verified" ) verified } catch { case NonFatal(e) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala index ca2d909e2c..cfe4911cb7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala @@ -126,7 +126,7 @@ object JdbcDialects { /** * Register a dialect for use on all new matching jdbc [[org.apache.spark.sql.DataFrame]]. - * Readding an existing dialect will cause a move-to-front. + * Reading an existing dialect will cause a move-to-front. * * @param dialect The new dialect. */ diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java index a5ab446e08..873f681bdf 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java @@ -318,14 +318,14 @@ public class JavaDatasetSuite implements Serializable { Encoder> encoder3 = Encoders.tuple(Encoders.INT(), Encoders.LONG(), Encoders.STRING()); List> data3 = - Arrays.asList(new Tuple3(1, 2L, "a")); + Arrays.asList(new Tuple3<>(1, 2L, "a")); Dataset> ds3 = context.createDataset(data3, encoder3); Assert.assertEquals(data3, ds3.collectAsList()); Encoder> encoder4 = Encoders.tuple(Encoders.INT(), Encoders.STRING(), Encoders.LONG(), Encoders.STRING()); List> data4 = - Arrays.asList(new Tuple4(1, "b", 2L, "a")); + Arrays.asList(new Tuple4<>(1, "b", 2L, "a")); Dataset> ds4 = context.createDataset(data4, encoder4); Assert.assertEquals(data4, ds4.collectAsList()); @@ -333,7 +333,7 @@ public class JavaDatasetSuite implements Serializable { Encoders.tuple(Encoders.INT(), Encoders.STRING(), Encoders.LONG(), Encoders.STRING(), Encoders.BOOLEAN()); List> data5 = - Arrays.asList(new Tuple5(1, "b", 2L, "a", true)); + Arrays.asList(new Tuple5<>(1, "b", 2L, "a", true)); Dataset> ds5 = context.createDataset(data5, encoder5); Assert.assertEquals(data5, ds5.collectAsList()); @@ -354,7 +354,7 @@ public class JavaDatasetSuite implements Serializable { Encoders.tuple(Encoders.INT(), Encoders.tuple(Encoders.STRING(), Encoders.STRING(), Encoders.LONG())); List>> data2 = - Arrays.asList(tuple2(1, new Tuple3("a", "b", 3L))); + Arrays.asList(tuple2(1, new Tuple3<>("a", "b", 3L))); Dataset>> ds2 = context.createDataset(data2, encoder2); Assert.assertEquals(data2, ds2.collectAsList()); @@ -376,7 +376,7 @@ public class JavaDatasetSuite implements Serializable { Encoders.tuple(Encoders.DOUBLE(), Encoders.DECIMAL(), Encoders.DATE(), Encoders.TIMESTAMP(), Encoders.FLOAT()); List> data = - Arrays.asList(new Tuple5( + Arrays.asList(new Tuple5<>( 1.7976931348623157E308, new BigDecimal("0.922337203685477589"), Date.valueOf("1970-01-01"), new Timestamp(System.currentTimeMillis()), Float.MAX_VALUE)); Dataset> ds = diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index d160f8ab8c..f7f3bd78e9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -105,10 +105,10 @@ abstract class QueryTest extends PlanTest { val expected = expectedAnswer.toSet.toSeq.map((a: Any) => a.toString).sorted val actual = decoded.toSet.toSeq.map((a: Any) => a.toString).sorted - val comparision = sideBySide("expected" +: expected, "spark" +: actual).mkString("\n") + val comparison = sideBySide("expected" +: expected, "spark" +: actual).mkString("\n") fail( s"""Decoded objects do not match expected objects: - |$comparision + |$comparison |${ds.resolvedTEncoder.deserializer.treeString} """.stripMargin) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchemaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchemaSuite.scala index 3a7cb25b4f..23d422635b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchemaSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchemaSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.datasources.csv import org.apache.spark.SparkFunSuite import org.apache.spark.sql.types._ -class InferSchemaSuite extends SparkFunSuite { +class CSVInferSchemaSuite extends SparkFunSuite { test("String fields types are inferred correctly from null types") { assert(CSVInferSchema.inferField(NullType, "") == NullType) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala index 9746187d22..a3017258d6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala @@ -469,7 +469,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { } } - testQuietly("SPARK-9849 DirectParquetOutputCommitter qualified name backwards compatiblity") { + testQuietly("SPARK-9849 DirectParquetOutputCommitter qualified name backwards compatibility") { val clonedConf = new Configuration(hadoopConfiguration) // Write to a parquet file and let it fail. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStressSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStressSuite.scala index 3916430cdf..5b49a0a86a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStressSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStressSuite.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.util.Utils /** - * A stress test for streamign queries that read and write files. This test constists of + * A stress test for streaming queries that read and write files. This test consists of * two threads: * - one that writes out `numRecords` distinct integers to files of random sizes (the total * number of records is fixed but each files size / creation time is random). diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala index 4afc8d18a6..9393302355 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala @@ -380,8 +380,8 @@ class TestHiveContext private[hive]( """.stripMargin.cmd, s"LOAD DATA LOCAL INPATH '${getHiveFile("data/files/episodes.avro")}' INTO TABLE episodes".cmd ), - // THIS TABLE IS NOT THE SAME AS THE HIVE TEST TABLE episodes_partitioned AS DYNAMIC PARITIONING - // IS NOT YET SUPPORTED + // THIS TABLE IS NOT THE SAME AS THE HIVE TEST TABLE episodes_partitioned AS DYNAMIC + // PARTITIONING IS NOT YET SUPPORTED TestTable("episodes_part", s"""CREATE TABLE episodes_part (title STRING, air_date STRING, doctor INT) |PARTITIONED BY (doctor_pt INT) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala index 4c1b425b16..e67fcbedc3 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala @@ -482,7 +482,7 @@ abstract class HiveComparisonTest val tablesGenerated = queryList.zip(executions).flatMap { // We should take executedPlan instead of sparkPlan, because in following codes we // will run the collected plans. As we will do extra processing for sparkPlan such - // as adding exchage, collapsing codegen stages, etc., collecing sparkPlan here + // as adding exchange, collapsing codegen stages, etc., collecting sparkPlan here // will cause some errors when running these plans later. case (q, e) => e.executedPlan.collect { case i: InsertIntoHiveTable if tablesRead contains i.table.tableName => diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala index b6fc61d453..eac65d5720 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala @@ -311,7 +311,7 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest { case ExecutedCommand(_: InsertIntoHadoopFsRelation) => // OK case o => fail("test_insert_parquet should be converted to a " + s"${classOf[HadoopFsRelation ].getCanonicalName} and " + - s"${classOf[InsertIntoDataSource].getCanonicalName} is expcted as the SparkPlan. " + + s"${classOf[InsertIntoDataSource].getCanonicalName} is expected as the SparkPlan. " + s"However, found a ${o.toString} ") } @@ -341,7 +341,7 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest { case ExecutedCommand(_: InsertIntoHadoopFsRelation) => // OK case o => fail("test_insert_parquet should be converted to a " + s"${classOf[HadoopFsRelation ].getCanonicalName} and " + - s"${classOf[InsertIntoDataSource].getCanonicalName} is expcted as the SparkPlan." + + s"${classOf[InsertIntoDataSource].getCanonicalName} is expected as the SparkPlan." + s"However, found a ${o.toString} ") } -- cgit v1.2.3 From 76f3c735aa300d7ea6b17e64cc22d7e8fc3a8322 Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Sun, 3 Apr 2016 21:08:54 -0700 Subject: [SPARK-14356] Update spark.sql.execution.debug to work on Datasets ## What changes were proposed in this pull request? Update DebugQuery to work on Datasets of any type, not just DataFrames. ## How was this patch tested? Added unit tests, checked in spark-shell. Author: Matei Zaharia Closes #12140 from mateiz/debug-dataset. --- .../main/scala/org/apache/spark/sql/execution/debug/package.scala | 4 ++-- .../scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala | 6 ++++++ 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala index 7b0c8ebdfa..17eae88b49 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala @@ -77,9 +77,9 @@ package object debug { } /** - * Augments [[DataFrame]]s with debug methods. + * Augments [[Dataset]]s with debug methods. */ - implicit class DebugQuery(query: DataFrame) extends Logging { + implicit class DebugQuery(query: Dataset[_]) extends Logging { def debug(): Unit = { val plan = query.queryExecution.executedPlan val visited = new collection.mutable.HashSet[TreeNodeRef]() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala index c0fce4b96a..8aa0114d98 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.debug import org.apache.spark.SparkFunSuite import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.test.SQLTestData.TestData class DebuggingSuite extends SparkFunSuite with SharedSQLContext { @@ -26,6 +27,11 @@ class DebuggingSuite extends SparkFunSuite with SharedSQLContext { testData.debug() } + test("Dataset.debug()") { + import testImplicits._ + testData.as[TestData].debug() + } + test("debugCodegen") { val res = codegenString(sqlContext.range(10).groupBy("id").count().queryExecution.executedPlan) assert(res.contains("Subtree 1 / 2")) -- cgit v1.2.3 From 0340b3d279de6be4903673bbf3e6a1a2653de6c0 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Mon, 4 Apr 2016 09:58:01 +0200 Subject: [SPARK-14360][SQL] QueryExecution.debug.codegen() to dump codegen ## What changes were proposed in this pull request? We recently added the ability to dump the generated code for a given query. However, the method is only available through an implicit after an import. It'd slightly simplify things if it can be called directly in queryExecution. ## How was this patch tested? Manually tested in spark-shell. Author: Reynold Xin Closes #12144 from rxin/SPARK-14360. --- .../org/apache/spark/sql/execution/QueryExecution.scala | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala index 4843553211..63eb1aa24e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala @@ -103,4 +103,20 @@ class QueryExecution(val sqlContext: SQLContext, val logical: LogicalPlan) { |${stringOrError(executedPlan)} """.stripMargin.trim } + + /** A special namespace for commands that can be used to debug query execution. */ + // scalastyle:off + object debug { + // scalastyle:on + + /** + * Prints to stdout all the generated code found in this plan (i.e. the output of each + * WholeStageCodegen subtree). + */ + def codegen(): Unit = { + // scalastyle:off println + println(org.apache.spark.sql.execution.debug.codegenString(executedPlan)) + // scalastyle:on println + } + } } -- cgit v1.2.3 From 745425332f41e2ae94649f9d1ad675243f36f743 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Mon, 4 Apr 2016 10:01:24 -0700 Subject: [SPARK-14137] [SQL] Cleanup hash join ## What changes were proposed in this pull request? This PR did a few cleanup on HashedRelation and HashJoin: 1) Merge HashedRelation and UniqueHashedRelation together 2) Return an iterator from HashedRelation, so we donot need a create many UnsafeRow objects. 3) Return a copy of HashedRelation for thread-safety in BroadcastJoin, so we can re-use the UnafeRow objects. 4) Cleanup HashJoin, share most of the code between BroadcastHashJoin and ShuffleHashJoin 5) Removed UniqueLongHashedRelation, which will be replaced by LongUnsafeMap (another PR). 6) Update benchmark, before this patch, the selectivity of joins are too high. ## How was this patch tested? Existing tests. Author: Davies Liu Closes #12102 from davies/cleanup_hash. --- .../sql/execution/joins/BroadcastHashJoin.scala | 78 ++---- .../spark/sql/execution/joins/HashJoin.scala | 217 +++++++---------- .../spark/sql/execution/joins/HashedRelation.scala | 264 +++++++++------------ .../sql/execution/joins/ShuffledHashJoin.scala | 32 +-- .../sql/execution/BenchmarkWholeStageCodegen.scala | 64 +++-- .../sql/execution/joins/HashedRelationSuite.scala | 14 +- 6 files changed, 268 insertions(+), 401 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala index 41e566c27b..67ac9e94ff 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala @@ -68,37 +68,9 @@ case class BroadcastHashJoin( val broadcastRelation = buildPlan.executeBroadcast[HashedRelation]() streamedPlan.execute().mapPartitions { streamedIter => - val joinedRow = new JoinedRow() - val hashTable = broadcastRelation.value - TaskContext.get().taskMetrics().incPeakExecutionMemory(hashTable.getMemorySize) - val keyGenerator = streamSideKeyGenerator - val resultProj = createResultProjection - - joinType match { - case Inner => - hashJoin(streamedIter, hashTable, numOutputRows) - - case LeftOuter => - streamedIter.flatMap { currentRow => - val rowKey = keyGenerator(currentRow) - joinedRow.withLeft(currentRow) - leftOuterIterator(rowKey, joinedRow, hashTable.get(rowKey), resultProj, numOutputRows) - } - - case RightOuter => - streamedIter.flatMap { currentRow => - val rowKey = keyGenerator(currentRow) - joinedRow.withRight(currentRow) - rightOuterIterator(rowKey, hashTable.get(rowKey), joinedRow, resultProj, numOutputRows) - } - - case LeftSemi => - hashSemiJoin(streamedIter, hashTable, numOutputRows) - - case x => - throw new IllegalArgumentException( - s"BroadcastHashJoin should not take $x as the JoinType") - } + val hashed = broadcastRelation.value.asReadOnlyCopy() + TaskContext.get().taskMetrics().incPeakExecutionMemory(hashed.getMemorySize) + join(streamedIter, hashed, numOutputRows) } } @@ -132,7 +104,7 @@ case class BroadcastHashJoin( val clsName = broadcastRelation.value.getClass.getName ctx.addMutableState(clsName, relationTerm, s""" - | $relationTerm = ($clsName) $broadcast.value(); + | $relationTerm = (($clsName) $broadcast.value()).asReadOnlyCopy(); | incPeakExecutionMemory($relationTerm.getMemorySize()); """.stripMargin) (broadcastRelation, relationTerm) @@ -217,7 +189,7 @@ case class BroadcastHashJoin( case BuildLeft => buildVars ++ input case BuildRight => input ++ buildVars } - if (broadcastRelation.value.isInstanceOf[UniqueHashedRelation]) { + if (broadcastRelation.value.keyIsUnique) { s""" |// generate join key for stream side |${keyEv.code} @@ -232,18 +204,15 @@ case class BroadcastHashJoin( } else { ctx.copyResult = true val matches = ctx.freshName("matches") - val bufferType = classOf[CompactBuffer[UnsafeRow]].getName - val i = ctx.freshName("i") - val size = ctx.freshName("size") + val iteratorCls = classOf[Iterator[UnsafeRow]].getName s""" |// generate join key for stream side |${keyEv.code} |// find matches from HashRelation - |$bufferType $matches = $anyNull ? null : ($bufferType)$relationTerm.get(${keyEv.value}); + |$iteratorCls $matches = $anyNull ? null : ($iteratorCls)$relationTerm.get(${keyEv.value}); |if ($matches == null) continue; - |int $size = $matches.size(); - |for (int $i = 0; $i < $size; $i++) { - | UnsafeRow $matched = (UnsafeRow) $matches.apply($i); + |while ($matches.hasNext()) { + | UnsafeRow $matched = (UnsafeRow) $matches.next(); | $checkCondition | $numOutput.add(1); | ${consume(ctx, resultVars)} @@ -287,7 +256,7 @@ case class BroadcastHashJoin( case BuildLeft => buildVars ++ input case BuildRight => input ++ buildVars } - if (broadcastRelation.value.isInstanceOf[UniqueHashedRelation]) { + if (broadcastRelation.value.keyIsUnique) { s""" |// generate join key for stream side |${keyEv.code} @@ -306,22 +275,21 @@ case class BroadcastHashJoin( } else { ctx.copyResult = true val matches = ctx.freshName("matches") - val bufferType = classOf[CompactBuffer[UnsafeRow]].getName + val iteratorCls = classOf[Iterator[UnsafeRow]].getName val i = ctx.freshName("i") - val size = ctx.freshName("size") val found = ctx.freshName("found") s""" |// generate join key for stream side |${keyEv.code} |// find matches from HashRelation - |$bufferType $matches = $anyNull ? null : ($bufferType)$relationTerm.get(${keyEv.value}); - |int $size = $matches != null ? $matches.size() : 0; + |$iteratorCls $matches = $anyNull ? null : ($iteratorCls)$relationTerm.get(${keyEv.value}); |boolean $found = false; |// the last iteration of this loop is to emit an empty row if there is no matched rows. - |for (int $i = 0; $i <= $size; $i++) { - | UnsafeRow $matched = $i < $size ? (UnsafeRow) $matches.apply($i) : null; + |while ($matches != null && $matches.hasNext() || !$found) { + | UnsafeRow $matched = $matches != null && $matches.hasNext() ? + | (UnsafeRow) $matches.next() : null; | ${checkCondition.trim} - | if (!$conditionPassed || ($i == $size && $found)) continue; + | if (!$conditionPassed) continue; | $found = true; | $numOutput.add(1); | ${consume(ctx, resultVars)} @@ -356,7 +324,7 @@ case class BroadcastHashJoin( "" } - if (broadcastRelation.value.isInstanceOf[UniqueHashedRelation]) { + if (broadcastRelation.value.keyIsUnique) { s""" |// generate join key for stream side |${keyEv.code} @@ -369,23 +337,19 @@ case class BroadcastHashJoin( """.stripMargin } else { val matches = ctx.freshName("matches") - val bufferType = classOf[CompactBuffer[UnsafeRow]].getName - val i = ctx.freshName("i") - val size = ctx.freshName("size") + val iteratorCls = classOf[Iterator[UnsafeRow]].getName val found = ctx.freshName("found") s""" |// generate join key for stream side |${keyEv.code} |// find matches from HashRelation - |$bufferType $matches = $anyNull ? null : ($bufferType)$relationTerm.get(${keyEv.value}); + |$iteratorCls $matches = $anyNull ? null : ($iteratorCls)$relationTerm.get(${keyEv.value}); |if ($matches == null) continue; - |int $size = $matches.size(); |boolean $found = false; - |for (int $i = 0; $i < $size; $i++) { - | UnsafeRow $matched = (UnsafeRow) $matches.apply($i); + |while (!$found && $matches.hasNext()) { + | UnsafeRow $matched = (UnsafeRow) $matches.next(); | $checkCondition | $found = true; - | break; |} |if (!$found) continue; |$numOutput.add(1); diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala index c298b7dee0..b7c0f3e7d1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala @@ -19,10 +19,11 @@ package org.apache.spark.sql.execution.joins import java.util.NoSuchElementException +import org.apache.spark.TaskContext import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans._ -import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.{RowIterator, SparkPlan} import org.apache.spark.sql.execution.metric.LongSQLMetric import org.apache.spark.sql.types.{IntegerType, IntegralType, LongType} import org.apache.spark.util.collection.CompactBuffer @@ -110,169 +111,113 @@ trait HashJoin { sameTypes && key.length == 1 && key.head.dataType.isInstanceOf[LongType] } - protected def buildSideKeyGenerator: Projection = + protected def buildSideKeyGenerator(): Projection = UnsafeProjection.create(rewriteKeyExpr(buildKeys), buildPlan.output) - protected def streamSideKeyGenerator: Projection = + protected def streamSideKeyGenerator(): UnsafeProjection = UnsafeProjection.create(rewriteKeyExpr(streamedKeys), streamedPlan.output) @transient private[this] lazy val boundCondition = if (condition.isDefined) { - newPredicate(condition.getOrElse(Literal(true)), left.output ++ right.output) + newPredicate(condition.get, streamedPlan.output ++ buildPlan.output) } else { (r: InternalRow) => true } - protected def createResultProjection: (InternalRow) => InternalRow = - UnsafeProjection.create(self.schema) - - protected def hashJoin( - streamIter: Iterator[InternalRow], - hashedRelation: HashedRelation, - numOutputRows: LongSQLMetric): Iterator[InternalRow] = { - new Iterator[InternalRow] { - private[this] var currentStreamedRow: InternalRow = _ - private[this] var currentHashMatches: Seq[InternalRow] = _ - private[this] var currentMatchPosition: Int = -1 - - // Mutable per row objects. - private[this] val joinRow = new JoinedRow - private[this] val resultProjection = createResultProjection - - private[this] val joinKeys = streamSideKeyGenerator - - override final def hasNext: Boolean = { - while (true) { - // check if it's end of current matches - if (currentHashMatches != null && currentMatchPosition == currentHashMatches.length) { - currentHashMatches = null - currentMatchPosition = -1 - } - - // find the next match - while (currentHashMatches == null && streamIter.hasNext) { - currentStreamedRow = streamIter.next() - val key = joinKeys(currentStreamedRow) - if (!key.anyNull) { - currentHashMatches = hashedRelation.get(key) - if (currentHashMatches != null) { - currentMatchPosition = 0 - } - } - } - if (currentHashMatches == null) { - return false - } - - // found some matches - buildSide match { - case BuildRight => joinRow(currentStreamedRow, currentHashMatches(currentMatchPosition)) - case BuildLeft => joinRow(currentHashMatches(currentMatchPosition), currentStreamedRow) - } - if (boundCondition(joinRow)) { - return true - } else { - currentMatchPosition += 1 - } - } - false // unreachable - } - - override final def next(): InternalRow = { - // next() could be called without calling hasNext() - if (hasNext) { - currentMatchPosition += 1 - numOutputRows += 1 - resultProjection(joinRow) - } else { - throw new NoSuchElementException - } - } + protected def createResultProjection(): (InternalRow) => InternalRow = { + if (joinType == LeftSemi) { + UnsafeProjection.create(output, output) + } else { + // Always put the stream side on left to simplify implementation + // both of left and right side could be null + UnsafeProjection.create( + output, (streamedPlan.output ++ buildPlan.output).map(_.withNullability(true))) } } - @transient protected[this] lazy val EMPTY_LIST = CompactBuffer[InternalRow]() - - @transient private[this] lazy val leftNullRow = new GenericInternalRow(left.output.length) - @transient private[this] lazy val rightNullRow = new GenericInternalRow(right.output.length) - - protected[this] def leftOuterIterator( - key: InternalRow, - joinedRow: JoinedRow, - rightIter: Iterable[InternalRow], - resultProjection: InternalRow => InternalRow, - numOutputRows: LongSQLMetric): Iterator[InternalRow] = { - val ret: Iterable[InternalRow] = { - if (!key.anyNull) { - val temp = if (rightIter != null) { - rightIter.collect { - case r if boundCondition(joinedRow.withRight(r)) => { - numOutputRows += 1 - resultProjection(joinedRow).copy() - } - } - } else { - List.empty - } - if (temp.isEmpty) { - numOutputRows += 1 - resultProjection(joinedRow.withRight(rightNullRow)) :: Nil - } else { - temp - } + private def innerJoin( + streamIter: Iterator[InternalRow], + hashedRelation: HashedRelation): Iterator[InternalRow] = { + val joinRow = new JoinedRow + val joinKeys = streamSideKeyGenerator() + streamIter.flatMap { srow => + joinRow.withLeft(srow) + val matches = hashedRelation.get(joinKeys(srow)) + if (matches != null) { + matches.map(joinRow.withRight(_)).filter(boundCondition) } else { - numOutputRows += 1 - resultProjection(joinedRow.withRight(rightNullRow)) :: Nil + Seq.empty } } - ret.iterator } - protected[this] def rightOuterIterator( - key: InternalRow, - leftIter: Iterable[InternalRow], - joinedRow: JoinedRow, - resultProjection: InternalRow => InternalRow, - numOutputRows: LongSQLMetric): Iterator[InternalRow] = { - val ret: Iterable[InternalRow] = { - if (!key.anyNull) { - val temp = if (leftIter != null) { - leftIter.collect { - case l if boundCondition(joinedRow.withLeft(l)) => { - numOutputRows += 1 - resultProjection(joinedRow).copy() + private def outerJoin( + streamedIter: Iterator[InternalRow], + hashedRelation: HashedRelation): Iterator[InternalRow] = { + val joinedRow = new JoinedRow() + val keyGenerator = streamSideKeyGenerator() + val nullRow = new GenericInternalRow(buildPlan.output.length) + + streamedIter.flatMap { currentRow => + val rowKey = keyGenerator(currentRow) + joinedRow.withLeft(currentRow) + val buildIter = hashedRelation.get(rowKey) + new RowIterator { + private var found = false + override def advanceNext(): Boolean = { + while (buildIter != null && buildIter.hasNext) { + val nextBuildRow = buildIter.next() + if (boundCondition(joinedRow.withRight(nextBuildRow))) { + found = true + return true } } - } else { - List.empty - } - if (temp.isEmpty) { - numOutputRows += 1 - resultProjection(joinedRow.withLeft(leftNullRow)) :: Nil - } else { - temp + if (!found) { + joinedRow.withRight(nullRow) + found = true + return true + } + false } - } else { - numOutputRows += 1 - resultProjection(joinedRow.withLeft(leftNullRow)) :: Nil - } + override def getRow: InternalRow = joinedRow + }.toScala } - ret.iterator } - protected def hashSemiJoin( - streamIter: Iterator[InternalRow], - hashedRelation: HashedRelation, - numOutputRows: LongSQLMetric): Iterator[InternalRow] = { - val joinKeys = streamSideKeyGenerator + private def semiJoin( + streamIter: Iterator[InternalRow], + hashedRelation: HashedRelation): Iterator[InternalRow] = { + val joinKeys = streamSideKeyGenerator() val joinedRow = new JoinedRow streamIter.filter { current => val key = joinKeys(current) - lazy val rowBuffer = hashedRelation.get(key) - val r = !key.anyNull && rowBuffer != null && (condition.isEmpty || rowBuffer.exists { + lazy val buildIter = hashedRelation.get(key) + !key.anyNull && buildIter != null && (condition.isEmpty || buildIter.exists { (row: InternalRow) => boundCondition(joinedRow(current, row)) }) - if (r) numOutputRows += 1 - r + } + } + + protected def join( + streamedIter: Iterator[InternalRow], + hashed: HashedRelation, + numOutputRows: LongSQLMetric): Iterator[InternalRow] = { + + val joinedIter = joinType match { + case Inner => + innerJoin(streamedIter, hashed) + case LeftOuter | RightOuter => + outerJoin(streamedIter, hashed) + case LeftSemi => + semiJoin(streamedIter, hashed) + case x => + throw new IllegalArgumentException( + s"BroadcastHashJoin should not take $x as the JoinType") + } + + val resultProj = createResultProjection + joinedIter.map { r => + numOutputRows += 1 + resultProj(r) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala index 91c470d187..5ccb435686 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala @@ -27,7 +27,6 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical.BroadcastMode import org.apache.spark.sql.execution.SparkSqlSerializer import org.apache.spark.unsafe.Platform -import org.apache.spark.unsafe.hash.Murmur3_x86_32 import org.apache.spark.unsafe.map.BytesToBytesMap import org.apache.spark.util.{KnownSizeEstimation, Utils} import org.apache.spark.util.collection.CompactBuffer @@ -39,16 +38,42 @@ import org.apache.spark.util.collection.CompactBuffer private[execution] sealed trait HashedRelation { /** * Returns matched rows. + * + * Returns null if there is no matched rows. */ - def get(key: InternalRow): Seq[InternalRow] + def get(key: InternalRow): Iterator[InternalRow] /** * Returns matched rows for a key that has only one column with LongType. + * + * Returns null if there is no matched rows. + */ + def get(key: Long): Iterator[InternalRow] = { + throw new UnsupportedOperationException + } + + /** + * Returns the matched single row. */ - def get(key: Long): Seq[InternalRow] = { + def getValue(key: InternalRow): InternalRow + + /** + * Returns the matched single row with key that have only one column of LongType. + */ + def getValue(key: Long): InternalRow = { throw new UnsupportedOperationException } + /** + * Returns true iff all the keys are unique. + */ + def keyIsUnique: Boolean + + /** + * Returns a read-only copy of this, to be safely used in current thread. + */ + def asReadOnlyCopy(): HashedRelation + /** * Returns the size of used memory. */ @@ -76,44 +101,6 @@ private[execution] sealed trait HashedRelation { } } -/** - * Interface for a hashed relation that have only one row per key. - * - * We should call getValue() for better performance. - */ -private[execution] trait UniqueHashedRelation extends HashedRelation { - - /** - * Returns the matched single row. - */ - def getValue(key: InternalRow): InternalRow - - /** - * Returns the matched single row with key that have only one column of LongType. - */ - def getValue(key: Long): InternalRow = { - throw new UnsupportedOperationException - } - - override def get(key: InternalRow): Seq[InternalRow] = { - val row = getValue(key) - if (row != null) { - CompactBuffer[InternalRow](row) - } else { - null - } - } - - override def get(key: Long): Seq[InternalRow] = { - val row = getValue(key) - if (row != null) { - CompactBuffer[InternalRow](row) - } else { - null - } - } -} - private[execution] object HashedRelation { /** @@ -150,6 +137,11 @@ private[joins] class UnsafeHashedRelation( private[joins] def this() = this(0, null) // Needed for serialization + override def keyIsUnique: Boolean = binaryMap.numKeys() == binaryMap.numValues() + + override def asReadOnlyCopy(): UnsafeHashedRelation = + new UnsafeHashedRelation(numFields, binaryMap) + override def getMemorySize: Long = { binaryMap.getTotalMemoryConsumption } @@ -158,23 +150,39 @@ private[joins] class UnsafeHashedRelation( binaryMap.getTotalMemoryConsumption } - override def get(key: InternalRow): Seq[InternalRow] = { + // re-used in get()/getValue() + var resultRow = new UnsafeRow(numFields) + + override def get(key: InternalRow): Iterator[InternalRow] = { val unsafeKey = key.asInstanceOf[UnsafeRow] val map = binaryMap // avoid the compiler error val loc = new map.Location // this could be allocated in stack binaryMap.safeLookup(unsafeKey.getBaseObject, unsafeKey.getBaseOffset, unsafeKey.getSizeInBytes, loc, unsafeKey.hashCode()) if (loc.isDefined) { - val buffer = CompactBuffer[UnsafeRow]() - val row = new UnsafeRow(numFields) - row.pointTo(loc.getValueBase, loc.getValueOffset, loc.getValueLength) - buffer += row - while (loc.nextValue()) { - val row = new UnsafeRow(numFields) - row.pointTo(loc.getValueBase, loc.getValueOffset, loc.getValueLength) - buffer += row + new Iterator[UnsafeRow] { + private var _hasNext = true + override def hasNext: Boolean = _hasNext + override def next(): UnsafeRow = { + resultRow.pointTo(loc.getValueBase, loc.getValueOffset, loc.getValueLength) + _hasNext = loc.nextValue() + resultRow + } } - buffer + } else { + null + } + } + + def getValue(key: InternalRow): InternalRow = { + val unsafeKey = key.asInstanceOf[UnsafeRow] + val map = binaryMap // avoid the compiler error + val loc = new map.Location // this could be allocated in stack + binaryMap.safeLookup(unsafeKey.getBaseObject, unsafeKey.getBaseOffset, + unsafeKey.getSizeInBytes, loc, unsafeKey.hashCode()) + if (loc.isDefined) { + resultRow.pointTo(loc.getValueBase, loc.getValueOffset, loc.getValueLength) + resultRow } else { null } @@ -212,6 +220,7 @@ private[joins] class UnsafeHashedRelation( override def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException { numFields = in.readInt() + resultRow = new UnsafeRow(numFields) val nKeys = in.readInt() val nValues = in.readInt() // This is used in Broadcast, shared by multiple tasks, so we use on-heap memory @@ -263,29 +272,6 @@ private[joins] class UnsafeHashedRelation( } } -/** - * A HashedRelation for UnsafeRow with unique keys. - */ -private[joins] final class UniqueUnsafeHashedRelation( - private var numFields: Int, - private var binaryMap: BytesToBytesMap) - extends UnsafeHashedRelation(numFields, binaryMap) with UniqueHashedRelation { - def getValue(key: InternalRow): InternalRow = { - val unsafeKey = key.asInstanceOf[UnsafeRow] - val map = binaryMap // avoid the compiler error - val loc = new map.Location // this could be allocated in stack - binaryMap.safeLookup(unsafeKey.getBaseObject, unsafeKey.getBaseOffset, - unsafeKey.getSizeInBytes, loc, unsafeKey.hashCode()) - if (loc.isDefined) { - val row = new UnsafeRow(numFields) - row.pointTo(loc.getValueBase, loc.getValueOffset, loc.getValueLength) - row - } else { - null - } - } -} - private[joins] object UnsafeHashedRelation { def apply( @@ -315,17 +301,12 @@ private[joins] object UnsafeHashedRelation { // Create a mapping of buildKeys -> rows var numFields = 0 - // Whether all the keys are unique or not - var allUnique: Boolean = true while (input.hasNext) { val row = input.next().asInstanceOf[UnsafeRow] numFields = row.numFields() val key = keyGenerator(row) if (!key.anyNull) { val loc = binaryMap.lookup(key.getBaseObject, key.getBaseOffset, key.getSizeInBytes) - if (loc.isDefined) { - allUnique = false - } val success = loc.append( key.getBaseObject, key.getBaseOffset, key.getSizeInBytes, row.getBaseObject, row.getBaseOffset, row.getSizeInBytes) @@ -336,11 +317,7 @@ private[joins] object UnsafeHashedRelation { } } - if (allUnique) { - new UniqueUnsafeHashedRelation(numFields, binaryMap) - } else { - new UnsafeHashedRelation(numFields, binaryMap) - } + new UnsafeHashedRelation(numFields, binaryMap) } } @@ -348,9 +325,12 @@ private[joins] object UnsafeHashedRelation { * An interface for a hashed relation that the key is a Long. */ private[joins] trait LongHashedRelation extends HashedRelation { - override def get(key: InternalRow): Seq[InternalRow] = { + override def get(key: InternalRow): Iterator[InternalRow] = { get(key.getLong(0)) } + override def getValue(key: InternalRow): InternalRow = { + getValue(key.getLong(0)) + } } private[joins] final class GeneralLongHashedRelation( @@ -360,30 +340,18 @@ private[joins] final class GeneralLongHashedRelation( // Needed for serialization (it is public to make Java serialization work) def this() = this(null) - override def get(key: Long): Seq[InternalRow] = hashTable.get(key) + override def keyIsUnique: Boolean = false - override def writeExternal(out: ObjectOutput): Unit = { - writeBytes(out, SparkSqlSerializer.serialize(hashTable)) - } - - override def readExternal(in: ObjectInput): Unit = { - hashTable = SparkSqlSerializer.deserialize(readBytes(in)) - } -} + override def asReadOnlyCopy(): GeneralLongHashedRelation = + new GeneralLongHashedRelation(hashTable) -private[joins] final class UniqueLongHashedRelation( - private var hashTable: JavaHashMap[Long, UnsafeRow]) - extends UniqueHashedRelation with LongHashedRelation with Externalizable { - - // Needed for serialization (it is public to make Java serialization work) - def this() = this(null) - - override def getValue(key: InternalRow): InternalRow = { - getValue(key.getLong(0)) - } - - override def getValue(key: Long): InternalRow = { - hashTable.get(key) + override def get(key: Long): Iterator[InternalRow] = { + val rows = hashTable.get(key) + if (rows != null) { + rows.toIterator + } else { + null + } } override def writeExternal(out: ObjectOutput): Unit = { @@ -422,25 +390,36 @@ private[joins] final class LongArrayRelation( private var offsets: Array[Int], private var sizes: Array[Int], private var bytes: Array[Byte] - ) extends UniqueHashedRelation with LongHashedRelation with Externalizable { + ) extends LongHashedRelation with Externalizable { // Needed for serialization (it is public to make Java serialization work) def this() = this(0, 0L, null, null, null) - override def getValue(key: InternalRow): InternalRow = { - getValue(key.getLong(0)) + override def keyIsUnique: Boolean = true + + override def asReadOnlyCopy(): LongArrayRelation = { + new LongArrayRelation(numFields, start, offsets, sizes, bytes) } override def getMemorySize: Long = { offsets.length * 4 + sizes.length * 4 + bytes.length } + override def get(key: Long): Iterator[InternalRow] = { + val row = getValue(key) + if (row != null) { + Seq(row).toIterator + } else { + null + } + } + + var resultRow = new UnsafeRow(numFields) override def getValue(key: Long): InternalRow = { val idx = (key - start).toInt if (idx >= 0 && idx < sizes.length && sizes(idx) > 0) { - val result = new UnsafeRow(numFields) - result.pointTo(bytes, Platform.BYTE_ARRAY_OFFSET + offsets(idx), sizes(idx)) - result + resultRow.pointTo(bytes, Platform.BYTE_ARRAY_OFFSET + offsets(idx), sizes(idx)) + resultRow } else { null } @@ -461,6 +440,7 @@ private[joins] final class LongArrayRelation( override def readExternal(in: ObjectInput): Unit = { numFields = in.readInt() + resultRow = new UnsafeRow(numFields) start = in.readLong() val length = in.readInt() // read sizes of rows @@ -523,44 +503,32 @@ private[joins] object LongHashedRelation { } } - if (keyIsUnique) { - if (hashTable.size() > (maxKey - minKey) * DENSE_FACTOR) { - // The keys are dense enough, so use LongArrayRelation - val length = (maxKey - minKey).toInt + 1 - val sizes = new Array[Int](length) - val offsets = new Array[Int](length) - var offset = 0 - var i = 0 - while (i < length) { - val rows = hashTable.get(i + minKey) - if (rows != null) { - offsets(i) = offset - sizes(i) = rows(0).getSizeInBytes - offset += sizes(i) - } - i += 1 - } - val bytes = new Array[Byte](offset) - i = 0 - while (i < length) { - val rows = hashTable.get(i + minKey) - if (rows != null) { - rows(0).writeToMemory(bytes, Platform.BYTE_ARRAY_OFFSET + offsets(i)) - } - i += 1 + if (keyIsUnique && hashTable.size() > (maxKey - minKey) * DENSE_FACTOR) { + // The keys are dense enough, so use LongArrayRelation + val length = (maxKey - minKey).toInt + 1 + val sizes = new Array[Int](length) + val offsets = new Array[Int](length) + var offset = 0 + var i = 0 + while (i < length) { + val rows = hashTable.get(i + minKey) + if (rows != null) { + offsets(i) = offset + sizes(i) = rows(0).getSizeInBytes + offset += sizes(i) } - new LongArrayRelation(numFields, minKey, offsets, sizes, bytes) - - } else { - // all the keys are unique, one row per key. - val uniqHashTable = new JavaHashMap[Long, UnsafeRow](hashTable.size) - val iter = hashTable.entrySet().iterator() - while (iter.hasNext) { - val entry = iter.next() - uniqHashTable.put(entry.getKey, entry.getValue()(0)) + i += 1 + } + val bytes = new Array[Byte](offset) + i = 0 + while (i < length) { + val rows = hashTable.get(i + minKey) + if (rows != null) { + rows(0).writeToMemory(bytes, Platform.BYTE_ARRAY_OFFSET + offsets(i)) } - new UniqueLongHashedRelation(uniqHashTable) + i += 1 } + new LongArrayRelation(numFields, minKey, offsets, sizes, bytes) } else { new GeneralLongHashedRelation(hashTable) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala index e3a2eaea5d..c63faacf33 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala @@ -102,39 +102,9 @@ case class ShuffledHashJoin( protected override def doExecute(): RDD[InternalRow] = { val numOutputRows = longMetric("numOutputRows") - streamedPlan.execute().zipPartitions(buildPlan.execute()) { (streamIter, buildIter) => val hashed = buildHashedRelation(buildIter.asInstanceOf[Iterator[UnsafeRow]]) - val joinedRow = new JoinedRow - joinType match { - case Inner => - hashJoin(streamIter, hashed, numOutputRows) - - case LeftSemi => - hashSemiJoin(streamIter, hashed, numOutputRows) - - case LeftOuter => - val keyGenerator = streamSideKeyGenerator - val resultProj = createResultProjection - streamIter.flatMap(currentRow => { - val rowKey = keyGenerator(currentRow) - joinedRow.withLeft(currentRow) - leftOuterIterator(rowKey, joinedRow, hashed.get(rowKey), resultProj, numOutputRows) - }) - - case RightOuter => - val keyGenerator = streamSideKeyGenerator - val resultProj = createResultProjection - streamIter.flatMap(currentRow => { - val rowKey = keyGenerator(currentRow) - joinedRow.withRight(currentRow) - rightOuterIterator(rowKey, hashed.get(rowKey), joinedRow, resultProj, numOutputRows) - }) - - case x => - throw new IllegalArgumentException( - s"ShuffledHashJoin should not take $x as the JoinType") - } + join(streamIter, hashed, numOutputRows) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala index 289e1b6db9..3566ef3043 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala @@ -166,20 +166,35 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite { } ignore("broadcast hash join") { - val N = 100 << 20 + val N = 20 << 20 val M = 1 << 16 val dim = broadcast(sqlContext.range(M).selectExpr("id as k", "cast(id as string) as v")) runBenchmark("Join w long", N) { - sqlContext.range(N).join(dim, (col("id") bitwiseAND M) === col("k")).count() + sqlContext.range(N).join(dim, (col("id") % M) === col("k")).count() } /* + Java HotSpot(TM) 64-Bit Server VM 1.7.0_60-b19 on Mac OS X 10.9.5 Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz Join w long: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------- - Join w long codegen=false 5744 / 5814 18.3 54.8 1.0X - Join w long codegen=true 735 / 853 142.7 7.0 7.8X + Join w long codegen=false 5351 / 5531 3.9 255.1 1.0X + Join w long codegen=true 275 / 352 76.2 13.1 19.4X + */ + + runBenchmark("Join w long duplicated", N) { + val dim = broadcast(sqlContext.range(M).selectExpr("cast(id/10 as long) as k")) + sqlContext.range(N).join(dim, (col("id") % M) === col("k")).count() + } + + /** + Java HotSpot(TM) 64-Bit Server VM 1.7.0_60-b19 on Mac OS X 10.9.5 + Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz + Join w long duplicated: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------- + Join w long duplicated codegen=false 4752 / 4906 4.4 226.6 1.0X + Join w long duplicated codegen=true 722 / 760 29.0 34.4 6.6X */ val dim2 = broadcast(sqlContext.range(M) @@ -187,16 +202,17 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite { runBenchmark("Join w 2 ints", N) { sqlContext.range(N).join(dim2, - (col("id") bitwiseAND M).cast(IntegerType) === col("k1") - && (col("id") bitwiseAND M).cast(IntegerType) === col("k2")).count() + (col("id") % M).cast(IntegerType) === col("k1") + && (col("id") % M).cast(IntegerType) === col("k2")).count() } /** + Java HotSpot(TM) 64-Bit Server VM 1.7.0_60-b19 on Mac OS X 10.9.5 Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz Join w 2 ints: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------- - Join w 2 ints codegen=false 7159 / 7224 14.6 68.3 1.0X - Join w 2 ints codegen=true 1135 / 1197 92.4 10.8 6.3X + Join w 2 ints codegen=false 9011 / 9121 2.3 429.7 1.0X + Join w 2 ints codegen=true 2565 / 2816 8.2 122.3 3.5X */ val dim3 = broadcast(sqlContext.range(M) @@ -204,16 +220,17 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite { runBenchmark("Join w 2 longs", N) { sqlContext.range(N).join(dim3, - (col("id") bitwiseAND M) === col("k1") && (col("id") bitwiseAND M) === col("k2")) + (col("id") % M) === col("k1") && (col("id") % M) === col("k2")) .count() } /** + Java HotSpot(TM) 64-Bit Server VM 1.7.0_60-b19 on Mac OS X 10.9.5 Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz Join w 2 longs: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------- - Join w 2 longs codegen=false 12725 / 13158 8.2 121.4 1.0X - Join w 2 longs codegen=true 6044 / 6771 17.3 57.6 2.1X + Join w 2 longs codegen=false 5905 / 6123 3.6 281.6 1.0X + Join w 2 longs codegen=true 2230 / 2529 9.4 106.3 2.6X */ val dim4 = broadcast(sqlContext.range(M) @@ -227,34 +244,36 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite { /** Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz - Join w 2 longs: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + Join w 2 longs duplicated: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------- - Join w 2 longs duplicated codegen=false 13066 / 13710 8.0 124.6 1.0X - Join w 2 longs duplicated codegen=true 7122 / 7277 14.7 67.9 1.8X + Join w 2 longs duplicated codegen=false 6420 / 6587 3.3 306.1 1.0X + Join w 2 longs duplicated codegen=true 2080 / 2139 10.1 99.2 3.1X */ runBenchmark("outer join w long", N) { - sqlContext.range(N).join(dim, (col("id") bitwiseAND M) === col("k"), "left").count() + sqlContext.range(N).join(dim, (col("id") % M) === col("k"), "left").count() } /** + Java HotSpot(TM) 64-Bit Server VM 1.7.0_60-b19 on Mac OS X 10.9.5 Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz outer join w long: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------- - outer join w long codegen=false 15280 / 16497 6.9 145.7 1.0X - outer join w long codegen=true 769 / 796 136.3 7.3 19.9X + outer join w long codegen=false 5667 / 5780 3.7 270.2 1.0X + outer join w long codegen=true 216 / 226 97.2 10.3 26.3X */ runBenchmark("semi join w long", N) { - sqlContext.range(N).join(dim, (col("id") bitwiseAND M) === col("k"), "leftsemi").count() + sqlContext.range(N).join(dim, (col("id") % M) === col("k"), "leftsemi").count() } /** + Java HotSpot(TM) 64-Bit Server VM 1.7.0_60-b19 on Mac OS X 10.9.5 Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz semi join w long: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------- - semi join w long codegen=false 5804 / 5969 18.1 55.3 1.0X - semi join w long codegen=true 814 / 934 128.8 7.8 7.1X + semi join w long codegen=false 4690 / 4953 4.5 223.7 1.0X + semi join w long codegen=true 211 / 229 99.2 10.1 22.2X */ } @@ -303,11 +322,12 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite { } /** + Java HotSpot(TM) 64-Bit Server VM 1.7.0_60-b19 on Mac OS X 10.9.5 Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz shuffle hash join: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------- - shuffle hash join codegen=false 1168 / 1902 3.6 278.6 1.0X - shuffle hash join codegen=true 850 / 1196 4.9 202.8 1.4X + shuffle hash join codegen=false 1538 / 1742 2.7 366.7 1.0X + shuffle hash join codegen=true 892 / 1329 4.7 212.6 1.7X */ } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala index ed4cc1c4c4..ed87a99439 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala @@ -34,20 +34,20 @@ class HashedRelationSuite extends SparkFunSuite with SharedSQLContext { val schema = StructType(StructField("a", IntegerType, true) :: Nil) val data = Array(InternalRow(0), InternalRow(1), InternalRow(2), InternalRow(2)) val toUnsafe = UnsafeProjection.create(schema) - val unsafeData = data.map(toUnsafe(_).copy()).toArray + val unsafeData = data.map(toUnsafe(_).copy()) val buildKey = Seq(BoundReference(0, IntegerType, false)) val keyGenerator = UnsafeProjection.create(buildKey) val hashed = UnsafeHashedRelation(unsafeData.iterator, keyGenerator, 1) assert(hashed.isInstanceOf[UnsafeHashedRelation]) - assert(hashed.get(unsafeData(0)) === CompactBuffer[InternalRow](unsafeData(0))) - assert(hashed.get(unsafeData(1)) === CompactBuffer[InternalRow](unsafeData(1))) + assert(hashed.get(unsafeData(0)).toArray === Array(unsafeData(0))) + assert(hashed.get(unsafeData(1)).toArray === Array(unsafeData(1))) assert(hashed.get(toUnsafe(InternalRow(10))) === null) val data2 = CompactBuffer[InternalRow](unsafeData(2).copy()) data2 += unsafeData(2).copy() - assert(hashed.get(unsafeData(2)) === data2) + assert(hashed.get(unsafeData(2)).toArray === data2.toArray) val os = new ByteArrayOutputStream() val out = new ObjectOutputStream(os) @@ -56,10 +56,10 @@ class HashedRelationSuite extends SparkFunSuite with SharedSQLContext { val in = new ObjectInputStream(new ByteArrayInputStream(os.toByteArray)) val hashed2 = new UnsafeHashedRelation() hashed2.readExternal(in) - assert(hashed2.get(unsafeData(0)) === CompactBuffer[InternalRow](unsafeData(0))) - assert(hashed2.get(unsafeData(1)) === CompactBuffer[InternalRow](unsafeData(1))) + assert(hashed2.get(unsafeData(0)).toArray === Array(unsafeData(0))) + assert(hashed2.get(unsafeData(1)).toArray === Array(unsafeData(1))) assert(hashed2.get(toUnsafe(InternalRow(10))) === null) - assert(hashed2.get(unsafeData(2)) === data2) + assert(hashed2.get(unsafeData(2)).toArray === data2) val os2 = new ByteArrayOutputStream() val out2 = new ObjectOutputStream(os2) -- cgit v1.2.3 From 89f3befab6c150f87de2fb91b50ea8b414c69095 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Mon, 4 Apr 2016 10:24:02 -0700 Subject: [SPARK-13784][ML] Persistence for RandomForestClassifier, RandomForestRegressor ## What changes were proposed in this pull request? **Main change**: Added save/load for RandomForestClassifier, RandomForestRegressor (implementation details below) Modified numTrees method (*deprecation*) * Goal: Use default implementations of unit tests which assume Estimators and Models share the same set of Params. * What this PR does: Moves method numTrees outside of trait TreeEnsembleModel. Adds it to GBT and RF Models. Deprecates it in RF Models in favor of new method getNumTrees. In Spark 2.1, we can have RF Models include Param numTrees. Minor items * Fixes bugs in GBTClassificationModel, GBTRegressionModel fromOld methods where they assign the wrong old UID. **Implementation details** * Split DecisionTreeModelReadWrite.loadTreeNodes into 2 methods in order to reuse some code for ensembles. * Added EnsembleModelReadWrite object with save/load implementations usable for RFs and GBTs * These store all trees' nodes in a single DataFrame, and all trees' metadata in a second DataFrame. * Split trait RandomForestParams into parts in order to add more Estimator Params to RF models * Split DefaultParamsWriter.saveMetadata into two methods to allow ensembles to store sub-models' metadata in a single DataFrame. Same for DefaultParamsReader.loadMetadata ## How was this patch tested? Adds standard unit tests for RF save/load Author: Joseph K. Bradley Author: GayathriMurali Closes #12118 from jkbradley/GayathriMurali-SPARK-13784. --- .../spark/ml/classification/GBTClassifier.scala | 7 +- .../ml/classification/RandomForestClassifier.scala | 100 ++++++++++++++-- .../apache/spark/ml/regression/GBTRegressor.scala | 7 +- .../ml/regression/RandomForestRegressor.scala | 96 +++++++++++++-- .../org/apache/spark/ml/tree/treeModels.scala | 131 +++++++++++++++++++-- .../org/apache/spark/ml/tree/treeParams.scala | 66 +++++++---- .../scala/org/apache/spark/ml/util/ReadWrite.scala | 42 ++++++- .../RandomForestClassifierSuite.scala | 40 +++---- .../ml/regression/RandomForestRegressorSuite.scala | 38 +++--- 9 files changed, 424 insertions(+), 103 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala index 48ce051d0a..bfefaf1a1a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala @@ -192,7 +192,7 @@ final class GBTClassificationModel private[ml]( extends PredictionModel[Vector, GBTClassificationModel] with TreeEnsembleModel with Serializable { - require(numTrees > 0, "GBTClassificationModel requires at least 1 tree.") + require(_trees.nonEmpty, "GBTClassificationModel requires at least 1 tree.") require(_trees.length == _treeWeights.length, "GBTClassificationModel given trees, treeWeights" + s" of non-matching lengths (${_trees.length}, ${_treeWeights.length}, respectively).") @@ -227,6 +227,9 @@ final class GBTClassificationModel private[ml]( if (prediction > 0.0) 1.0 else 0.0 } + /** Number of trees in ensemble */ + val numTrees: Int = trees.length + @Since("1.4.0") override def copy(extra: ParamMap): GBTClassificationModel = { copyValues(new GBTClassificationModel(uid, _trees, _treeWeights, numFeatures), @@ -272,6 +275,6 @@ private[ml] object GBTClassificationModel { DecisionTreeRegressionModel.fromOld(tree, null, categoricalFeatures) } val uid = if (parent != null) parent.uid else Identifiable.randomUID("gbtc") - new GBTClassificationModel(parent.uid, newTrees, oldModel.treeWeights, numFeatures) + new GBTClassificationModel(uid, newTrees, oldModel.treeWeights, numFeatures) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala index 82fa05a604..2ad893f4fa 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala @@ -17,11 +17,15 @@ package org.apache.spark.ml.classification +import org.json4s.{DefaultFormats, JObject} +import org.json4s.JsonDSL._ + import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.param.ParamMap -import org.apache.spark.ml.tree.{DecisionTreeModel, RandomForestParams, TreeClassifierParams, TreeEnsembleModel} +import org.apache.spark.ml.tree._ import org.apache.spark.ml.tree.impl.RandomForest -import org.apache.spark.ml.util.{Identifiable, MetadataUtils} +import org.apache.spark.ml.util._ +import org.apache.spark.ml.util.DefaultParamsReader.Metadata import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} @@ -43,7 +47,7 @@ import org.apache.spark.sql.functions._ final class RandomForestClassifier @Since("1.4.0") ( @Since("1.4.0") override val uid: String) extends ProbabilisticClassifier[Vector, RandomForestClassifier, RandomForestClassificationModel] - with RandomForestParams with TreeClassifierParams { + with RandomForestClassifierParams with DefaultParamsWritable { @Since("1.4.0") def this() = this(Identifiable.randomUID("rfc")) @@ -120,7 +124,7 @@ final class RandomForestClassifier @Since("1.4.0") ( @Since("1.4.0") @Experimental -object RandomForestClassifier { +object RandomForestClassifier extends DefaultParamsReadable[RandomForestClassifier] { /** Accessor for supported impurity settings: entropy, gini */ @Since("1.4.0") final val supportedImpurities: Array[String] = TreeClassifierParams.supportedImpurities @@ -129,6 +133,9 @@ object RandomForestClassifier { @Since("1.4.0") final val supportedFeatureSubsetStrategies: Array[String] = RandomForestParams.supportedFeatureSubsetStrategies + + @Since("2.0.0") + override def load(path: String): RandomForestClassifier = super.load(path) } /** @@ -136,8 +143,9 @@ object RandomForestClassifier { * [[http://en.wikipedia.org/wiki/Random_forest Random Forest]] model for classification. * It supports both binary and multiclass labels, as well as both continuous and categorical * features. + * * @param _trees Decision trees in the ensemble. - * Warning: These have null parents. + * Warning: These have null parents. */ @Since("1.4.0") @Experimental @@ -147,12 +155,14 @@ final class RandomForestClassificationModel private[ml] ( @Since("1.6.0") override val numFeatures: Int, @Since("1.5.0") override val numClasses: Int) extends ProbabilisticClassificationModel[Vector, RandomForestClassificationModel] - with TreeEnsembleModel with Serializable { + with RandomForestClassificationModelParams with TreeEnsembleModel with MLWritable + with Serializable { - require(numTrees > 0, "RandomForestClassificationModel requires at least 1 tree.") + require(_trees.nonEmpty, "RandomForestClassificationModel requires at least 1 tree.") /** * Construct a random forest classification model, with all trees weighted equally. + * * @param trees Component trees */ private[ml] def this( @@ -165,7 +175,7 @@ final class RandomForestClassificationModel private[ml] ( override def trees: Array[DecisionTreeModel] = _trees.asInstanceOf[Array[DecisionTreeModel]] // Note: We may add support for weights (based on tree performance) later on. - private lazy val _treeWeights: Array[Double] = Array.fill[Double](numTrees)(1.0) + private lazy val _treeWeights: Array[Double] = Array.fill[Double](_trees.length)(1.0) @Since("1.4.0") override def treeWeights: Array[Double] = _treeWeights @@ -208,6 +218,15 @@ final class RandomForestClassificationModel private[ml] ( } } + /** + * Number of trees in ensemble + * + * @deprecated Use [[getNumTrees]] instead. This method will be removed in 2.1.0 + */ + // TODO: Once this is removed, then this class can inherit from RandomForestClassifierParams + @deprecated("Use getNumTrees instead. This method will be removed in 2.1.0.", "2.0.0") + val numTrees: Int = trees.length + @Since("1.4.0") override def copy(extra: ParamMap): RandomForestClassificationModel = { copyValues(new RandomForestClassificationModel(uid, _trees, numFeatures, numClasses), extra) @@ -216,7 +235,7 @@ final class RandomForestClassificationModel private[ml] ( @Since("1.4.0") override def toString: String = { - s"RandomForestClassificationModel (uid=$uid) with $numTrees trees" + s"RandomForestClassificationModel (uid=$uid) with $getNumTrees trees" } /** @@ -236,12 +255,69 @@ final class RandomForestClassificationModel private[ml] ( private[ml] def toOld: OldRandomForestModel = { new OldRandomForestModel(OldAlgo.Classification, _trees.map(_.toOld)) } + + @Since("2.0.0") + override def write: MLWriter = + new RandomForestClassificationModel.RandomForestClassificationModelWriter(this) } -private[ml] object RandomForestClassificationModel { +@Since("2.0.0") +object RandomForestClassificationModel extends MLReadable[RandomForestClassificationModel] { + + @Since("2.0.0") + override def read: MLReader[RandomForestClassificationModel] = + new RandomForestClassificationModelReader + + @Since("2.0.0") + override def load(path: String): RandomForestClassificationModel = super.load(path) + + private[RandomForestClassificationModel] + class RandomForestClassificationModelWriter(instance: RandomForestClassificationModel) + extends MLWriter { + + override protected def saveImpl(path: String): Unit = { + // Note: numTrees is not currently used, but could be nice to store for fast querying. + val extraMetadata: JObject = Map( + "numFeatures" -> instance.numFeatures, + "numClasses" -> instance.numClasses, + "numTrees" -> instance.getNumTrees) + EnsembleModelReadWrite.saveImpl(instance, path, sqlContext, extraMetadata) + } + } + + private class RandomForestClassificationModelReader + extends MLReader[RandomForestClassificationModel] { + + /** Checked against metadata when loading model */ + private val className = classOf[RandomForestClassificationModel].getName + private val treeClassName = classOf[DecisionTreeClassificationModel].getName + + override def load(path: String): RandomForestClassificationModel = { + implicit val format = DefaultFormats + val (metadata: Metadata, treesData: Array[(Metadata, Node)]) = + EnsembleModelReadWrite.loadImpl(path, sqlContext, className, treeClassName) + val numFeatures = (metadata.metadata \ "numFeatures").extract[Int] + val numClasses = (metadata.metadata \ "numClasses").extract[Int] + val numTrees = (metadata.metadata \ "numTrees").extract[Int] + + val trees: Array[DecisionTreeClassificationModel] = treesData.map { + case (treeMetadata, root) => + val tree = + new DecisionTreeClassificationModel(treeMetadata.uid, root, numFeatures, numClasses) + DefaultParamsReader.getAndSetParams(tree, treeMetadata) + tree + } + require(numTrees == trees.length, s"RandomForestClassificationModel.load expected $numTrees" + + s" trees based on metadata but found ${trees.length} trees.") + + val model = new RandomForestClassificationModel(metadata.uid, trees, numFeatures, numClasses) + DefaultParamsReader.getAndSetParams(model, metadata) + model + } + } - /** (private[ml]) Convert a model from the old API */ - def fromOld( + /** Convert a model from the old API */ + private[ml] def fromOld( oldModel: OldRandomForestModel, parent: RandomForestClassifier, categoricalFeatures: Map[Int, Int], diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala index 8fca35da51..02e124a1c0 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala @@ -179,7 +179,7 @@ final class GBTRegressionModel private[ml]( extends PredictionModel[Vector, GBTRegressionModel] with TreeEnsembleModel with Serializable { - require(numTrees > 0, "GBTRegressionModel requires at least 1 tree.") + require(_trees.nonEmpty, "GBTRegressionModel requires at least 1 tree.") require(_trees.length == _treeWeights.length, "GBTRegressionModel given trees, treeWeights of" + s" non-matching lengths (${_trees.length}, ${_treeWeights.length}, respectively).") @@ -213,6 +213,9 @@ final class GBTRegressionModel private[ml]( blas.ddot(numTrees, treePredictions, 1, _treeWeights, 1) } + /** Number of trees in ensemble */ + val numTrees: Int = trees.length + @Since("1.4.0") override def copy(extra: ParamMap): GBTRegressionModel = { copyValues(new GBTRegressionModel(uid, _trees, _treeWeights, numFeatures), @@ -258,6 +261,6 @@ private[ml] object GBTRegressionModel { DecisionTreeRegressionModel.fromOld(tree, null, categoricalFeatures) } val uid = if (parent != null) parent.uid else Identifiable.randomUID("gbtr") - new GBTRegressionModel(parent.uid, newTrees, oldModel.treeWeights, numFeatures) + new GBTRegressionModel(uid, newTrees, oldModel.treeWeights, numFeatures) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala index 5b3f3a1f5d..ba56b5cd3f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala @@ -17,12 +17,16 @@ package org.apache.spark.ml.regression +import org.json4s.{DefaultFormats, JObject} +import org.json4s.JsonDSL._ + import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.{PredictionModel, Predictor} import org.apache.spark.ml.param.ParamMap -import org.apache.spark.ml.tree.{DecisionTreeModel, RandomForestParams, TreeEnsembleModel, TreeRegressorParams} +import org.apache.spark.ml.tree._ import org.apache.spark.ml.tree.impl.RandomForest -import org.apache.spark.ml.util.{Identifiable, MetadataUtils} +import org.apache.spark.ml.util._ +import org.apache.spark.ml.util.DefaultParamsReader.Metadata import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} @@ -41,7 +45,7 @@ import org.apache.spark.sql.functions._ @Experimental final class RandomForestRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: String) extends Predictor[Vector, RandomForestRegressor, RandomForestRegressionModel] - with RandomForestParams with TreeRegressorParams { + with RandomForestRegressorParams with DefaultParamsWritable { @Since("1.4.0") def this() = this(Identifiable.randomUID("rfr")) @@ -108,7 +112,7 @@ final class RandomForestRegressor @Since("1.4.0") (@Since("1.4.0") override val @Since("1.4.0") @Experimental -object RandomForestRegressor { +object RandomForestRegressor extends DefaultParamsReadable[RandomForestRegressor]{ /** Accessor for supported impurity settings: variance */ @Since("1.4.0") final val supportedImpurities: Array[String] = TreeRegressorParams.supportedImpurities @@ -117,12 +121,17 @@ object RandomForestRegressor { @Since("1.4.0") final val supportedFeatureSubsetStrategies: Array[String] = RandomForestParams.supportedFeatureSubsetStrategies + + @Since("2.0.0") + override def load(path: String): RandomForestRegressor = super.load(path) + } /** * :: Experimental :: * [[http://en.wikipedia.org/wiki/Random_forest Random Forest]] model for regression. * It supports both continuous and categorical features. + * * @param _trees Decision trees in the ensemble. * @param numFeatures Number of features used by this model */ @@ -133,12 +142,13 @@ final class RandomForestRegressionModel private[ml] ( private val _trees: Array[DecisionTreeRegressionModel], override val numFeatures: Int) extends PredictionModel[Vector, RandomForestRegressionModel] - with TreeEnsembleModel with Serializable { + with RandomForestRegressionModelParams with TreeEnsembleModel with MLWritable with Serializable { - require(numTrees > 0, "RandomForestRegressionModel requires at least 1 tree.") + require(_trees.nonEmpty, "RandomForestRegressionModel requires at least 1 tree.") /** * Construct a random forest regression model, with all trees weighted equally. + * * @param trees Component trees */ private[ml] def this(trees: Array[DecisionTreeRegressionModel], numFeatures: Int) = @@ -148,7 +158,7 @@ final class RandomForestRegressionModel private[ml] ( override def trees: Array[DecisionTreeModel] = _trees.asInstanceOf[Array[DecisionTreeModel]] // Note: We may add support for weights (based on tree performance) later on. - private lazy val _treeWeights: Array[Double] = Array.fill[Double](numTrees)(1.0) + private lazy val _treeWeights: Array[Double] = Array.fill[Double](_trees.length)(1.0) @Since("1.4.0") override def treeWeights: Array[Double] = _treeWeights @@ -165,9 +175,17 @@ final class RandomForestRegressionModel private[ml] ( // TODO: When we add a generic Bagging class, handle transform there. SPARK-7128 // Predict average of tree predictions. // Ignore the weights since all are 1.0 for now. - _trees.map(_.rootNode.predictImpl(features).prediction).sum / numTrees + _trees.map(_.rootNode.predictImpl(features).prediction).sum / getNumTrees } + /** + * Number of trees in ensemble + * @deprecated Use [[getNumTrees]] instead. This method will be removed in 2.1.0 + */ + // TODO: Once this is removed, then this class can inherit from RandomForestRegressorParams + @deprecated("Use getNumTrees instead. This method will be removed in 2.1.0.", "2.0.0") + val numTrees: Int = trees.length + @Since("1.4.0") override def copy(extra: ParamMap): RandomForestRegressionModel = { copyValues(new RandomForestRegressionModel(uid, _trees, numFeatures), extra).setParent(parent) @@ -175,7 +193,7 @@ final class RandomForestRegressionModel private[ml] ( @Since("1.4.0") override def toString: String = { - s"RandomForestRegressionModel (uid=$uid) with $numTrees trees" + s"RandomForestRegressionModel (uid=$uid) with $getNumTrees trees" } /** @@ -195,12 +213,63 @@ final class RandomForestRegressionModel private[ml] ( private[ml] def toOld: OldRandomForestModel = { new OldRandomForestModel(OldAlgo.Regression, _trees.map(_.toOld)) } + + @Since("2.0.0") + override def write: MLWriter = + new RandomForestRegressionModel.RandomForestRegressionModelWriter(this) } -private[ml] object RandomForestRegressionModel { +@Since("2.0.0") +object RandomForestRegressionModel extends MLReadable[RandomForestRegressionModel] { + + @Since("2.0.0") + override def read: MLReader[RandomForestRegressionModel] = new RandomForestRegressionModelReader + + @Since("2.0.0") + override def load(path: String): RandomForestRegressionModel = super.load(path) + + private[RandomForestRegressionModel] + class RandomForestRegressionModelWriter(instance: RandomForestRegressionModel) + extends MLWriter { + + override protected def saveImpl(path: String): Unit = { + val extraMetadata: JObject = Map( + "numFeatures" -> instance.numFeatures, + "numTrees" -> instance.getNumTrees) + EnsembleModelReadWrite.saveImpl(instance, path, sqlContext, extraMetadata) + } + } + + private class RandomForestRegressionModelReader extends MLReader[RandomForestRegressionModel] { + + /** Checked against metadata when loading model */ + private val className = classOf[RandomForestRegressionModel].getName + private val treeClassName = classOf[DecisionTreeRegressionModel].getName + + override def load(path: String): RandomForestRegressionModel = { + implicit val format = DefaultFormats + val (metadata: Metadata, treesData: Array[(Metadata, Node)]) = + EnsembleModelReadWrite.loadImpl(path, sqlContext, className, treeClassName) + val numFeatures = (metadata.metadata \ "numFeatures").extract[Int] + val numTrees = (metadata.metadata \ "numTrees").extract[Int] + + val trees: Array[DecisionTreeRegressionModel] = treesData.map { case (treeMetadata, root) => + val tree = + new DecisionTreeRegressionModel(treeMetadata.uid, root, numFeatures) + DefaultParamsReader.getAndSetParams(tree, treeMetadata) + tree + } + require(numTrees == trees.length, s"RandomForestRegressionModel.load expected $numTrees" + + s" trees based on metadata but found ${trees.length} trees.") + + val model = new RandomForestRegressionModel(metadata.uid, trees, numFeatures) + DefaultParamsReader.getAndSetParams(model, metadata) + model + } + } - /** (private[ml]) Convert a model from the old API */ - def fromOld( + /** Convert a model from the old API */ + private[ml] def fromOld( oldModel: OldRandomForestModel, parent: RandomForestRegressor, categoricalFeatures: Map[Int, Int], @@ -211,6 +280,7 @@ private[ml] object RandomForestRegressionModel { // parent for each tree is null since there is no good way to set this. DecisionTreeRegressionModel.fromOld(tree, null, categoricalFeatures) } - new RandomForestRegressionModel(parent.uid, newTrees, numFeatures) + val uid = if (parent != null) parent.uid else Identifiable.randomUID("rfr") + new RandomForestRegressionModel(uid, newTrees, numFeatures) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala index 8ea767b2b3..48b8fd19ad 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala @@ -21,12 +21,15 @@ import org.apache.hadoop.fs.Path import org.json4s._ import org.json4s.jackson.JsonMethods._ -import org.apache.spark.ml.param.Param -import org.apache.spark.ml.util.DefaultParamsReader +import org.apache.spark.ml.param.{Param, Params} +import org.apache.spark.ml.tree.DecisionTreeModelReadWrite.NodeData +import org.apache.spark.ml.util.{DefaultParamsReader, DefaultParamsWriter} +import org.apache.spark.ml.util.DefaultParamsReader.Metadata import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.tree.impurity.ImpurityCalculator import org.apache.spark.mllib.tree.model.{DecisionTreeModel => OldDecisionTreeModel} -import org.apache.spark.sql.SQLContext +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{Dataset, SQLContext} import org.apache.spark.util.collection.OpenHashMap /** @@ -88,6 +91,11 @@ private[ml] trait TreeEnsembleModel { /** Trees in this ensemble. Warning: These have null parent Estimators. */ def trees: Array[DecisionTreeModel] + /** + * Number of trees in ensemble + */ + val getNumTrees: Int = trees.length + /** Weights for each tree, zippable with [[trees]] */ def treeWeights: Array[Double] @@ -98,7 +106,7 @@ private[ml] trait TreeEnsembleModel { /** Summary of the model */ override def toString: String = { // Implementing classes should generally override this method to be more descriptive. - s"TreeEnsembleModel with $numTrees trees" + s"TreeEnsembleModel with ${trees.length} trees" } /** Full description of model */ @@ -109,9 +117,6 @@ private[ml] trait TreeEnsembleModel { }.fold("")(_ + _) } - /** Number of trees in ensemble */ - val numTrees: Int = trees.length - /** Total number of nodes, summed over all trees in the ensemble. */ lazy val totalNumNodes: Int = trees.map(_.numNodes).sum } @@ -316,6 +321,10 @@ private[ml] object DecisionTreeModelReadWrite { } } + /** + * Load a decision tree from a file. + * @return Root node of reconstructed tree + */ def loadTreeNodes( path: String, metadata: DefaultParamsReader.Metadata, @@ -331,9 +340,18 @@ private[ml] object DecisionTreeModelReadWrite { val dataPath = new Path(path, "data").toString val data = sqlContext.read.parquet(dataPath).as[NodeData] + buildTreeFromNodes(data.collect(), impurityType) + } + /** + * Given all data for all nodes in a tree, rebuild the tree. + * @param data Unsorted node data + * @param impurityType Impurity type for this tree + * @return Root node of reconstructed tree + */ + def buildTreeFromNodes(data: Array[NodeData], impurityType: String): Node = { // Load all nodes, sorted by ID. - val nodes: Array[NodeData] = data.collect().sortBy(_.id) + val nodes = data.sortBy(_.id) // Sanity checks; could remove assert(nodes.head.id == 0, s"Decision Tree load failed. Expected smallest node ID to be 0," + s" but found ${nodes.head.id}") @@ -358,3 +376,100 @@ private[ml] object DecisionTreeModelReadWrite { finalNodes.head } } + +private[ml] object EnsembleModelReadWrite { + + /** + * Helper method for saving a tree ensemble to disk. + * + * @param instance Tree ensemble model + * @param path Path to which to save the ensemble model. + * @param extraMetadata Metadata such as numFeatures, numClasses, numTrees. + */ + def saveImpl[M <: Params with TreeEnsembleModel]( + instance: M, + path: String, + sql: SQLContext, + extraMetadata: JObject): Unit = { + DefaultParamsWriter.saveMetadata(instance, path, sql.sparkContext, Some(extraMetadata)) + val treesMetadataJson: Array[(Int, String)] = instance.trees.zipWithIndex.map { + case (tree, treeID) => + treeID -> DefaultParamsWriter.getMetadataToSave(tree.asInstanceOf[Params], sql.sparkContext) + } + val treesMetadataPath = new Path(path, "treesMetadata").toString + sql.createDataFrame(treesMetadataJson).toDF("treeID", "metadata") + .write.parquet(treesMetadataPath) + val dataPath = new Path(path, "data").toString + val nodeDataRDD = sql.sparkContext.parallelize(instance.trees.zipWithIndex).flatMap { + case (tree, treeID) => EnsembleNodeData.build(tree, treeID) + } + sql.createDataFrame(nodeDataRDD).write.parquet(dataPath) + } + + /** + * Helper method for loading a tree ensemble from disk. + * This reconstructs all trees, returning the root nodes. + * @param path Path given to [[saveImpl()]] + * @param className Class name for ensemble model type + * @param treeClassName Class name for tree model type in the ensemble + * @return (ensemble metadata, array over trees of (tree metadata, root node)), + * where the root node is linked with all descendents + * @see [[saveImpl()]] for how the model was saved + */ + def loadImpl( + path: String, + sql: SQLContext, + className: String, + treeClassName: String): (Metadata, Array[(Metadata, Node)]) = { + import sql.implicits._ + implicit val format = DefaultFormats + val metadata = DefaultParamsReader.loadMetadata(path, sql.sparkContext, className) + + // Get impurity to construct ImpurityCalculator for each node + val impurityType: String = { + val impurityJson: JValue = metadata.getParamValue("impurity") + Param.jsonDecode[String](compact(render(impurityJson))) + } + + val treesMetadataPath = new Path(path, "treesMetadata").toString + val treesMetadataRDD: RDD[(Int, Metadata)] = sql.read.parquet(treesMetadataPath) + .select("treeID", "metadata").as[(Int, String)].rdd.map { + case (treeID: Int, json: String) => + treeID -> DefaultParamsReader.parseMetadata(json, treeClassName) + } + val treesMetadata: Array[Metadata] = treesMetadataRDD.sortByKey().values.collect() + + val dataPath = new Path(path, "data").toString + val nodeData: Dataset[EnsembleNodeData] = + sql.read.parquet(dataPath).as[EnsembleNodeData] + val rootNodesRDD: RDD[(Int, Node)] = + nodeData.rdd.map(d => (d.treeID, d.nodeData)).groupByKey().map { + case (treeID: Int, nodeData: Iterable[NodeData]) => + treeID -> DecisionTreeModelReadWrite.buildTreeFromNodes(nodeData.toArray, impurityType) + } + val rootNodes: Array[Node] = rootNodesRDD.sortByKey().values.collect() + (metadata, treesMetadata.zip(rootNodes)) + } + + /** + * Info for one [[Node]] in a tree ensemble + * + * @param treeID Tree index + * @param nodeData Data for this node + */ + case class EnsembleNodeData( + treeID: Int, + nodeData: NodeData) + + object EnsembleNodeData { + /** + * Create [[EnsembleNodeData]] instances for the given tree. + * + * @return Sequence of nodes for this tree + */ + def build(tree: DecisionTreeModel, treeID: Int): Seq[EnsembleNodeData] = { + val (nodeData: Seq[NodeData], _) = NodeData.build(tree.rootNode, 0) + nodeData.map(nd => EnsembleNodeData(treeID, nd)) + } + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala index 4fbd957677..78e6d3bfac 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala @@ -315,22 +315,8 @@ private[ml] trait TreeEnsembleParams extends DecisionTreeParams { } } -/** - * Parameters for Random Forest algorithms. - * - * Note: Marked as private and DeveloperApi since this may be made public in the future. - */ -private[ml] trait RandomForestParams extends TreeEnsembleParams { - - /** - * Number of trees to train (>= 1). - * If 1, then no bootstrapping is used. If > 1, then bootstrapping is done. - * TODO: Change to always do bootstrapping (simpler). SPARK-7130 - * (default = 20) - * @group param - */ - final val numTrees: IntParam = new IntParam(this, "numTrees", "Number of trees to train (>= 1)", - ParamValidators.gtEq(1)) +/** Used for [[RandomForestParams]] */ +private[ml] trait HasFeatureSubsetStrategy extends Params { /** * The number of features to consider for splits at each tree node. @@ -362,27 +348,65 @@ private[ml] trait RandomForestParams extends TreeEnsembleParams { (value: String) => RandomForestParams.supportedFeatureSubsetStrategies.contains(value.toLowerCase)) - setDefault(numTrees -> 20, featureSubsetStrategy -> "auto") + setDefault(featureSubsetStrategy -> "auto") /** @group setParam */ - def setNumTrees(value: Int): this.type = set(numTrees, value) + def setFeatureSubsetStrategy(value: String): this.type = set(featureSubsetStrategy, value) /** @group getParam */ - final def getNumTrees: Int = $(numTrees) + final def getFeatureSubsetStrategy: String = $(featureSubsetStrategy).toLowerCase +} + +/** + * Used for [[RandomForestParams]]. + * This is separated out from [[RandomForestParams]] because of an issue with the + * `numTrees` method conflicting with this Param in the Estimator. + */ +private[ml] trait HasNumTrees extends Params { + + /** + * Number of trees to train (>= 1). + * If 1, then no bootstrapping is used. If > 1, then bootstrapping is done. + * TODO: Change to always do bootstrapping (simpler). SPARK-7130 + * (default = 20) + * @group param + */ + final val numTrees: IntParam = new IntParam(this, "numTrees", "Number of trees to train (>= 1)", + ParamValidators.gtEq(1)) + + setDefault(numTrees -> 20) /** @group setParam */ - def setFeatureSubsetStrategy(value: String): this.type = set(featureSubsetStrategy, value) + def setNumTrees(value: Int): this.type = set(numTrees, value) /** @group getParam */ - final def getFeatureSubsetStrategy: String = $(featureSubsetStrategy).toLowerCase + final def getNumTrees: Int = $(numTrees) } +/** + * Parameters for Random Forest algorithms. + */ +private[ml] trait RandomForestParams extends TreeEnsembleParams + with HasFeatureSubsetStrategy with HasNumTrees + private[spark] object RandomForestParams { // These options should be lowercase. final val supportedFeatureSubsetStrategies: Array[String] = Array("auto", "all", "onethird", "sqrt", "log2").map(_.toLowerCase) } +private[ml] trait RandomForestClassifierParams + extends RandomForestParams with TreeClassifierParams + +private[ml] trait RandomForestClassificationModelParams extends TreeEnsembleParams + with HasFeatureSubsetStrategy with TreeClassifierParams + +private[ml] trait RandomForestRegressorParams + extends RandomForestParams with TreeRegressorParams + +private[ml] trait RandomForestRegressionModelParams extends TreeEnsembleParams + with HasFeatureSubsetStrategy with TreeRegressorParams + /** * Parameters for Gradient-Boosted Tree algorithms. * diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala index 39999ede30..7dec07ea14 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala @@ -144,6 +144,7 @@ private[ml] trait DefaultParamsWritable extends MLWritable { self: Params => /** * Abstract class for utility classes that can load ML instances. + * * @tparam T ML instance type */ @Experimental @@ -162,6 +163,7 @@ abstract class MLReader[T] extends BaseReadWrite { /** * Trait for objects that provide [[MLReader]]. + * * @tparam T ML instance type */ @Experimental @@ -192,6 +194,7 @@ private[ml] trait DefaultParamsReadable[T] extends MLReadable[T] { * Default [[MLWriter]] implementation for transformers and estimators that contain basic * (json4s-serializable) params and no data. This will not handle more complex params or types with * data (e.g., models with coefficients). + * * @param instance object to save */ private[ml] class DefaultParamsWriter(instance: Params) extends MLWriter { @@ -211,6 +214,7 @@ private[ml] object DefaultParamsWriter { * - uid * - paramMap * - (optionally, extra metadata) + * * @param extraMetadata Extra metadata to be saved at same level as uid, paramMap, etc. * @param paramMap If given, this is saved in the "paramMap" field. * Otherwise, all [[org.apache.spark.ml.param.Param]]s are encoded using @@ -222,6 +226,22 @@ private[ml] object DefaultParamsWriter { sc: SparkContext, extraMetadata: Option[JObject] = None, paramMap: Option[JValue] = None): Unit = { + val metadataPath = new Path(path, "metadata").toString + val metadataJson = getMetadataToSave(instance, sc, extraMetadata, paramMap) + sc.parallelize(Seq(metadataJson), 1).saveAsTextFile(metadataPath) + } + + /** + * Helper for [[saveMetadata()]] which extracts the JSON to save. + * This is useful for ensemble models which need to save metadata for many sub-models. + * + * @see [[saveMetadata()]] for details on what this includes. + */ + def getMetadataToSave( + instance: Params, + sc: SparkContext, + extraMetadata: Option[JObject] = None, + paramMap: Option[JValue] = None): String = { val uid = instance.uid val cls = instance.getClass.getName val params = instance.extractParamMap().toSeq.asInstanceOf[Seq[ParamPair[Any]]] @@ -239,9 +259,8 @@ private[ml] object DefaultParamsWriter { case None => basicMetadata } - val metadataPath = new Path(path, "metadata").toString - val metadataJson = compact(render(metadata)) - sc.parallelize(Seq(metadataJson), 1).saveAsTextFile(metadataPath) + val metadataJson: String = compact(render(metadata)) + metadataJson } } @@ -249,6 +268,7 @@ private[ml] object DefaultParamsWriter { * Default [[MLReader]] implementation for transformers and estimators that contain basic * (json4s-serializable) params and no data. This will not handle more complex params or types with * data (e.g., models with coefficients). + * * @tparam T ML instance type * TODO: Consider adding check for correct class name. */ @@ -268,6 +288,7 @@ private[ml] object DefaultParamsReader { /** * All info from metadata file. + * * @param params paramMap, as a [[JValue]] * @param metadata All metadata, including the other fields * @param metadataJson Full metadata file String (for debugging) @@ -304,13 +325,26 @@ private[ml] object DefaultParamsReader { } /** - * Load metadata from file. + * Load metadata saved using [[DefaultParamsWriter.saveMetadata()]] + * * @param expectedClassName If non empty, this is checked against the loaded metadata. * @throws IllegalArgumentException if expectedClassName is specified and does not match metadata */ def loadMetadata(path: String, sc: SparkContext, expectedClassName: String = ""): Metadata = { val metadataPath = new Path(path, "metadata").toString val metadataStr = sc.textFile(metadataPath, 1).first() + parseMetadata(metadataStr, expectedClassName) + } + + /** + * Parse metadata JSON string produced by [[DefaultParamsWriter.getMetadataToSave()]]. + * This is a helper function for [[loadMetadata()]]. + * + * @param metadataStr JSON string of metadata + * @param expectedClassName If non empty, this is checked against the loaded metadata. + * @throws IllegalArgumentException if expectedClassName is specified and does not match metadata + */ + def parseMetadata(metadataStr: String, expectedClassName: String = ""): Metadata = { val metadata = parse(metadataStr) implicit val format = DefaultFormats diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala index 052bc83c38..aaaa429103 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala @@ -21,7 +21,7 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.tree.LeafNode import org.apache.spark.ml.tree.impl.TreeTests -import org.apache.spark.ml.util.MLTestingUtils +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.{EnsembleTestHelper, RandomForest => OldRandomForest} @@ -34,7 +34,8 @@ import org.apache.spark.sql.{DataFrame, Row} /** * Test suite for [[RandomForestClassifier]]. */ -class RandomForestClassifierSuite extends SparkFunSuite with MLlibTestSparkContext { +class RandomForestClassifierSuite + extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { import RandomForestClassifierSuite.compareAPIs @@ -190,27 +191,24 @@ class RandomForestClassifierSuite extends SparkFunSuite with MLlibTestSparkConte // Tests of model save/load ///////////////////////////////////////////////////////////////////////////// - // TODO: Reinstate test once save/load are implemented SPARK-6725 - /* - test("model save/load") { - val tempDir = Utils.createTempDir() - val path = tempDir.toURI.toString - - val trees = - Range(0, 3).map(_ => OldDecisionTreeSuite.createModel(OldAlgo.Classification)).toArray - val oldModel = new OldRandomForestModel(OldAlgo.Classification, trees) - val newModel = RandomForestClassificationModel.fromOld(oldModel) - - // Save model, load it back, and compare. - try { - newModel.save(sc, path) - val sameNewModel = RandomForestClassificationModel.load(sc, path) - TreeTests.checkEqual(newModel, sameNewModel) - } finally { - Utils.deleteRecursively(tempDir) + test("read/write") { + def checkModelData( + model: RandomForestClassificationModel, + model2: RandomForestClassificationModel): Unit = { + TreeTests.checkEqual(model, model2) + assert(model.numFeatures === model2.numFeatures) + assert(model.numClasses === model2.numClasses) } + + val rf = new RandomForestClassifier().setNumTrees(2) + val rdd = TreeTests.getTreeReadWriteData(sc) + + val allParamSettings = TreeTests.allParamSettings ++ Map("impurity" -> "entropy") + + val continuousData: DataFrame = + TreeTests.setMetadata(rdd, Map.empty[Int, Int], numClasses = 2) + testEstimatorAndModelReadWrite(rf, continuousData, allParamSettings, checkModelData) } - */ } private object RandomForestClassifierSuite extends SparkFunSuite { diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala index 2ab4f1b146..ca400e1914 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.ml.regression import org.apache.spark.SparkFunSuite import org.apache.spark.ml.tree.impl.TreeTests -import org.apache.spark.ml.util.MLTestingUtils +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.{EnsembleTestHelper, RandomForest => OldRandomForest} import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} @@ -30,7 +30,8 @@ import org.apache.spark.sql.DataFrame /** * Test suite for [[RandomForestRegressor]]. */ -class RandomForestRegressorSuite extends SparkFunSuite with MLlibTestSparkContext { +class RandomForestRegressorSuite extends SparkFunSuite with MLlibTestSparkContext + with DefaultReadWriteTest{ import RandomForestRegressorSuite.compareAPIs @@ -106,26 +107,23 @@ class RandomForestRegressorSuite extends SparkFunSuite with MLlibTestSparkContex // Tests of model save/load ///////////////////////////////////////////////////////////////////////////// - // TODO: Reinstate test once save/load are implemented SPARK-6725 - /* - test("model save/load") { - val tempDir = Utils.createTempDir() - val path = tempDir.toURI.toString - - val trees = Range(0, 3).map(_ => OldDecisionTreeSuite.createModel(OldAlgo.Regression)).toArray - val oldModel = new OldRandomForestModel(OldAlgo.Regression, trees) - val newModel = RandomForestRegressionModel.fromOld(oldModel) - - // Save model, load it back, and compare. - try { - newModel.save(sc, path) - val sameNewModel = RandomForestRegressionModel.load(sc, path) - TreeTests.checkEqual(newModel, sameNewModel) - } finally { - Utils.deleteRecursively(tempDir) + test("read/write") { + def checkModelData( + model: RandomForestRegressionModel, + model2: RandomForestRegressionModel): Unit = { + TreeTests.checkEqual(model, model2) + assert(model.numFeatures === model2.numFeatures) } + + val rf = new RandomForestRegressor().setNumTrees(2) + val rdd = TreeTests.getTreeReadWriteData(sc) + + val allParamSettings = TreeTests.allParamSettings ++ Map("impurity" -> "variance") + + val continuousData: DataFrame = + TreeTests.setMetadata(rdd, Map.empty[Int, Int], numClasses = 0) + testEstimatorAndModelReadWrite(rf, continuousData, allParamSettings, checkModelData) } - */ } private object RandomForestRegressorSuite extends SparkFunSuite { -- cgit v1.2.3 From 855ed44ed31210d2001d7ce67c8fa99f8416edd3 Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Mon, 4 Apr 2016 10:54:06 -0700 Subject: [SPARK-14176][SQL] Add DataFrameWriter.trigger to set the stream batch period ## What changes were proposed in this pull request? Add a processing time trigger to control the batch processing speed ## How was this patch tested? Unit tests Author: Shixiong Zhu Closes #11976 from zsxwing/trigger. --- .../apache/spark/sql/ContinuousQueryManager.scala | 11 +- .../org/apache/spark/sql/DataFrameWriter.scala | 34 +++++- .../main/scala/org/apache/spark/sql/Trigger.scala | 133 +++++++++++++++++++++ .../sql/execution/streaming/StreamExecution.scala | 24 ++-- .../sql/execution/streaming/TriggerExecutor.scala | 72 +++++++++++ .../org/apache/spark/sql/ProcessingTimeSuite.scala | 40 +++++++ .../scala/org/apache/spark/sql/StreamTest.scala | 6 +- .../streaming/ProcessingTimeExecutorSuite.scala | 78 ++++++++++++ .../sql/streaming/DataFrameReaderWriterSuite.scala | 28 +++++ 9 files changed, 413 insertions(+), 13 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/Trigger.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TriggerExecutor.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/ProcessingTimeSuite.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ProcessingTimeExecutorSuite.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/ContinuousQueryManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/ContinuousQueryManager.scala index 465feeb604..2306df09b8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/ContinuousQueryManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/ContinuousQueryManager.scala @@ -171,13 +171,20 @@ class ContinuousQueryManager(sqlContext: SQLContext) { name: String, checkpointLocation: String, df: DataFrame, - sink: Sink): ContinuousQuery = { + sink: Sink, + trigger: Trigger = ProcessingTime(0)): ContinuousQuery = { activeQueriesLock.synchronized { if (activeQueries.contains(name)) { throw new IllegalArgumentException( s"Cannot start query with name $name as a query with that name is already active") } - val query = new StreamExecution(sqlContext, name, checkpointLocation, df.logicalPlan, sink) + val query = new StreamExecution( + sqlContext, + name, + checkpointLocation, + df.logicalPlan, + sink, + trigger) query.start() activeQueries.put(name, query) query diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index c07bd0e7b7..3332a997cd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -77,6 +77,35 @@ final class DataFrameWriter private[sql](df: DataFrame) { this } + /** + * :: Experimental :: + * Set the trigger for the stream query. The default value is `ProcessingTime(0)` and it will run + * the query as fast as possible. + * + * Scala Example: + * {{{ + * def.writer.trigger(ProcessingTime("10 seconds")) + * + * import scala.concurrent.duration._ + * def.writer.trigger(ProcessingTime(10.seconds)) + * }}} + * + * Java Example: + * {{{ + * def.writer.trigger(ProcessingTime.create("10 seconds")) + * + * import java.util.concurrent.TimeUnit + * def.writer.trigger(ProcessingTime.create(10, TimeUnit.SECONDS)) + * }}} + * + * @since 2.0.0 + */ + @Experimental + def trigger(trigger: Trigger): DataFrameWriter = { + this.trigger = trigger + this + } + /** * Specifies the underlying output data source. Built-in options include "parquet", "json", etc. * @@ -261,7 +290,8 @@ final class DataFrameWriter private[sql](df: DataFrame) { queryName, checkpointLocation, df, - dataSource.createSink()) + dataSource.createSink(), + trigger) } /** @@ -552,6 +582,8 @@ final class DataFrameWriter private[sql](df: DataFrame) { private var mode: SaveMode = SaveMode.ErrorIfExists + private var trigger: Trigger = ProcessingTime(0L) + private var extraOptions = new scala.collection.mutable.HashMap[String, String] private var partitioningColumns: Option[Seq[String]] = None diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Trigger.scala b/sql/core/src/main/scala/org/apache/spark/sql/Trigger.scala new file mode 100644 index 0000000000..c4e54b3f90 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/Trigger.scala @@ -0,0 +1,133 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import java.util.concurrent.TimeUnit + +import scala.concurrent.duration.Duration + +import org.apache.commons.lang3.StringUtils + +import org.apache.spark.annotation.Experimental +import org.apache.spark.unsafe.types.CalendarInterval + +/** + * :: Experimental :: + * Used to indicate how often results should be produced by a [[ContinuousQuery]]. + */ +@Experimental +sealed trait Trigger {} + +/** + * :: Experimental :: + * A trigger that runs a query periodically based on the processing time. If `intervalMs` is 0, + * the query will run as fast as possible. + * + * Scala Example: + * {{{ + * def.writer.trigger(ProcessingTime("10 seconds")) + * + * import scala.concurrent.duration._ + * def.writer.trigger(ProcessingTime(10.seconds)) + * }}} + * + * Java Example: + * {{{ + * def.writer.trigger(ProcessingTime.create("10 seconds")) + * + * import java.util.concurrent.TimeUnit + * def.writer.trigger(ProcessingTime.create(10, TimeUnit.SECONDS)) + * }}} + */ +@Experimental +case class ProcessingTime(intervalMs: Long) extends Trigger { + require(intervalMs >= 0, "the interval of trigger should not be negative") +} + +/** + * :: Experimental :: + * Used to create [[ProcessingTime]] triggers for [[ContinuousQuery]]s. + */ +@Experimental +object ProcessingTime { + + /** + * Create a [[ProcessingTime]]. If `intervalMs` is 0, the query will run as fast as possible. + * + * Example: + * {{{ + * def.writer.trigger(ProcessingTime("10 seconds")) + * }}} + */ + def apply(interval: String): ProcessingTime = { + if (StringUtils.isBlank(interval)) { + throw new IllegalArgumentException( + "interval cannot be null or blank.") + } + val cal = if (interval.startsWith("interval")) { + CalendarInterval.fromString(interval) + } else { + CalendarInterval.fromString("interval " + interval) + } + if (cal == null) { + throw new IllegalArgumentException(s"Invalid interval: $interval") + } + if (cal.months > 0) { + throw new IllegalArgumentException(s"Doesn't support month or year interval: $interval") + } + new ProcessingTime(cal.microseconds / 1000) + } + + /** + * Create a [[ProcessingTime]]. If `intervalMs` is 0, the query will run as fast as possible. + * + * Example: + * {{{ + * import scala.concurrent.duration._ + * def.writer.trigger(ProcessingTime(10.seconds)) + * }}} + */ + def apply(interval: Duration): ProcessingTime = { + new ProcessingTime(interval.toMillis) + } + + /** + * Create a [[ProcessingTime]]. If `intervalMs` is 0, the query will run as fast as possible. + * + * Example: + * {{{ + * def.writer.trigger(ProcessingTime.create("10 seconds")) + * }}} + */ + def create(interval: String): ProcessingTime = { + apply(interval) + } + + /** + * Create a [[ProcessingTime]]. If `intervalMs` is 0, the query will run as fast as possible. + * + * Example: + * {{{ + * import java.util.concurrent.TimeUnit + * def.writer.trigger(ProcessingTime.create(10, TimeUnit.SECONDS)) + * }}} + */ + def create(interval: Long, unit: TimeUnit): ProcessingTime = { + new ProcessingTime(unit.toMillis(interval)) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala index 511e30c70c..64f80699ce 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala @@ -47,16 +47,14 @@ class StreamExecution( override val name: String, val checkpointRoot: String, private[sql] val logicalPlan: LogicalPlan, - val sink: Sink) extends ContinuousQuery with Logging { + val sink: Sink, + val trigger: Trigger) extends ContinuousQuery with Logging { /** An monitor used to wait/notify when batches complete. */ private val awaitBatchLock = new Object private val startLatch = new CountDownLatch(1) private val terminationLatch = new CountDownLatch(1) - /** Minimum amount of time in between the start of each batch. */ - private val minBatchTime = 10 - /** * Tracks how much data we have processed and committed to the sink or state store from each * input source. @@ -79,6 +77,10 @@ class StreamExecution( /** A list of unique sources in the query plan. */ private val uniqueSources = sources.distinct + private val triggerExecutor = trigger match { + case t: ProcessingTime => ProcessingTimeExecutor(t) + } + /** Defines the internal state of execution */ @volatile private var state: State = INITIALIZED @@ -154,11 +156,15 @@ class StreamExecution( SQLContext.setActive(sqlContext) populateStartOffsets() logDebug(s"Stream running from $committedOffsets to $availableOffsets") - while (isActive) { - if (dataAvailable) runBatch() - commitAndConstructNextBatch() - Thread.sleep(minBatchTime) // TODO: Could be tighter - } + triggerExecutor.execute(() => { + if (isActive) { + if (dataAvailable) runBatch() + commitAndConstructNextBatch() + true + } else { + false + } + }) } catch { case _: InterruptedException if state == TERMINATED => // interrupted by stop() case NonFatal(e) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TriggerExecutor.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TriggerExecutor.scala new file mode 100644 index 0000000000..a1132d5106 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TriggerExecutor.scala @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.ProcessingTime +import org.apache.spark.util.{Clock, SystemClock} + +trait TriggerExecutor { + + /** + * Execute batches using `batchRunner`. If `batchRunner` runs `false`, terminate the execution. + */ + def execute(batchRunner: () => Boolean): Unit +} + +/** + * A trigger executor that runs a batch every `intervalMs` milliseconds. + */ +case class ProcessingTimeExecutor(processingTime: ProcessingTime, clock: Clock = new SystemClock()) + extends TriggerExecutor with Logging { + + private val intervalMs = processingTime.intervalMs + + override def execute(batchRunner: () => Boolean): Unit = { + while (true) { + val batchStartTimeMs = clock.getTimeMillis() + val terminated = !batchRunner() + if (intervalMs > 0) { + val batchEndTimeMs = clock.getTimeMillis() + val batchElapsedTimeMs = batchEndTimeMs - batchStartTimeMs + if (batchElapsedTimeMs > intervalMs) { + notifyBatchFallingBehind(batchElapsedTimeMs) + } + if (terminated) { + return + } + clock.waitTillTime(nextBatchTime(batchEndTimeMs)) + } else { + if (terminated) { + return + } + } + } + } + + /** Called when a batch falls behind. Expose for test only */ + def notifyBatchFallingBehind(realElapsedTimeMs: Long): Unit = { + logWarning("Current batch is falling behind. The trigger interval is " + + s"${intervalMs} milliseconds, but spent ${realElapsedTimeMs} milliseconds") + } + + /** Return the next multiple of intervalMs */ + def nextBatchTime(now: Long): Long = { + (now - 1) / intervalMs * intervalMs + intervalMs + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ProcessingTimeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ProcessingTimeSuite.scala new file mode 100644 index 0000000000..0d18a645f6 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/ProcessingTimeSuite.scala @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import java.util.concurrent.TimeUnit + +import scala.concurrent.duration._ + +import org.apache.spark.SparkFunSuite + +class ProcessingTimeSuite extends SparkFunSuite { + + test("create") { + assert(ProcessingTime(10.seconds).intervalMs === 10 * 1000) + assert(ProcessingTime.create(10, TimeUnit.SECONDS).intervalMs === 10 * 1000) + assert(ProcessingTime("1 minute").intervalMs === 60 * 1000) + assert(ProcessingTime("interval 1 minute").intervalMs === 60 * 1000) + + intercept[IllegalArgumentException] { ProcessingTime(null: String) } + intercept[IllegalArgumentException] { ProcessingTime("") } + intercept[IllegalArgumentException] { ProcessingTime("invalid") } + intercept[IllegalArgumentException] { ProcessingTime("1 month") } + intercept[IllegalArgumentException] { ProcessingTime("1 year") } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala index 550c3c6f9c..3444e56e9e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala @@ -288,7 +288,11 @@ trait StreamTest extends QueryTest with Timeouts { currentStream = sqlContext .streams - .startQuery(StreamExecution.nextName, metadataRoot, stream, sink) + .startQuery( + StreamExecution.nextName, + metadataRoot, + stream, + sink) .asInstanceOf[StreamExecution] currentStream.microBatchThread.setUncaughtExceptionHandler( new UncaughtExceptionHandler { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ProcessingTimeExecutorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ProcessingTimeExecutorSuite.scala new file mode 100644 index 0000000000..dd5f92248b --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ProcessingTimeExecutorSuite.scala @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming + +import java.util.concurrent.{CountDownLatch, TimeUnit} + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.ProcessingTime +import org.apache.spark.util.ManualClock + +class ProcessingTimeExecutorSuite extends SparkFunSuite { + + test("nextBatchTime") { + val processingTimeExecutor = ProcessingTimeExecutor(ProcessingTime(100)) + assert(processingTimeExecutor.nextBatchTime(1) === 100) + assert(processingTimeExecutor.nextBatchTime(99) === 100) + assert(processingTimeExecutor.nextBatchTime(100) === 100) + assert(processingTimeExecutor.nextBatchTime(101) === 200) + assert(processingTimeExecutor.nextBatchTime(150) === 200) + } + + private def testBatchTermination(intervalMs: Long): Unit = { + var batchCounts = 0 + val processingTimeExecutor = ProcessingTimeExecutor(ProcessingTime(intervalMs)) + processingTimeExecutor.execute(() => { + batchCounts += 1 + // If the batch termination works well, batchCounts should be 3 after `execute` + batchCounts < 3 + }) + assert(batchCounts === 3) + } + + test("batch termination") { + testBatchTermination(0) + testBatchTermination(10) + } + + test("notifyBatchFallingBehind") { + val clock = new ManualClock() + @volatile var batchFallingBehindCalled = false + val latch = new CountDownLatch(1) + val t = new Thread() { + override def run(): Unit = { + val processingTimeExecutor = new ProcessingTimeExecutor(ProcessingTime(100), clock) { + override def notifyBatchFallingBehind(realElapsedTimeMs: Long): Unit = { + batchFallingBehindCalled = true + } + } + processingTimeExecutor.execute(() => { + latch.countDown() + clock.waitTillTime(200) + false + }) + } + } + t.start() + // Wait until the batch is running so that we don't call `advance` too early + assert(latch.await(10, TimeUnit.SECONDS), "the batch has not yet started in 10 seconds") + clock.advance(200) + t.join() + assert(batchFallingBehindCalled === true) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/DataFrameReaderWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/DataFrameReaderWriterSuite.scala index 102473d7d0..28c558208f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/DataFrameReaderWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/DataFrameReaderWriterSuite.scala @@ -17,6 +17,10 @@ package org.apache.spark.sql.streaming.test +import java.util.concurrent.TimeUnit + +import scala.concurrent.duration._ + import org.scalatest.BeforeAndAfter import org.apache.spark.sql._ @@ -275,4 +279,28 @@ class DataFrameReaderWriterSuite extends StreamTest with SharedSQLContext with B assert(activeStreamNames.contains("name")) sqlContext.streams.active.foreach(_.stop()) } + + test("trigger") { + val df = sqlContext.read + .format("org.apache.spark.sql.streaming.test") + .stream("/test") + + var q = df.write + .format("org.apache.spark.sql.streaming.test") + .option("checkpointLocation", newMetadataDir) + .trigger(ProcessingTime(10.seconds)) + .startStream() + q.stop() + + assert(q.asInstanceOf[StreamExecution].trigger == ProcessingTime(10000)) + + q = df.write + .format("org.apache.spark.sql.streaming.test") + .option("checkpointLocation", newMetadataDir) + .trigger(ProcessingTime.create(100, TimeUnit.SECONDS)) + .startStream() + q.stop() + + assert(q.asInstanceOf[StreamExecution].trigger == ProcessingTime(100000)) + } } -- cgit v1.2.3 From 5743c6476dbef50852b7f9873112a2d299966ebd Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Mon, 4 Apr 2016 10:56:26 -0700 Subject: [SPARK-12981] [SQL] extract Pyhton UDF in physical plan ## What changes were proposed in this pull request? Currently we extract Python UDFs into a special logical plan EvaluatePython in analyzer, But EvaluatePython is not part of catalyst, many rules have no knowledge of it , which will break many things (for example, filter push down or column pruning). We should treat Python UDFs as normal expressions, until we want to evaluate in physical plan, we could extract them in end of optimizer, or physical plan. This PR extract Python UDFs in physical plan. Closes #10935 ## How was this patch tested? Added regression tests. Author: Davies Liu Closes #12127 from davies/py_udf. --- python/pyspark/sql/tests.py | 9 +++ .../spark/sql/execution/QueryExecution.scala | 1 + .../spark/sql/execution/SparkStrategies.scala | 2 - .../sql/execution/python/EvaluatePython.scala | 23 ------ .../sql/execution/python/ExtractPythonUDFs.scala | 94 ++++++++++++---------- .../spark/sql/execution/python/PythonUDF.scala | 3 +- .../apache/spark/sql/internal/SessionState.scala | 1 - .../apache/spark/sql/hive/HiveSessionState.scala | 1 - 8 files changed, 64 insertions(+), 70 deletions(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 536ef55251..e4f79c911c 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -343,6 +343,15 @@ class SQLTests(ReusedPySparkTestCase): [res] = self.sqlCtx.sql("SELECT MYUDF('')").collect() self.assertEqual("", res[0]) + def test_udf_with_aggregate_function(self): + df = self.sqlCtx.createDataFrame([(1, "1"), (2, "2"), (1, "2"), (1, "2")], ["key", "value"]) + from pyspark.sql.functions import udf, col + from pyspark.sql.types import BooleanType + + my_filter = udf(lambda a: a == 1, BooleanType()) + sel = df.select(col("key")).distinct().filter(my_filter(col("key"))) + self.assertEqual(sel.collect(), [Row(key=1)]) + def test_basic_functions(self): rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}']) df = self.sqlCtx.read.json(rdd) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala index 63eb1aa24e..f5e1e77263 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala @@ -74,6 +74,7 @@ class QueryExecution(val sqlContext: SQLContext, val logical: LogicalPlan) { /** A sequence of rules that will be applied in order to the physical plan before execution. */ protected def preparations: Seq[Rule[SparkPlan]] = Seq( + python.ExtractPythonUDFs, PlanSubqueries(sqlContext), EnsureRequirements(sqlContext.conf), CollapseCodegenStages(sqlContext.conf), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index e1fabf519a..e52f05a5f4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -392,8 +392,6 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case logical.RepartitionByExpression(expressions, child, nPartitions) => exchange.ShuffleExchange(HashPartitioning( expressions, nPartitions.getOrElse(numPartitions)), planLater(child)) :: Nil - case e @ python.EvaluatePython(udfs, child, _) => - python.BatchPythonEvaluation(udfs, e.output, planLater(child)) :: Nil case LogicalRDD(output, rdd) => PhysicalRDD(output, rdd, "ExistingRDD") :: Nil case BroadcastHint(child) => planLater(child) :: Nil case _ => Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala index f3d1c44b25..3b05e29e52 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala @@ -35,30 +35,7 @@ import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, Generic import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String -/** - * Evaluates a list of [[PythonUDF]], appending the result to the end of the input tuple. - */ -case class EvaluatePython( - udfs: Seq[PythonUDF], - child: LogicalPlan, - resultAttribute: Seq[AttributeReference]) - extends logical.UnaryNode { - - def output: Seq[Attribute] = child.output ++ resultAttribute - - // References should not include the produced attribute. - override def references: AttributeSet = AttributeSet(udfs.flatMap(_.references)) -} - - object EvaluatePython { - def apply(udfs: Seq[PythonUDF], child: LogicalPlan): EvaluatePython = { - val resultAttrs = udfs.zipWithIndex.map { case (u, i) => - AttributeReference(s"pythonUDF$i", u.dataType)() - } - new EvaluatePython(udfs, child, resultAttrs) - } - def takeAndServe(df: DataFrame, n: Int): Int = { registerPicklers() df.withNewExecutionId { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala index 0934cd135d..d72b3d347d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala @@ -19,10 +19,10 @@ package org.apache.spark.sql.execution.python import scala.collection.mutable -import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression} -import org.apache.spark.sql.catalyst.plans.logical -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression} import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.execution +import org.apache.spark.sql.execution.SparkPlan /** * Extracts PythonUDFs from operators, rewriting the query plan so that the UDF can be evaluated @@ -34,7 +34,7 @@ import org.apache.spark.sql.catalyst.rules.Rule * This has the limitation that the input to the Python UDF is not allowed include attributes from * multiple child operators. */ -private[spark] object ExtractPythonUDFs extends Rule[LogicalPlan] { +private[spark] object ExtractPythonUDFs extends Rule[SparkPlan] { private def hasPythonUDF(e: Expression): Boolean = { e.find(_.isInstanceOf[PythonUDF]).isDefined @@ -54,49 +54,61 @@ private[spark] object ExtractPythonUDFs extends Rule[LogicalPlan] { case e => e.children.flatMap(collectEvaluatableUDF) } - def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { - // Skip EvaluatePython nodes. - case plan: EvaluatePython => plan + def apply(plan: SparkPlan): SparkPlan = plan transformUp { + case plan: SparkPlan => extract(plan) + } - case plan: LogicalPlan if plan.resolved => - // Extract any PythonUDFs from the current operator. - val udfs = plan.expressions.flatMap(collectEvaluatableUDF).filter(_.resolved) - if (udfs.isEmpty) { - // If there aren't any, we are done. - plan - } else { - val attributeMap = mutable.HashMap[PythonUDF, Expression]() - // Rewrite the child that has the input required for the UDF - val newChildren = plan.children.map { child => - // Pick the UDF we are going to evaluate - val validUdfs = udfs.filter { case udf => - // Check to make sure that the UDF can be evaluated with only the input of this child. - udf.references.subsetOf(child.outputSet) - } - if (validUdfs.nonEmpty) { - val evaluation = EvaluatePython(validUdfs, child) - attributeMap ++= validUdfs.zip(evaluation.resultAttribute) - evaluation - } else { - child - } + /** + * Extract all the PythonUDFs from the current operator. + */ + def extract(plan: SparkPlan): SparkPlan = { + val udfs = plan.expressions.flatMap(collectEvaluatableUDF) + if (udfs.isEmpty) { + // If there aren't any, we are done. + plan + } else { + val attributeMap = mutable.HashMap[PythonUDF, Expression]() + // Rewrite the child that has the input required for the UDF + val newChildren = plan.children.map { child => + // Pick the UDF we are going to evaluate + val validUdfs = udfs.filter { case udf => + // Check to make sure that the UDF can be evaluated with only the input of this child. + udf.references.subsetOf(child.outputSet) } - // Other cases are disallowed as they are ambiguous or would require a cartesian - // product. - udfs.filterNot(attributeMap.contains).foreach { udf => - if (udf.references.subsetOf(plan.inputSet)) { - sys.error(s"Invalid PythonUDF $udf, requires attributes from more than one child.") - } else { - sys.error(s"Unable to evaluate PythonUDF $udf. Missing input attributes.") + if (validUdfs.nonEmpty) { + val resultAttrs = udfs.zipWithIndex.map { case (u, i) => + AttributeReference(s"pythonUDF$i", u.dataType)() } + val evaluation = BatchPythonEvaluation(validUdfs, child.output ++ resultAttrs, child) + attributeMap ++= validUdfs.zip(resultAttrs) + evaluation + } else { + child + } + } + // Other cases are disallowed as they are ambiguous or would require a cartesian + // product. + udfs.filterNot(attributeMap.contains).foreach { udf => + if (udf.references.subsetOf(plan.inputSet)) { + sys.error(s"Invalid PythonUDF $udf, requires attributes from more than one child.") + } else { + sys.error(s"Unable to evaluate PythonUDF $udf. Missing input attributes.") } + } + + val rewritten = plan.transformExpressions { + case p: PythonUDF if attributeMap.contains(p) => + attributeMap(p) + }.withNewChildren(newChildren) + // extract remaining python UDFs recursively + val newPlan = extract(rewritten) + if (newPlan.output != plan.output) { // Trim away the new UDF value if it was only used for filtering or something. - logical.Project( - plan.output, - plan.transformExpressions { - case p: PythonUDF if attributeMap.contains(p) => attributeMap(p) - }.withNewChildren(newChildren)) + execution.Project(plan.output, newPlan) + } else { + newPlan } + } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala index 4f1b837158..59d7e8dd6d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala @@ -18,7 +18,6 @@ package org.apache.spark.sql.execution.python import org.apache.spark.api.python.PythonFunction -import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.expressions.{Expression, NonSQLExpression, Unevaluable} import org.apache.spark.sql.types.DataType @@ -30,7 +29,7 @@ case class PythonUDF( func: PythonFunction, dataType: DataType, children: Seq[Expression]) - extends Expression with Unevaluable with NonSQLExpression with Logging { + extends Expression with Unevaluable with NonSQLExpression { override def toString: String = s"$name(${children.mkString(", ")})" diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala index cd3d254d1e..cd29def3be 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala @@ -64,7 +64,6 @@ private[sql] class SessionState(ctx: SQLContext) { lazy val analyzer: Analyzer = { new Analyzer(catalog, functionRegistry, conf) { override val extendedResolutionRules = - python.ExtractPythonUDFs :: PreInsertCastAndRename :: DataSourceAnalysis :: (if (conf.runSQLOnFile) new ResolveDataSource(ctx) :: Nil else Nil) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala index ff40c366c8..829afa8432 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala @@ -60,7 +60,6 @@ private[hive] class HiveSessionState(ctx: HiveContext) extends SessionState(ctx) catalog.OrcConversions :: catalog.CreateTables :: catalog.PreInsertionCasts :: - python.ExtractPythonUDFs :: PreInsertCastAndRename :: DataSourceAnalysis :: (if (conf.runSQLOnFile) new ResolveDataSource(ctx) :: Nil else Nil) -- cgit v1.2.3 From 27dad6f658f04815e1f3b93c68974bfd31500bed Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Mon, 4 Apr 2016 13:19:34 -0700 Subject: [SPARK-14364][SPARK] HeartbeatReceiver object should be private ## What changes were proposed in this pull request? It's a mistake that HeartbeatReceiver object was made public in Spark 1.x. ## How was this patch tested? N/A Author: Reynold Xin Closes #12148 from rxin/SPARK-14364. --- core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala b/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala index e8748dd80a..61f689ec8c 100644 --- a/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala +++ b/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala @@ -220,6 +220,7 @@ private[spark] class HeartbeatReceiver(sc: SparkContext, clock: Clock) } } -object HeartbeatReceiver { + +private[spark] object HeartbeatReceiver { val ENDPOINT_NAME = "HeartbeatReceiver" } -- cgit v1.2.3 From 7143904700435265975d36f073cce2833467e121 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Mon, 4 Apr 2016 13:26:18 -0700 Subject: [SPARK-14358] Change SparkListener from a trait to an abstract class ## What changes were proposed in this pull request? Scala traits are difficult to maintain binary compatibility on, and as a result we had to introduce JavaSparkListener. In Spark 2.0 we can change SparkListener from a trait to an abstract class and then remove JavaSparkListener. ## How was this patch tested? Updated related unit tests. Author: Reynold Xin Closes #12142 from rxin/SPARK-14358. --- .../java/org/apache/spark/JavaSparkListener.java | 88 -------- .../org/apache/spark/SparkFirehoseListener.java | 2 +- .../scala/org/apache/spark/HeartbeatReceiver.scala | 2 +- .../org/apache/spark/scheduler/SparkListener.scala | 251 ++++++--------------- .../spark/scheduler/StatsReportListener.scala | 199 ++++++++++++++++ project/MimaExcludes.scala | 11 +- .../ui/StreamingJobProgressListener.scala | 2 +- 7 files changed, 276 insertions(+), 279 deletions(-) delete mode 100644 core/src/main/java/org/apache/spark/JavaSparkListener.java create mode 100644 core/src/main/scala/org/apache/spark/scheduler/StatsReportListener.scala diff --git a/core/src/main/java/org/apache/spark/JavaSparkListener.java b/core/src/main/java/org/apache/spark/JavaSparkListener.java deleted file mode 100644 index 23bc9a2e81..0000000000 --- a/core/src/main/java/org/apache/spark/JavaSparkListener.java +++ /dev/null @@ -1,88 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark; - -import org.apache.spark.scheduler.*; - -/** - * Java clients should extend this class instead of implementing - * SparkListener directly. This is to prevent java clients - * from breaking when new events are added to the SparkListener - * trait. - * - * This is a concrete class instead of abstract to enforce - * new events get added to both the SparkListener and this adapter - * in lockstep. - */ -public class JavaSparkListener implements SparkListener { - - @Override - public void onStageCompleted(SparkListenerStageCompleted stageCompleted) { } - - @Override - public void onStageSubmitted(SparkListenerStageSubmitted stageSubmitted) { } - - @Override - public void onTaskStart(SparkListenerTaskStart taskStart) { } - - @Override - public void onTaskGettingResult(SparkListenerTaskGettingResult taskGettingResult) { } - - @Override - public void onTaskEnd(SparkListenerTaskEnd taskEnd) { } - - @Override - public void onJobStart(SparkListenerJobStart jobStart) { } - - @Override - public void onJobEnd(SparkListenerJobEnd jobEnd) { } - - @Override - public void onEnvironmentUpdate(SparkListenerEnvironmentUpdate environmentUpdate) { } - - @Override - public void onBlockManagerAdded(SparkListenerBlockManagerAdded blockManagerAdded) { } - - @Override - public void onBlockManagerRemoved(SparkListenerBlockManagerRemoved blockManagerRemoved) { } - - @Override - public void onUnpersistRDD(SparkListenerUnpersistRDD unpersistRDD) { } - - @Override - public void onApplicationStart(SparkListenerApplicationStart applicationStart) { } - - @Override - public void onApplicationEnd(SparkListenerApplicationEnd applicationEnd) { } - - @Override - public void onExecutorMetricsUpdate(SparkListenerExecutorMetricsUpdate executorMetricsUpdate) { } - - @Override - public void onExecutorAdded(SparkListenerExecutorAdded executorAdded) { } - - @Override - public void onExecutorRemoved(SparkListenerExecutorRemoved executorRemoved) { } - - @Override - public void onBlockUpdated(SparkListenerBlockUpdated blockUpdated) { } - - @Override - public void onOtherEvent(SparkListenerEvent event) { } - -} diff --git a/core/src/main/java/org/apache/spark/SparkFirehoseListener.java b/core/src/main/java/org/apache/spark/SparkFirehoseListener.java index e6b24afd88..97eed611e8 100644 --- a/core/src/main/java/org/apache/spark/SparkFirehoseListener.java +++ b/core/src/main/java/org/apache/spark/SparkFirehoseListener.java @@ -28,7 +28,7 @@ import org.apache.spark.scheduler.*; * this was a concrete Scala class, default implementations of new event handlers would be inherited * from the SparkListener trait). */ -public class SparkFirehoseListener implements SparkListener { +public class SparkFirehoseListener implements SparkListenerInterface { public void onEvent(SparkListenerEvent event) { } diff --git a/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala b/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala index 61f689ec8c..2bdbd3fae9 100644 --- a/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala +++ b/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala @@ -56,7 +56,7 @@ private[spark] case class HeartbeatResponse(reregisterBlockManager: Boolean) * Lives in the driver to receive heartbeats from executors.. */ private[spark] class HeartbeatReceiver(sc: SparkContext, clock: Clock) - extends ThreadSafeRpcEndpoint with SparkListener with Logging { + extends SparkListener with ThreadSafeRpcEndpoint with Logging { def this(sc: SparkContext) { this(sc, new SystemClock) diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala index 586173f180..080ea6c33a 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala @@ -151,275 +151,152 @@ private[spark] trait SparkHistoryListenerFactory { def createListeners(conf: SparkConf, sparkUI: SparkUI): Seq[SparkListener] } + /** - * :: DeveloperApi :: - * Interface for listening to events from the Spark scheduler. Note that this is an internal - * interface which might change in different Spark releases. Java clients should extend - * {@link JavaSparkListener} + * Interface for listening to events from the Spark scheduler. Most applications should probably + * extend SparkListener or SparkFirehoseListener directly, rather than implementing this class. + * + * Note that this is an internal interface which might change in different Spark releases. */ -@DeveloperApi -trait SparkListener { +private[spark] trait SparkListenerInterface { + /** * Called when a stage completes successfully or fails, with information on the completed stage. */ - def onStageCompleted(stageCompleted: SparkListenerStageCompleted) { } + def onStageCompleted(stageCompleted: SparkListenerStageCompleted): Unit /** * Called when a stage is submitted */ - def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted) { } + def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted): Unit /** * Called when a task starts */ - def onTaskStart(taskStart: SparkListenerTaskStart) { } + def onTaskStart(taskStart: SparkListenerTaskStart): Unit /** * Called when a task begins remotely fetching its result (will not be called for tasks that do * not need to fetch the result remotely). */ - def onTaskGettingResult(taskGettingResult: SparkListenerTaskGettingResult) { } + def onTaskGettingResult(taskGettingResult: SparkListenerTaskGettingResult): Unit /** * Called when a task ends */ - def onTaskEnd(taskEnd: SparkListenerTaskEnd) { } + def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit /** * Called when a job starts */ - def onJobStart(jobStart: SparkListenerJobStart) { } + def onJobStart(jobStart: SparkListenerJobStart): Unit /** * Called when a job ends */ - def onJobEnd(jobEnd: SparkListenerJobEnd) { } + def onJobEnd(jobEnd: SparkListenerJobEnd): Unit /** * Called when environment properties have been updated */ - def onEnvironmentUpdate(environmentUpdate: SparkListenerEnvironmentUpdate) { } + def onEnvironmentUpdate(environmentUpdate: SparkListenerEnvironmentUpdate): Unit /** * Called when a new block manager has joined */ - def onBlockManagerAdded(blockManagerAdded: SparkListenerBlockManagerAdded) { } + def onBlockManagerAdded(blockManagerAdded: SparkListenerBlockManagerAdded): Unit /** * Called when an existing block manager has been removed */ - def onBlockManagerRemoved(blockManagerRemoved: SparkListenerBlockManagerRemoved) { } + def onBlockManagerRemoved(blockManagerRemoved: SparkListenerBlockManagerRemoved): Unit /** * Called when an RDD is manually unpersisted by the application */ - def onUnpersistRDD(unpersistRDD: SparkListenerUnpersistRDD) { } + def onUnpersistRDD(unpersistRDD: SparkListenerUnpersistRDD): Unit /** * Called when the application starts */ - def onApplicationStart(applicationStart: SparkListenerApplicationStart) { } + def onApplicationStart(applicationStart: SparkListenerApplicationStart): Unit /** * Called when the application ends */ - def onApplicationEnd(applicationEnd: SparkListenerApplicationEnd) { } + def onApplicationEnd(applicationEnd: SparkListenerApplicationEnd): Unit /** * Called when the driver receives task metrics from an executor in a heartbeat. */ - def onExecutorMetricsUpdate(executorMetricsUpdate: SparkListenerExecutorMetricsUpdate) { } + def onExecutorMetricsUpdate(executorMetricsUpdate: SparkListenerExecutorMetricsUpdate): Unit /** * Called when the driver registers a new executor. */ - def onExecutorAdded(executorAdded: SparkListenerExecutorAdded) { } + def onExecutorAdded(executorAdded: SparkListenerExecutorAdded): Unit /** * Called when the driver removes an executor. */ - def onExecutorRemoved(executorRemoved: SparkListenerExecutorRemoved) { } + def onExecutorRemoved(executorRemoved: SparkListenerExecutorRemoved): Unit /** * Called when the driver receives a block update info. */ - def onBlockUpdated(blockUpdated: SparkListenerBlockUpdated) { } + def onBlockUpdated(blockUpdated: SparkListenerBlockUpdated): Unit /** * Called when other events like SQL-specific events are posted. */ - def onOtherEvent(event: SparkListenerEvent) { } + def onOtherEvent(event: SparkListenerEvent): Unit } + /** * :: DeveloperApi :: - * Simple SparkListener that logs a few summary statistics when each stage completes + * A default implementation for [[SparkListenerInterface]] that has no-op implementations for + * all callbacks. + * + * Note that this is an internal interface which might change in different Spark releases. */ @DeveloperApi -class StatsReportListener extends SparkListener with Logging { - - import org.apache.spark.scheduler.StatsReportListener._ - - private val taskInfoMetrics = mutable.Buffer[(TaskInfo, TaskMetrics)]() - - override def onTaskEnd(taskEnd: SparkListenerTaskEnd) { - val info = taskEnd.taskInfo - val metrics = taskEnd.taskMetrics - if (info != null && metrics != null) { - taskInfoMetrics += ((info, metrics)) - } - } - - override def onStageCompleted(stageCompleted: SparkListenerStageCompleted) { - implicit val sc = stageCompleted - this.logInfo(s"Finished stage: ${getStatusDetail(stageCompleted.stageInfo)}") - showMillisDistribution("task runtime:", (info, _) => Some(info.duration), taskInfoMetrics) - - // Shuffle write - showBytesDistribution("shuffle bytes written:", - (_, metric) => metric.shuffleWriteMetrics.map(_.bytesWritten), taskInfoMetrics) - - // Fetch & I/O - showMillisDistribution("fetch wait time:", - (_, metric) => metric.shuffleReadMetrics.map(_.fetchWaitTime), taskInfoMetrics) - showBytesDistribution("remote bytes read:", - (_, metric) => metric.shuffleReadMetrics.map(_.remoteBytesRead), taskInfoMetrics) - showBytesDistribution("task result size:", - (_, metric) => Some(metric.resultSize), taskInfoMetrics) - - // Runtime breakdown - val runtimePcts = taskInfoMetrics.map { case (info, metrics) => - RuntimePercentage(info.duration, metrics) - } - showDistribution("executor (non-fetch) time pct: ", - Distribution(runtimePcts.map(_.executorPct * 100)), "%2.0f %%") - showDistribution("fetch wait time pct: ", - Distribution(runtimePcts.flatMap(_.fetchPct.map(_ * 100))), "%2.0f %%") - showDistribution("other time pct: ", Distribution(runtimePcts.map(_.other * 100)), "%2.0f %%") - taskInfoMetrics.clear() - } - - private def getStatusDetail(info: StageInfo): String = { - val failureReason = info.failureReason.map("(" + _ + ")").getOrElse("") - val timeTaken = info.submissionTime.map( - x => info.completionTime.getOrElse(System.currentTimeMillis()) - x - ).getOrElse("-") - - s"Stage(${info.stageId}, ${info.attemptId}); Name: '${info.name}'; " + - s"Status: ${info.getStatusString}$failureReason; numTasks: ${info.numTasks}; " + - s"Took: $timeTaken msec" - } +abstract class SparkListener extends SparkListenerInterface { + override def onStageCompleted(stageCompleted: SparkListenerStageCompleted): Unit = { } -} + override def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted): Unit = { } -private[spark] object StatsReportListener extends Logging { - - // For profiling, the extremes are more interesting - val percentiles = Array[Int](0, 5, 10, 25, 50, 75, 90, 95, 100) - val probabilities = percentiles.map(_ / 100.0) - val percentilesHeader = "\t" + percentiles.mkString("%\t") + "%" - - def extractDoubleDistribution( - taskInfoMetrics: Seq[(TaskInfo, TaskMetrics)], - getMetric: (TaskInfo, TaskMetrics) => Option[Double]): Option[Distribution] = { - Distribution(taskInfoMetrics.flatMap { case (info, metric) => getMetric(info, metric) }) - } - - // Is there some way to setup the types that I can get rid of this completely? - def extractLongDistribution( - taskInfoMetrics: Seq[(TaskInfo, TaskMetrics)], - getMetric: (TaskInfo, TaskMetrics) => Option[Long]): Option[Distribution] = { - extractDoubleDistribution( - taskInfoMetrics, - (info, metric) => { getMetric(info, metric).map(_.toDouble) }) - } - - def showDistribution(heading: String, d: Distribution, formatNumber: Double => String) { - val stats = d.statCounter - val quantiles = d.getQuantiles(probabilities).map(formatNumber) - logInfo(heading + stats) - logInfo(percentilesHeader) - logInfo("\t" + quantiles.mkString("\t")) - } - - def showDistribution( - heading: String, - dOpt: Option[Distribution], - formatNumber: Double => String) { - dOpt.foreach { d => showDistribution(heading, d, formatNumber)} - } - - def showDistribution(heading: String, dOpt: Option[Distribution], format: String) { - def f(d: Double): String = format.format(d) - showDistribution(heading, dOpt, f _) - } - - def showDistribution( - heading: String, - format: String, - getMetric: (TaskInfo, TaskMetrics) => Option[Double], - taskInfoMetrics: Seq[(TaskInfo, TaskMetrics)]) { - showDistribution(heading, extractDoubleDistribution(taskInfoMetrics, getMetric), format) - } - - def showBytesDistribution( - heading: String, - getMetric: (TaskInfo, TaskMetrics) => Option[Long], - taskInfoMetrics: Seq[(TaskInfo, TaskMetrics)]) { - showBytesDistribution(heading, extractLongDistribution(taskInfoMetrics, getMetric)) - } - - def showBytesDistribution(heading: String, dOpt: Option[Distribution]) { - dOpt.foreach { dist => showBytesDistribution(heading, dist) } - } - - def showBytesDistribution(heading: String, dist: Distribution) { - showDistribution(heading, dist, (d => Utils.bytesToString(d.toLong)): Double => String) - } - - def showMillisDistribution(heading: String, dOpt: Option[Distribution]) { - showDistribution(heading, dOpt, - (d => StatsReportListener.millisToString(d.toLong)): Double => String) - } - - def showMillisDistribution( - heading: String, - getMetric: (TaskInfo, TaskMetrics) => Option[Long], - taskInfoMetrics: Seq[(TaskInfo, TaskMetrics)]) { - showMillisDistribution(heading, extractLongDistribution(taskInfoMetrics, getMetric)) - } - - val seconds = 1000L - val minutes = seconds * 60 - val hours = minutes * 60 + override def onTaskStart(taskStart: SparkListenerTaskStart): Unit = { } - /** - * Reformat a time interval in milliseconds to a prettier format for output - */ - def millisToString(ms: Long): String = { - val (size, units) = - if (ms > hours) { - (ms.toDouble / hours, "hours") - } else if (ms > minutes) { - (ms.toDouble / minutes, "min") - } else if (ms > seconds) { - (ms.toDouble / seconds, "s") - } else { - (ms.toDouble, "ms") - } - "%.1f %s".format(size, units) - } -} + override def onTaskGettingResult(taskGettingResult: SparkListenerTaskGettingResult): Unit = { } + + override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = { } + + override def onJobStart(jobStart: SparkListenerJobStart): Unit = { } + + override def onJobEnd(jobEnd: SparkListenerJobEnd): Unit = { } + + override def onEnvironmentUpdate(environmentUpdate: SparkListenerEnvironmentUpdate): Unit = { } + + override def onBlockManagerAdded(blockManagerAdded: SparkListenerBlockManagerAdded): Unit = { } + + override def onBlockManagerRemoved( + blockManagerRemoved: SparkListenerBlockManagerRemoved): Unit = { } + + override def onUnpersistRDD(unpersistRDD: SparkListenerUnpersistRDD): Unit = { } + + override def onApplicationStart(applicationStart: SparkListenerApplicationStart): Unit = { } + + override def onApplicationEnd(applicationEnd: SparkListenerApplicationEnd): Unit = { } + + override def onExecutorMetricsUpdate( + executorMetricsUpdate: SparkListenerExecutorMetricsUpdate): Unit = { } + + override def onExecutorAdded(executorAdded: SparkListenerExecutorAdded): Unit = { } + + override def onExecutorRemoved(executorRemoved: SparkListenerExecutorRemoved): Unit = { } + + override def onBlockUpdated(blockUpdated: SparkListenerBlockUpdated): Unit = { } -private case class RuntimePercentage(executorPct: Double, fetchPct: Option[Double], other: Double) - -private object RuntimePercentage { - def apply(totalTime: Long, metrics: TaskMetrics): RuntimePercentage = { - val denom = totalTime.toDouble - val fetchTime = metrics.shuffleReadMetrics.map(_.fetchWaitTime) - val fetch = fetchTime.map(_ / denom) - val exec = (metrics.executorRunTime - fetchTime.getOrElse(0L)) / denom - val other = 1.0 - (exec + fetch.getOrElse(0d)) - RuntimePercentage(exec, fetch, other) - } + override def onOtherEvent(event: SparkListenerEvent): Unit = { } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/StatsReportListener.scala b/core/src/main/scala/org/apache/spark/scheduler/StatsReportListener.scala new file mode 100644 index 0000000000..309f4b806b --- /dev/null +++ b/core/src/main/scala/org/apache/spark/scheduler/StatsReportListener.scala @@ -0,0 +1,199 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler + +import scala.collection.mutable + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.executor.TaskMetrics +import org.apache.spark.internal.Logging +import org.apache.spark.util.{Distribution, Utils} + + +/** + * :: DeveloperApi :: + * Simple SparkListener that logs a few summary statistics when each stage completes. + */ +@DeveloperApi +class StatsReportListener extends SparkListener with Logging { + + import org.apache.spark.scheduler.StatsReportListener._ + + private val taskInfoMetrics = mutable.Buffer[(TaskInfo, TaskMetrics)]() + + override def onTaskEnd(taskEnd: SparkListenerTaskEnd) { + val info = taskEnd.taskInfo + val metrics = taskEnd.taskMetrics + if (info != null && metrics != null) { + taskInfoMetrics += ((info, metrics)) + } + } + + override def onStageCompleted(stageCompleted: SparkListenerStageCompleted) { + implicit val sc = stageCompleted + this.logInfo(s"Finished stage: ${getStatusDetail(stageCompleted.stageInfo)}") + showMillisDistribution("task runtime:", (info, _) => Some(info.duration), taskInfoMetrics) + + // Shuffle write + showBytesDistribution("shuffle bytes written:", + (_, metric) => metric.shuffleWriteMetrics.map(_.bytesWritten), taskInfoMetrics) + + // Fetch & I/O + showMillisDistribution("fetch wait time:", + (_, metric) => metric.shuffleReadMetrics.map(_.fetchWaitTime), taskInfoMetrics) + showBytesDistribution("remote bytes read:", + (_, metric) => metric.shuffleReadMetrics.map(_.remoteBytesRead), taskInfoMetrics) + showBytesDistribution("task result size:", + (_, metric) => Some(metric.resultSize), taskInfoMetrics) + + // Runtime breakdown + val runtimePcts = taskInfoMetrics.map { case (info, metrics) => + RuntimePercentage(info.duration, metrics) + } + showDistribution("executor (non-fetch) time pct: ", + Distribution(runtimePcts.map(_.executorPct * 100)), "%2.0f %%") + showDistribution("fetch wait time pct: ", + Distribution(runtimePcts.flatMap(_.fetchPct.map(_ * 100))), "%2.0f %%") + showDistribution("other time pct: ", Distribution(runtimePcts.map(_.other * 100)), "%2.0f %%") + taskInfoMetrics.clear() + } + + private def getStatusDetail(info: StageInfo): String = { + val failureReason = info.failureReason.map("(" + _ + ")").getOrElse("") + val timeTaken = info.submissionTime.map( + x => info.completionTime.getOrElse(System.currentTimeMillis()) - x + ).getOrElse("-") + + s"Stage(${info.stageId}, ${info.attemptId}); Name: '${info.name}'; " + + s"Status: ${info.getStatusString}$failureReason; numTasks: ${info.numTasks}; " + + s"Took: $timeTaken msec" + } + +} + +private[spark] object StatsReportListener extends Logging { + + // For profiling, the extremes are more interesting + val percentiles = Array[Int](0, 5, 10, 25, 50, 75, 90, 95, 100) + val probabilities = percentiles.map(_ / 100.0) + val percentilesHeader = "\t" + percentiles.mkString("%\t") + "%" + + def extractDoubleDistribution( + taskInfoMetrics: Seq[(TaskInfo, TaskMetrics)], + getMetric: (TaskInfo, TaskMetrics) => Option[Double]): Option[Distribution] = { + Distribution(taskInfoMetrics.flatMap { case (info, metric) => getMetric(info, metric) }) + } + + // Is there some way to setup the types that I can get rid of this completely? + def extractLongDistribution( + taskInfoMetrics: Seq[(TaskInfo, TaskMetrics)], + getMetric: (TaskInfo, TaskMetrics) => Option[Long]): Option[Distribution] = { + extractDoubleDistribution( + taskInfoMetrics, + (info, metric) => { getMetric(info, metric).map(_.toDouble) }) + } + + def showDistribution(heading: String, d: Distribution, formatNumber: Double => String) { + val stats = d.statCounter + val quantiles = d.getQuantiles(probabilities).map(formatNumber) + logInfo(heading + stats) + logInfo(percentilesHeader) + logInfo("\t" + quantiles.mkString("\t")) + } + + def showDistribution( + heading: String, + dOpt: Option[Distribution], + formatNumber: Double => String) { + dOpt.foreach { d => showDistribution(heading, d, formatNumber)} + } + + def showDistribution(heading: String, dOpt: Option[Distribution], format: String) { + def f(d: Double): String = format.format(d) + showDistribution(heading, dOpt, f _) + } + + def showDistribution( + heading: String, + format: String, + getMetric: (TaskInfo, TaskMetrics) => Option[Double], + taskInfoMetrics: Seq[(TaskInfo, TaskMetrics)]) { + showDistribution(heading, extractDoubleDistribution(taskInfoMetrics, getMetric), format) + } + + def showBytesDistribution( + heading: String, + getMetric: (TaskInfo, TaskMetrics) => Option[Long], + taskInfoMetrics: Seq[(TaskInfo, TaskMetrics)]) { + showBytesDistribution(heading, extractLongDistribution(taskInfoMetrics, getMetric)) + } + + def showBytesDistribution(heading: String, dOpt: Option[Distribution]) { + dOpt.foreach { dist => showBytesDistribution(heading, dist) } + } + + def showBytesDistribution(heading: String, dist: Distribution) { + showDistribution(heading, dist, (d => Utils.bytesToString(d.toLong)): Double => String) + } + + def showMillisDistribution(heading: String, dOpt: Option[Distribution]) { + showDistribution(heading, dOpt, + (d => StatsReportListener.millisToString(d.toLong)): Double => String) + } + + def showMillisDistribution( + heading: String, + getMetric: (TaskInfo, TaskMetrics) => Option[Long], + taskInfoMetrics: Seq[(TaskInfo, TaskMetrics)]) { + showMillisDistribution(heading, extractLongDistribution(taskInfoMetrics, getMetric)) + } + + val seconds = 1000L + val minutes = seconds * 60 + val hours = minutes * 60 + + /** + * Reformat a time interval in milliseconds to a prettier format for output + */ + def millisToString(ms: Long): String = { + val (size, units) = + if (ms > hours) { + (ms.toDouble / hours, "hours") + } else if (ms > minutes) { + (ms.toDouble / minutes, "min") + } else if (ms > seconds) { + (ms.toDouble / seconds, "s") + } else { + (ms.toDouble, "ms") + } + "%.1f %s".format(size, units) + } +} + +private case class RuntimePercentage(executorPct: Double, fetchPct: Option[Double], other: Double) + +private object RuntimePercentage { + def apply(totalTime: Long, metrics: TaskMetrics): RuntimePercentage = { + val denom = totalTime.toDouble + val fetchTime = metrics.shuffleReadMetrics.map(_.fetchWaitTime) + val fetch = fetchTime.map(_ / denom) + val exec = (metrics.executorRunTime - fetchTime.getOrElse(0L)) / denom + val other = 1.0 - (exec + fetch.getOrElse(0d)) + RuntimePercentage(exec, fetch, other) + } +} diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 2be490b942..9f245afd50 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -66,7 +66,16 @@ object MimaExcludes { ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.HadoopFsRelation$FileStatusCache") ) ++ Seq( ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.SparkContext.emptyRDD"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.broadcast.HttpBroadcastFactory") + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.broadcast.HttpBroadcastFactory"), + // SPARK-14358 SparkListener from trait to abstract class + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.JavaSparkListener"), + ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.SparkFirehoseListener"), + ProblemFilters.exclude[IncompatibleTemplateDefProblem]("org.apache.spark.scheduler.SparkListener"), + ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.ui.jobs.JobProgressListener"), + ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.ui.exec.ExecutorsListener"), + ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.ui.env.EnvironmentListener"), + ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.ui.storage.StorageListener"), + ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.storage.StorageStatusListener") ) ++ Seq( // SPARK-3369 Fix Iterable/Iterator in Java API diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingJobProgressListener.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingJobProgressListener.scala index 6985c37f71..c086df47d9 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingJobProgressListener.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingJobProgressListener.scala @@ -28,7 +28,7 @@ import org.apache.spark.streaming.{StreamingContext, Time} import org.apache.spark.streaming.scheduler._ private[streaming] class StreamingJobProgressListener(ssc: StreamingContext) - extends StreamingListener with SparkListener { + extends SparkListener with StreamingListener { private val waitingBatchUIData = new HashMap[Time, BatchUIData] private val runningBatchUIData = new HashMap[Time, BatchUIData] -- cgit v1.2.3 From cc70f174169f45c85d459126a68bbe43c0bec328 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Mon, 4 Apr 2016 13:31:44 -0700 Subject: [SPARK-14334] [SQL] add toLocalIterator for Dataset/DataFrame ## What changes were proposed in this pull request? RDD.toLocalIterator() could be used to fetch one partition at a time to reduce the memory usage. Right now, for Dataset/Dataframe we have to use df.rdd.toLocalIterator, which is super slow also requires lots of memory (because of the Java serializer or even Kyro serializer). This PR introduce an optimized toLocalIterator for Dataset/DataFrame, which is much faster and requires much less memory. For a partition with 5 millions rows, `df.rdd.toIterator` took about 100 seconds, but df.toIterator took less than 7 seconds. For 10 millions row, rdd.toIterator will crash (not enough memory) with 4G heap, but df.toLocalIterator could finished in 12 seconds. The JDBC server has been updated to use DataFrame.toIterator. ## How was this patch tested? Existing tests. Author: Davies Liu Closes #12114 from davies/local_iterator. --- .../org/apache/spark/api/python/PythonRDD.scala | 4 +++ python/pyspark/rdd.py | 8 ++--- python/pyspark/sql/dataframe.py | 14 +++++++++ .../main/scala/org/apache/spark/sql/Dataset.scala | 25 ++++++++++++++++ .../org/apache/spark/sql/execution/SparkPlan.scala | 35 +++++++++++++++------- .../org/apache/spark/sql/JavaDatasetSuite.java | 10 +++++++ .../scala/org/apache/spark/sql/DatasetSuite.scala | 1 + .../SparkExecuteStatementOperation.scala | 2 +- 8 files changed, 83 insertions(+), 16 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index 6faa03c12b..4bca16a234 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -453,6 +453,10 @@ private[spark] object PythonRDD extends Logging { serveIterator(rdd.collect().iterator, s"serve RDD ${rdd.id}") } + def toLocalIteratorAndServe[T](rdd: RDD[T]): Int = { + serveIterator(rdd.toLocalIterator, s"serve toLocalIterator") + } + def readRDDFromFile(sc: JavaSparkContext, filename: String, parallelism: Int): JavaRDD[Array[Byte]] = { val file = new DataInputStream(new FileInputStream(filename)) diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 37574cea0b..cd1f64e8aa 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -2299,14 +2299,14 @@ class RDD(object): """ Return an iterator that contains all of the elements in this RDD. The iterator will consume as much memory as the largest partition in this RDD. + >>> rdd = sc.parallelize(range(10)) >>> [x for x in rdd.toLocalIterator()] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] """ - for partition in range(self.getNumPartitions()): - rows = self.context.runJob(self, lambda x: x, [partition]) - for row in rows: - yield row + with SCCallSiteSync(self.context) as css: + port = self.ctx._jvm.PythonRDD.toLocalIteratorAndServe(self._jrdd.rdd()) + return _load_from_socket(port, self._jrdd_deserializer) def _prepare_for_python_RDD(sc, command): diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 7a69c4c70c..d473d6b534 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -240,6 +240,20 @@ class DataFrame(object): port = self._jdf.collectToPython() return list(_load_from_socket(port, BatchedSerializer(PickleSerializer()))) + @ignore_unicode_prefix + @since(2.0) + def toLocalIterator(self): + """ + Returns an iterator that contains all of the rows in this :class:`DataFrame`. + The iterator will consume as much memory as the largest partition in this DataFrame. + + >>> list(df.toLocalIterator()) + [Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')] + """ + with SCCallSiteSync(self._sc) as css: + port = self._jdf.toPythonIterator() + return _load_from_socket(port, BatchedSerializer(PickleSerializer())) + @ignore_unicode_prefix @since(1.3) def limit(self, num): diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index a39a2113e5..8dfe8ff702 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql import java.io.CharArrayWriter import scala.collection.JavaConverters._ +import scala.collection.mutable.ArrayBuffer import scala.language.implicitConversions import scala.reflect.runtime.universe.TypeTag import scala.util.control.NonFatal @@ -2056,6 +2057,24 @@ class Dataset[T] private[sql]( } } + /** + * Return an iterator that contains all of [[Row]]s in this [[Dataset]]. + * + * The iterator will consume as much memory as the largest partition in this [[Dataset]]. + * + * Note: this results in multiple Spark jobs, and if the input Dataset is the result + * of a wide transformation (e.g. join with different partitioners), to avoid + * recomputing the input Dataset should be cached first. + * + * @group action + * @since 2.0.0 + */ + def toLocalIterator(): java.util.Iterator[T] = withCallback("toLocalIterator", toDF()) { _ => + withNewExecutionId { + queryExecution.executedPlan.executeToIterator().map(boundTEncoder.fromRow).asJava + } + } + /** * Returns the number of rows in the [[Dataset]]. * @group action @@ -2300,6 +2319,12 @@ class Dataset[T] private[sql]( } } + protected[sql] def toPythonIterator(): Int = { + withNewExecutionId { + PythonRDD.toLocalIteratorAndServe(javaToPython.rdd) + } + } + //////////////////////////////////////////////////////////////////////////// // Private Helpers //////////////////////////////////////////////////////////////////////////// diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index ff19d1be1c..4091f65aec 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -249,20 +249,24 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ /** * Decode the byte arrays back to UnsafeRows and put them into buffer. */ - private def decodeUnsafeRows(bytes: Array[Byte], buffer: ArrayBuffer[InternalRow]): Unit = { + private def decodeUnsafeRows(bytes: Array[Byte]): Iterator[InternalRow] = { val nFields = schema.length val codec = CompressionCodec.createCodec(SparkEnv.get.conf) val bis = new ByteArrayInputStream(bytes) val ins = new DataInputStream(codec.compressedInputStream(bis)) - var sizeOfNextRow = ins.readInt() - while (sizeOfNextRow >= 0) { - val bs = new Array[Byte](sizeOfNextRow) - ins.readFully(bs) - val row = new UnsafeRow(nFields) - row.pointTo(bs, sizeOfNextRow) - buffer += row - sizeOfNextRow = ins.readInt() + + new Iterator[InternalRow] { + private var sizeOfNextRow = ins.readInt() + override def hasNext: Boolean = sizeOfNextRow >= 0 + override def next(): InternalRow = { + val bs = new Array[Byte](sizeOfNextRow) + ins.readFully(bs) + val row = new UnsafeRow(nFields) + row.pointTo(bs, sizeOfNextRow) + sizeOfNextRow = ins.readInt() + row + } } } @@ -274,11 +278,20 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ val results = ArrayBuffer[InternalRow]() byteArrayRdd.collect().foreach { bytes => - decodeUnsafeRows(bytes, results) + decodeUnsafeRows(bytes).foreach(results.+=) } results.toArray } + /** + * Runs this query returning the result as an iterator of InternalRow. + * + * Note: this will trigger multiple jobs (one for each partition). + */ + def executeToIterator(): Iterator[InternalRow] = { + getByteArrayRdd().toLocalIterator.flatMap(decodeUnsafeRows) + } + /** * Runs this query returning the result as an array, using external Row format. */ @@ -325,7 +338,7 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ (it: Iterator[Array[Byte]]) => if (it.hasNext) it.next() else Array.empty, p) res.foreach { r => - decodeUnsafeRows(r.asInstanceOf[Array[Byte]], buf) + decodeUnsafeRows(r.asInstanceOf[Array[Byte]]).foreach(buf.+=) } partsScanned += p.size diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java index 873f681bdf..f26c57b301 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java @@ -85,6 +85,16 @@ public class JavaDatasetSuite implements Serializable { Assert.assertEquals(Arrays.asList("hello"), collected); } + @Test + public void testToLocalIterator() { + List data = Arrays.asList("hello", "world"); + Dataset ds = context.createDataset(data, Encoders.STRING()); + Iterator iter = ds.toLocalIterator(); + Assert.assertEquals("hello", iter.next()); + Assert.assertEquals("world", iter.next()); + Assert.assertFalse(iter.hasNext()); + } + @Test public void testCommonOperation() { List data = Arrays.asList("hello", "world"); diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 553bc436a6..2aa90568c3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -71,6 +71,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { assert(ds.first() == item) assert(ds.take(1).head == item) assert(ds.takeAsList(1).get(0) == item) + assert(ds.toLocalIterator().next() === item) } test("coalesce, repartition") { diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala index a955314ba3..673a293ce2 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala @@ -222,7 +222,7 @@ private[hive] class SparkExecuteStatementOperation( val useIncrementalCollect = hiveContext.getConf("spark.sql.thriftServer.incrementalCollect", "false").toBoolean if (useIncrementalCollect) { - result.rdd.toLocalIterator + result.toLocalIterator.asScala } else { result.collect().iterator } -- cgit v1.2.3 From 400b2f863ffaa01a34a8dae1541c61526fef908b Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Mon, 4 Apr 2016 14:41:03 -0700 Subject: [SPARK-14259] [SQL] Merging small files together based on the cost of opening ## What changes were proposed in this pull request? This PR basically re-do the things in #12068 but with a different model, which should work better in case of small files with different sizes. ## How was this patch tested? Updated existing tests. Ran a query on thousands of partitioned small files locally, with all default settings (the cost to open a file should be over estimated), the durations of tasks become smaller and smaller, which is good (the last few tasks will be shortest). Author: Davies Liu Closes #12095 from davies/file_cost. --- .../sql/execution/datasources/FileSourceStrategy.scala | 13 +++++-------- .../main/scala/org/apache/spark/sql/internal/SQLConf.scala | 13 ++++++++----- .../execution/datasources/FileSourceStrategySuite.scala | 14 ++++++++------ 3 files changed, 21 insertions(+), 19 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala index a143ac6aec..618d5a522b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala @@ -131,9 +131,9 @@ private[sql] object FileSourceStrategy extends Strategy with Logging { case _ => val maxSplitBytes = files.sqlContext.conf.filesMaxPartitionBytes - val maxFileNumInPartition = files.sqlContext.conf.filesMaxNumInPartition + val openCostInBytes = files.sqlContext.conf.filesOpenCostInBytes logInfo(s"Planning scan with bin packing, max size: $maxSplitBytes bytes, " + - s"max #files: $maxFileNumInPartition") + s"open cost is considered as scanning $openCostInBytes bytes.") val splitFiles = selectedPartitions.flatMap { partition => partition.files.flatMap { file => @@ -151,7 +151,7 @@ private[sql] object FileSourceStrategy extends Strategy with Logging { /** Add the given file to the current partition. */ def addFile(file: PartitionedFile): Unit = { - currentSize += file.length + currentSize += file.length + openCostInBytes currentFiles.append(file) } @@ -171,13 +171,10 @@ private[sql] object FileSourceStrategy extends Strategy with Logging { // Assign files to partitions using "First Fit Decreasing" (FFD) // TODO: consider adding a slop factor here? splitFiles.foreach { file => - if (currentSize + file.length > maxSplitBytes || - currentFiles.length >= maxFileNumInPartition) { + if (currentSize + file.length > maxSplitBytes) { closePartition() - addFile(file) - } else { - addFile(file) } + addFile(file) } closePartition() partitions diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 6cc72fba48..a7c0be63fc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -510,10 +510,13 @@ object SQLConf { doc = "The maximum number of bytes to pack into a single partition when reading files.", isPublic = true) - val FILES_MAX_NUM_IN_PARTITION = longConf("spark.sql.files.maxNumInPartition", - defaultValue = Some(32), - doc = "The maximum number of files to pack into a single partition when reading files.", - isPublic = true) + val FILES_OPEN_COST_IN_BYTES = longConf("spark.sql.files.openCostInBytes", + defaultValue = Some(4 * 1024 * 1024), + doc = "The estimated cost to open a file, measured by the number of bytes could be scanned in" + + " the same time. This is used when putting multiple files into a partition. It's better to" + + " over estimated, then the partitions with small files will be faster than partitions with" + + " bigger files (which is scheduled first).", + isPublic = false) val EXCHANGE_REUSE_ENABLED = booleanConf("spark.sql.exchange.reuse", defaultValue = Some(true), @@ -572,7 +575,7 @@ class SQLConf extends Serializable with CatalystConf with Logging { def filesMaxPartitionBytes: Long = getConf(FILES_MAX_PARTITION_BYTES) - def filesMaxNumInPartition: Long = getConf(FILES_MAX_NUM_IN_PARTITION) + def filesOpenCostInBytes: Long = getConf(FILES_OPEN_COST_IN_BYTES) def useCompression: Boolean = getConf(COMPRESS_CACHED) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala index 717a3a80b7..4446a6881c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala @@ -76,7 +76,8 @@ class FileSourceStrategySuite extends QueryTest with SharedSQLContext with Predi "file2" -> 5, "file3" -> 5)) - withSQLConf(SQLConf.FILES_MAX_PARTITION_BYTES.key -> "10") { + withSQLConf(SQLConf.FILES_MAX_PARTITION_BYTES.key -> "11", + SQLConf.FILES_OPEN_COST_IN_BYTES.key -> "1") { checkScan(table.select('c1)) { partitions => // 5 byte files should be laid out [(5, 5), (5)] assert(partitions.size == 2, "when checking partitions") @@ -98,11 +99,12 @@ class FileSourceStrategySuite extends QueryTest with SharedSQLContext with Predi createTable( files = Seq( "file1" -> 15, - "file2" -> 4)) + "file2" -> 3)) - withSQLConf(SQLConf.FILES_MAX_PARTITION_BYTES.key -> "10") { + withSQLConf(SQLConf.FILES_MAX_PARTITION_BYTES.key -> "10", + SQLConf.FILES_OPEN_COST_IN_BYTES.key -> "1") { checkScan(table.select('c1)) { partitions => - // Files should be laid out [(0-5), (5-10, 4)] + // Files should be laid out [(0-10), (10-15, 4)] assert(partitions.size == 2, "when checking partitions") assert(partitions(0).files.size == 1, "when checking partition 1") assert(partitions(1).files.size == 2, "when checking partition 2") @@ -132,8 +134,8 @@ class FileSourceStrategySuite extends QueryTest with SharedSQLContext with Predi "file5" -> 1, "file6" -> 1)) - withSQLConf(SQLConf.FILES_MAX_PARTITION_BYTES.key -> "3", - SQLConf.FILES_MAX_NUM_IN_PARTITION.key -> "2") { + withSQLConf(SQLConf.FILES_MAX_PARTITION_BYTES.key -> "4", + SQLConf.FILES_OPEN_COST_IN_BYTES.key -> "1") { checkScan(table.select('c1)) { partitions => // Files should be laid out [(file1), (file2, file3), (file4, file5), (file6)] assert(partitions.size == 4, "when checking partitions") -- cgit v1.2.3 From 24d7d2e453ab5eef6099a32fb9e8ed60f6ada93a Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Mon, 4 Apr 2016 16:52:21 -0700 Subject: [SPARK-13579][BUILD] Stop building the main Spark assembly. This change modifies the "assembly/" module to just copy needed dependencies to its build directory, and modifies the packaging script to pick those up (and remove duplicate jars packages in the examples module). I also made some minor adjustments to dependencies to remove some test jars from the final packaging, and remove jars that conflict with each other when packaged separately (e.g. servlet api). Also note that this change restores guava in applications' classpaths, even though it's still shaded inside Spark. This is now needed for the Hadoop libraries that are packaged with Spark, which now are not processed by the shade plugin. Author: Marcelo Vanzin Closes #11796 from vanzin/SPARK-13579. --- assembly/pom.xml | 101 +++++++-------------- bin/spark-class | 11 +-- bin/spark-class2.cmd | 5 +- .../main/scala/org/apache/spark/util/Utils.scala | 4 +- .../org/apache/spark/util/FileAppenderSuite.scala | 72 ++++++++------- dev/deps/spark-deps-hadoop-2.2 | 4 +- dev/deps/spark-deps-hadoop-2.3 | 4 +- dev/deps/spark-deps-hadoop-2.4 | 4 +- dev/deps/spark-deps-hadoop-2.6 | 4 +- dev/deps/spark-deps-hadoop-2.7 | 4 +- dev/make-distribution.sh | 25 +++-- dev/mima | 6 +- dev/run-tests.py | 11 ++- docs/sql-programming-guide.md | 7 +- examples/pom.xml | 80 ++-------------- .../spark/launcher/AbstractCommandBuilder.java | 47 +++++----- .../apache/spark/launcher/CommandBuilderUtils.java | 4 +- .../spark/launcher/SparkSubmitCommandBuilder.java | 11 ++- pom.xml | 44 ++++++--- project/SparkBuild.scala | 45 +++++---- python/pyspark/streaming/tests.py | 6 +- python/run-tests.py | 18 +++- .../thriftserver/HiveThriftServer2Suites.scala | 4 + sql/hive/pom.xml | 24 ----- .../org/apache/spark/deploy/yarn/Client.scala | 3 - .../org/apache/spark/deploy/yarn/ClientSuite.scala | 2 +- 26 files changed, 231 insertions(+), 319 deletions(-) diff --git a/assembly/pom.xml b/assembly/pom.xml index 477d4931c3..22cbac06ca 100644 --- a/assembly/pom.xml +++ b/assembly/pom.xml @@ -33,9 +33,8 @@ assembly - scala-${scala.binary.version} - spark-assembly-${project.version}-hadoop${hadoop.version}.jar - ${project.build.directory}/${spark.jar.dir}/${spark.jar.basename} + none + package @@ -69,6 +68,17 @@ spark-repl_${scala.binary.version} ${project.version} + + + + com.google.guava + guava + ${hadoop.deps.scope} + @@ -87,75 +97,26 @@ true - - - org.apache.maven.plugins - maven-antrun-plugin - - - package - - run - - - - - - - - - - - - - + org.apache.maven.plugins - maven-shade-plugin - - false - ${spark.jar} - - - *:* - - - - - *:* - - org/datanucleus/** - META-INF/*.SF - META-INF/*.DSA - META-INF/*.RSA - - - - - - - package - - shade - - - - - - META-INF/services/org.apache.hadoop.fs.FileSystem - - - reference.conf - - - log4j.properties - - - - - - - + maven-antrun-plugin + + + package + + run + + + + + + + + + + + diff --git a/bin/spark-class b/bin/spark-class index e710e388be..b489591778 100755 --- a/bin/spark-class +++ b/bin/spark-class @@ -36,21 +36,20 @@ else fi # Find Spark jars. -# TODO: change the directory name when Spark jars move from "lib". if [ -f "${SPARK_HOME}/RELEASE" ]; then - SPARK_JARS_DIR="${SPARK_HOME}/lib" + SPARK_JARS_DIR="${SPARK_HOME}/jars" else - SPARK_JARS_DIR="${SPARK_HOME}/assembly/target/scala-$SPARK_SCALA_VERSION" + SPARK_JARS_DIR="${SPARK_HOME}/assembly/target/scala-$SPARK_SCALA_VERSION/jars" fi -if [ ! -d "$SPARK_JARS_DIR" ]; then +if [ ! -d "$SPARK_JARS_DIR" ] && [ -z "$SPARK_TESTING$SPARK_SQL_TESTING" ]; then echo "Failed to find Spark jars directory ($SPARK_JARS_DIR)." 1>&2 echo "You need to build Spark before running this program." 1>&2 exit 1 +else + LAUNCH_CLASSPATH="$SPARK_JARS_DIR/*" fi -LAUNCH_CLASSPATH="$SPARK_JARS_DIR/*" - # Add the launcher build dir to the classpath if requested. if [ -n "$SPARK_PREPEND_CLASSES" ]; then LAUNCH_CLASSPATH="${SPARK_HOME}/launcher/target/scala-$SPARK_SCALA_VERSION/classes:$LAUNCH_CLASSPATH" diff --git a/bin/spark-class2.cmd b/bin/spark-class2.cmd index 565b87c102..579efff909 100644 --- a/bin/spark-class2.cmd +++ b/bin/spark-class2.cmd @@ -29,11 +29,10 @@ if "x%1"=="x" ( ) rem Find Spark jars. -rem TODO: change the directory name when Spark jars move from "lib". if exist "%SPARK_HOME%\RELEASE" ( - set SPARK_JARS_DIR="%SPARK_HOME%\lib" + set SPARK_JARS_DIR="%SPARK_HOME%\jars" ) else ( - set SPARK_JARS_DIR="%SPARK_HOME%\assembly\target\scala-%SPARK_SCALA_VERSION%" + set SPARK_JARS_DIR="%SPARK_HOME%\assembly\target\scala-%SPARK_SCALA_VERSION%\jars" ) if not exist "%SPARK_JARS_DIR%"\ ( diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 50bcf85805..c304629bcd 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -1121,9 +1121,9 @@ private[spark] object Utils extends Logging { extraEnvironment: Map[String, String] = Map.empty, redirectStderr: Boolean = true): String = { val process = executeCommand(command, workingDir, extraEnvironment, redirectStderr) - val output = new StringBuffer + val output = new StringBuilder val threadName = "read stdout for " + command(0) - def appendToOutput(s: String): Unit = output.append(s) + def appendToOutput(s: String): Unit = output.append(s).append("\n") val stdoutThread = processStreamByLine(threadName, process.getInputStream, appendToOutput) val exitCode = process.waitFor() stdoutThread.join() // Wait for it to finish reading output diff --git a/core/src/test/scala/org/apache/spark/util/FileAppenderSuite.scala b/core/src/test/scala/org/apache/spark/util/FileAppenderSuite.scala index 280e496498..4fa9f9a8f5 100644 --- a/core/src/test/scala/org/apache/spark/util/FileAppenderSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/FileAppenderSuite.scala @@ -201,24 +201,29 @@ class FileAppenderSuite extends SparkFunSuite with BeforeAndAfter with Logging { // Make sure only logging errors val logger = Logger.getRootLogger + val oldLogLevel = logger.getLevel logger.setLevel(Level.ERROR) - logger.addAppender(mockAppender) + try { + logger.addAppender(mockAppender) - val testOutputStream = new PipedOutputStream() - val testInputStream = new PipedInputStream(testOutputStream) + val testOutputStream = new PipedOutputStream() + val testInputStream = new PipedInputStream(testOutputStream) - // Close the stream before appender tries to read will cause an IOException - testInputStream.close() - testOutputStream.close() - val appender = FileAppender(testInputStream, testFile, new SparkConf) + // Close the stream before appender tries to read will cause an IOException + testInputStream.close() + testOutputStream.close() + val appender = FileAppender(testInputStream, testFile, new SparkConf) - appender.awaitTermination() + appender.awaitTermination() - // If InputStream was closed without first stopping the appender, an exception will be logged - verify(mockAppender, atLeast(1)).doAppend(loggingEventCaptor.capture) - val loggingEvent = loggingEventCaptor.getValue - assert(loggingEvent.getThrowableInformation !== null) - assert(loggingEvent.getThrowableInformation.getThrowable.isInstanceOf[IOException]) + // If InputStream was closed without first stopping the appender, an exception will be logged + verify(mockAppender, atLeast(1)).doAppend(loggingEventCaptor.capture) + val loggingEvent = loggingEventCaptor.getValue + assert(loggingEvent.getThrowableInformation !== null) + assert(loggingEvent.getThrowableInformation.getThrowable.isInstanceOf[IOException]) + } finally { + logger.setLevel(oldLogLevel) + } } test("file appender async close stream gracefully") { @@ -228,30 +233,35 @@ class FileAppenderSuite extends SparkFunSuite with BeforeAndAfter with Logging { // Make sure only logging errors val logger = Logger.getRootLogger + val oldLogLevel = logger.getLevel logger.setLevel(Level.ERROR) - logger.addAppender(mockAppender) + try { + logger.addAppender(mockAppender) - val testOutputStream = new PipedOutputStream() - val testInputStream = new PipedInputStream(testOutputStream) with LatchedInputStream + val testOutputStream = new PipedOutputStream() + val testInputStream = new PipedInputStream(testOutputStream) with LatchedInputStream - // Close the stream before appender tries to read will cause an IOException - testInputStream.close() - testOutputStream.close() - val appender = FileAppender(testInputStream, testFile, new SparkConf) + // Close the stream before appender tries to read will cause an IOException + testInputStream.close() + testOutputStream.close() + val appender = FileAppender(testInputStream, testFile, new SparkConf) - // Stop the appender before an IOException is called during read - testInputStream.latchReadStarted.await() - appender.stop() - testInputStream.latchReadProceed.countDown() + // Stop the appender before an IOException is called during read + testInputStream.latchReadStarted.await() + appender.stop() + testInputStream.latchReadProceed.countDown() - appender.awaitTermination() + appender.awaitTermination() - // Make sure no IOException errors have been logged as a result of appender closing gracefully - verify(mockAppender, atLeast(0)).doAppend(loggingEventCaptor.capture) - import scala.collection.JavaConverters._ - loggingEventCaptor.getAllValues.asScala.foreach { loggingEvent => - assert(loggingEvent.getThrowableInformation === null - || !loggingEvent.getThrowableInformation.getThrowable.isInstanceOf[IOException]) + // Make sure no IOException errors have been logged as a result of appender closing gracefully + verify(mockAppender, atLeast(0)).doAppend(loggingEventCaptor.capture) + import scala.collection.JavaConverters._ + loggingEventCaptor.getAllValues.asScala.foreach { loggingEvent => + assert(loggingEvent.getThrowableInformation === null + || !loggingEvent.getThrowableInformation.getThrowable.isInstanceOf[IOException]) + } + } finally { + logger.setLevel(oldLogLevel) } } diff --git a/dev/deps/spark-deps-hadoop-2.2 b/dev/deps/spark-deps-hadoop-2.2 index 3865a9fb16..2c24366cc3 100644 --- a/dev/deps/spark-deps-hadoop-2.2 +++ b/dev/deps/spark-deps-hadoop-2.2 @@ -12,7 +12,6 @@ asm-3.1.jar asm-commons-3.1.jar asm-tree-3.1.jar avro-1.7.7.jar -avro-ipc-1.7.7-tests.jar avro-ipc-1.7.7.jar avro-mapred-1.7.7-hadoop2.jar bonecp-0.8.0.RELEASE.jar @@ -61,6 +60,7 @@ grizzly-http-2.1.2.jar grizzly-http-server-2.1.2.jar grizzly-http-servlet-2.1.2.jar grizzly-rcm-2.1.2.jar +guava-14.0.1.jar guice-3.0.jar guice-servlet-3.0.jar hadoop-annotations-2.2.0.jar @@ -164,7 +164,6 @@ scala-parser-combinators_2.11-1.0.4.jar scala-reflect-2.11.8.jar scala-xml_2.11-1.0.2.jar scalap-2.11.8.jar -servlet-api-2.5.jar slf4j-api-1.7.16.jar slf4j-log4j12-1.7.16.jar snappy-0.2.jar @@ -177,7 +176,6 @@ stream-2.7.0.jar stringtemplate-3.2.1.jar super-csv-2.2.0.jar univocity-parsers-1.5.6.jar -unused-1.0.0.jar xbean-asm5-shaded-4.4.jar xmlenc-0.52.jar xz-1.0.jar diff --git a/dev/deps/spark-deps-hadoop-2.3 b/dev/deps/spark-deps-hadoop-2.3 index 4313799da7..e9cb0d8f3e 100644 --- a/dev/deps/spark-deps-hadoop-2.3 +++ b/dev/deps/spark-deps-hadoop-2.3 @@ -12,7 +12,6 @@ asm-3.1.jar asm-commons-3.1.jar asm-tree-3.1.jar avro-1.7.7.jar -avro-ipc-1.7.7-tests.jar avro-ipc-1.7.7.jar avro-mapred-1.7.7-hadoop2.jar base64-2.3.8.jar @@ -56,6 +55,7 @@ eigenbase-properties-1.1.5.jar geronimo-annotation_1.0_spec-1.1.1.jar geronimo-jaspic_1.0_spec-1.0.jar geronimo-jta_1.1_spec-1.1.1.jar +guava-14.0.1.jar guice-3.0.jar guice-servlet-3.0.jar hadoop-annotations-2.3.0.jar @@ -155,7 +155,6 @@ scala-parser-combinators_2.11-1.0.4.jar scala-reflect-2.11.8.jar scala-xml_2.11-1.0.2.jar scalap-2.11.8.jar -servlet-api-2.5.jar slf4j-api-1.7.16.jar slf4j-log4j12-1.7.16.jar snappy-0.2.jar @@ -168,7 +167,6 @@ stream-2.7.0.jar stringtemplate-3.2.1.jar super-csv-2.2.0.jar univocity-parsers-1.5.6.jar -unused-1.0.0.jar xbean-asm5-shaded-4.4.jar xmlenc-0.52.jar xz-1.0.jar diff --git a/dev/deps/spark-deps-hadoop-2.4 b/dev/deps/spark-deps-hadoop-2.4 index 910ea685f2..d8d1840da5 100644 --- a/dev/deps/spark-deps-hadoop-2.4 +++ b/dev/deps/spark-deps-hadoop-2.4 @@ -12,7 +12,6 @@ asm-3.1.jar asm-commons-3.1.jar asm-tree-3.1.jar avro-1.7.7.jar -avro-ipc-1.7.7-tests.jar avro-ipc-1.7.7.jar avro-mapred-1.7.7-hadoop2.jar base64-2.3.8.jar @@ -56,6 +55,7 @@ eigenbase-properties-1.1.5.jar geronimo-annotation_1.0_spec-1.1.1.jar geronimo-jaspic_1.0_spec-1.0.jar geronimo-jta_1.1_spec-1.1.1.jar +guava-14.0.1.jar guice-3.0.jar guice-servlet-3.0.jar hadoop-annotations-2.4.0.jar @@ -156,7 +156,6 @@ scala-parser-combinators_2.11-1.0.4.jar scala-reflect-2.11.8.jar scala-xml_2.11-1.0.2.jar scalap-2.11.8.jar -servlet-api-2.5.jar slf4j-api-1.7.16.jar slf4j-log4j12-1.7.16.jar snappy-0.2.jar @@ -169,7 +168,6 @@ stream-2.7.0.jar stringtemplate-3.2.1.jar super-csv-2.2.0.jar univocity-parsers-1.5.6.jar -unused-1.0.0.jar xbean-asm5-shaded-4.4.jar xmlenc-0.52.jar xz-1.0.jar diff --git a/dev/deps/spark-deps-hadoop-2.6 b/dev/deps/spark-deps-hadoop-2.6 index 0692f24e47..8beede1e38 100644 --- a/dev/deps/spark-deps-hadoop-2.6 +++ b/dev/deps/spark-deps-hadoop-2.6 @@ -16,7 +16,6 @@ asm-3.1.jar asm-commons-3.1.jar asm-tree-3.1.jar avro-1.7.7.jar -avro-ipc-1.7.7-tests.jar avro-ipc-1.7.7.jar avro-mapred-1.7.7-hadoop2.jar base64-2.3.8.jar @@ -61,6 +60,7 @@ geronimo-annotation_1.0_spec-1.1.1.jar geronimo-jaspic_1.0_spec-1.0.jar geronimo-jta_1.1_spec-1.1.1.jar gson-2.2.4.jar +guava-14.0.1.jar guice-3.0.jar guice-servlet-3.0.jar hadoop-annotations-2.6.0.jar @@ -162,7 +162,6 @@ scala-parser-combinators_2.11-1.0.4.jar scala-reflect-2.11.8.jar scala-xml_2.11-1.0.2.jar scalap-2.11.8.jar -servlet-api-2.5.jar slf4j-api-1.7.16.jar slf4j-log4j12-1.7.16.jar snappy-0.2.jar @@ -175,7 +174,6 @@ stream-2.7.0.jar stringtemplate-3.2.1.jar super-csv-2.2.0.jar univocity-parsers-1.5.6.jar -unused-1.0.0.jar xbean-asm5-shaded-4.4.jar xercesImpl-2.9.1.jar xmlenc-0.52.jar diff --git a/dev/deps/spark-deps-hadoop-2.7 b/dev/deps/spark-deps-hadoop-2.7 index e397558e05..a9d814f944 100644 --- a/dev/deps/spark-deps-hadoop-2.7 +++ b/dev/deps/spark-deps-hadoop-2.7 @@ -16,7 +16,6 @@ asm-3.1.jar asm-commons-3.1.jar asm-tree-3.1.jar avro-1.7.7.jar -avro-ipc-1.7.7-tests.jar avro-ipc-1.7.7.jar avro-mapred-1.7.7-hadoop2.jar base64-2.3.8.jar @@ -61,6 +60,7 @@ geronimo-annotation_1.0_spec-1.1.1.jar geronimo-jaspic_1.0_spec-1.0.jar geronimo-jta_1.1_spec-1.1.1.jar gson-2.2.4.jar +guava-14.0.1.jar guice-3.0.jar guice-servlet-3.0.jar hadoop-annotations-2.7.0.jar @@ -163,7 +163,6 @@ scala-parser-combinators_2.11-1.0.4.jar scala-reflect-2.11.8.jar scala-xml_2.11-1.0.2.jar scalap-2.11.8.jar -servlet-api-2.5.jar slf4j-api-1.7.16.jar slf4j-log4j12-1.7.16.jar snappy-0.2.jar @@ -176,7 +175,6 @@ stream-2.7.0.jar stringtemplate-3.2.1.jar super-csv-2.2.0.jar univocity-parsers-1.5.6.jar -unused-1.0.0.jar xbean-asm5-shaded-4.4.jar xercesImpl-2.9.1.jar xmlenc-0.52.jar diff --git a/dev/make-distribution.sh b/dev/make-distribution.sh index dbdd42ff9e..4f7544f6ea 100755 --- a/dev/make-distribution.sh +++ b/dev/make-distribution.sh @@ -160,28 +160,35 @@ echo -e "\$ ${BUILD_COMMAND[@]}\n" # Make directories rm -rf "$DISTDIR" -mkdir -p "$DISTDIR/lib" +mkdir -p "$DISTDIR/jars" echo "Spark $VERSION$GITREVSTRING built for Hadoop $SPARK_HADOOP_VERSION" > "$DISTDIR/RELEASE" echo "Build flags: $@" >> "$DISTDIR/RELEASE" # Copy jars -cp "$SPARK_HOME"/assembly/target/scala*/*assembly*hadoop*.jar "$DISTDIR/lib/" -# This will fail if the -Pyarn profile is not provided -# In this case, silence the error and ignore the return code of this command -cp "$SPARK_HOME"/common/network-yarn/target/scala*/spark-*-yarn-shuffle.jar "$DISTDIR/lib/" &> /dev/null || : +cp "$SPARK_HOME"/assembly/target/scala*/jars/* "$DISTDIR/jars/" + +# Only create the yarn directory if the yarn artifacts were build. +if [ -f "$SPARK_HOME"/common/network-yarn/target/scala*/spark-*-yarn-shuffle.jar ]; then + mkdir "$DISTDIR"/yarn + cp "$SPARK_HOME"/common/network-yarn/target/scala*/spark-*-yarn-shuffle.jar "$DISTDIR/yarn" +fi # Copy examples and dependencies mkdir -p "$DISTDIR/examples/jars" cp "$SPARK_HOME"/examples/target/scala*/jars/* "$DISTDIR/examples/jars" +# Deduplicate jars that have already been packaged as part of the main Spark dependencies. +for f in "$DISTDIR/examples/jars/"*; do + name=$(basename "$f") + if [ -f "$DISTDIR/jars/$name" ]; then + rm "$DISTDIR/examples/jars/$name" + fi +done + # Copy example sources (needed for python and SQL) mkdir -p "$DISTDIR/examples/src/main" cp -r "$SPARK_HOME"/examples/src/main "$DISTDIR/examples/src/" -if [ "$SPARK_HIVE" == "1" ]; then - cp "$SPARK_HOME"/lib_managed/jars/datanucleus*.jar "$DISTDIR/lib/" -fi - # Copy license and ASF files cp "$SPARK_HOME/LICENSE" "$DISTDIR" cp -r "$SPARK_HOME/licenses" "$DISTDIR" diff --git a/dev/mima b/dev/mima index ea746e6f01..c355349045 100755 --- a/dev/mima +++ b/dev/mima @@ -25,8 +25,8 @@ FWDIR="$(cd "`dirname "$0"`"/..; pwd)" cd "$FWDIR" SPARK_PROFILES="-Pyarn -Pspark-ganglia-lgpl -Pkinesis-asl -Phive-thriftserver -Phive" -TOOLS_CLASSPATH="$(build/sbt "export tools/fullClasspath" | tail -n1)" -OLD_DEPS_CLASSPATH="$(build/sbt $SPARK_PROFILES "export oldDeps/fullClasspath" | tail -n1)" +TOOLS_CLASSPATH="$(build/sbt -DcopyDependencies=false "export tools/fullClasspath" | tail -n1)" +OLD_DEPS_CLASSPATH="$(build/sbt -DcopyDependencies=false $SPARK_PROFILES "export oldDeps/fullClasspath" | tail -n1)" rm -f .generated-mima* @@ -36,7 +36,7 @@ java \ -cp "$TOOLS_CLASSPATH:$OLD_DEPS_CLASSPATH" \ org.apache.spark.tools.GenerateMIMAIgnore -echo -e "q\n" | build/sbt mimaReportBinaryIssues | grep -v -e "info.*Resolving" +echo -e "q\n" | build/sbt -DcopyDependencies=false "$@" mimaReportBinaryIssues | grep -v -e "info.*Resolving" ret_val=$? if [ $ret_val != 0 ]; then diff --git a/dev/run-tests.py b/dev/run-tests.py index c2944747ee..cbe347274e 100755 --- a/dev/run-tests.py +++ b/dev/run-tests.py @@ -350,7 +350,7 @@ def build_spark_sbt(hadoop_version): def build_spark_assembly_sbt(hadoop_version): # Enable all of the profiles for the build: build_profiles = get_hadoop_profiles(hadoop_version) + modules.root.build_profile_flags - sbt_goals = ["assembly/assembly"] + sbt_goals = ["assembly/package"] profiles_and_goals = build_profiles + sbt_goals print("[info] Building Spark assembly (w/Hive 1.2.1) using SBT with these arguments: ", " ".join(profiles_and_goals)) @@ -371,9 +371,10 @@ def build_apache_spark(build_tool, hadoop_version): build_spark_sbt(hadoop_version) -def detect_binary_inop_with_mima(): +def detect_binary_inop_with_mima(hadoop_version): + build_profiles = get_hadoop_profiles(hadoop_version) + modules.root.build_profile_flags set_title_and_block("Detecting binary incompatibilities with MiMa", "BLOCK_MIMA") - run_cmd([os.path.join(SPARK_HOME, "dev", "mima")]) + run_cmd([os.path.join(SPARK_HOME, "dev", "mima")] + build_profiles) def run_scala_tests_maven(test_profiles): @@ -571,8 +572,8 @@ def main(): # backwards compatibility checks if build_tool == "sbt": # Note: compatibility tests only supported in sbt for now - detect_binary_inop_with_mima() - # Since we did not build assembly/assembly before running dev/mima, we need to + detect_binary_inop_with_mima(hadoop_version) + # Since we did not build assembly/package before running dev/mima, we need to # do it here because the tests still rely on it; see SPARK-13294 for details. build_spark_assembly_sbt(hadoop_version) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 2fdc97f8a0..274a8edb0c 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1687,12 +1687,7 @@ on all of the worker nodes, as they will need access to the Hive serialization a (SerDes) in order to access data stored in Hive. Configuration of Hive is done by placing your `hive-site.xml`, `core-site.xml` (for security configuration), - `hdfs-site.xml` (for HDFS configuration) file in `conf/`. Please note when running -the query on a YARN cluster (`cluster` mode), the `datanucleus` jars under the `lib` directory -and `hive-site.xml` under `conf/` directory need to be available on the driver and all executors launched by the -YARN cluster. The convenient way to do this is adding them through the `--jars` option and `--file` option of the -`spark-submit` command. - +`hdfs-site.xml` (for HDFS configuration) file in `conf/`.
    diff --git a/examples/pom.xml b/examples/pom.xml index b7f37978b9..4a20370f06 100644 --- a/examples/pom.xml +++ b/examples/pom.xml @@ -27,13 +27,16 @@ org.apache.spark spark-examples_2.11 - - examples - jar Spark Project Examples http://spark.apache.org/ + + examples + none + package + + org.apache.spark @@ -75,23 +78,6 @@ spark-streaming-kafka_${scala.binary.version} ${project.version} - - org.apache.hbase - hbase-testing-util - ${hbase.version} - ${hbase.deps.scope} - - - - org.apache.hbase - hbase-annotations - - - org.jruby - jruby-complete - - - org.apache.hbase hbase-protocol @@ -139,6 +125,10 @@ org.apache.hbase hbase-annotations + + org.apache.hbase + hbase-common + org.apache.hadoop hadoop-core @@ -208,13 +198,6 @@ ${hbase.version} ${hbase.deps.scope} - - org.apache.hbase - hbase-hadoop-compat - ${hbase.version} - test-jar - test - org.apache.commons commons-math3 @@ -294,17 +277,6 @@ scopt_${scala.binary.version} 3.3.0 - - - - org.scala-lang - scala-library - provided - - @@ -325,38 +297,6 @@ true - - org.apache.maven.plugins - maven-jar-plugin - - - prepare-test-jar - none - - test-jar - - - - - ${jars.target.dir} - - - - org.apache.maven.plugins - maven-dependency-plugin - - - package - - copy-dependencies - - - runtime - ${jars.target.dir} - - - - diff --git a/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java b/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java index d02b2a4994..7a5e37c501 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java +++ b/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java @@ -144,10 +144,26 @@ abstract class AbstractCommandBuilder { boolean isTesting = "1".equals(getenv("SPARK_TESTING")); if (prependClasses || isTesting) { String scala = getScalaVersion(); - List projects = Arrays.asList("core", "repl", "mllib", "graphx", - "streaming", "tools", "sql/catalyst", "sql/core", "sql/hive", "sql/hive-thriftserver", - "yarn", "launcher", - "common/network-common", "common/network-shuffle", "common/network-yarn"); + List projects = Arrays.asList( + "common/network-common", + "common/network-shuffle", + "common/network-yarn", + "common/sketch", + "common/tags", + "common/unsafe", + "core", + "examples", + "graphx", + "launcher", + "mllib", + "repl", + "sql/catalyst", + "sql/core", + "sql/hive", + "sql/hive-thriftserver", + "streaming", + "yarn" + ); if (prependClasses) { if (!isTesting) { System.err.println( @@ -174,31 +190,12 @@ abstract class AbstractCommandBuilder { // Add Spark jars to the classpath. For the testing case, we rely on the test code to set and // propagate the test classpath appropriately. For normal invocation, look for the jars // directory under SPARK_HOME. - String jarsDir = findJarsDir(getSparkHome(), getScalaVersion(), !isTesting); + boolean isTestingSql = "1".equals(getenv("SPARK_SQL_TESTING")); + String jarsDir = findJarsDir(getSparkHome(), getScalaVersion(), !isTesting && !isTestingSql); if (jarsDir != null) { addToClassPath(cp, join(File.separator, jarsDir, "*")); } - // Datanucleus jars must be included on the classpath. Datanucleus jars do not work if only - // included in the uber jar as plugin.xml metadata is lost. Both sbt and maven will populate - // "lib_managed/jars/" with the datanucleus jars when Spark is built with Hive - File libdir; - if (new File(sparkHome, "RELEASE").isFile()) { - libdir = new File(sparkHome, "lib"); - } else { - libdir = new File(sparkHome, "lib_managed/jars"); - } - - if (libdir.isDirectory()) { - for (File jar : libdir.listFiles()) { - if (jar.getName().startsWith("datanucleus-")) { - addToClassPath(cp, jar.getAbsolutePath()); - } - } - } else { - checkState(isTesting, "Library directory '%s' does not exist.", libdir.getAbsolutePath()); - } - addToClassPath(cp, getenv("HADOOP_CONF_DIR")); addToClassPath(cp, getenv("YARN_CONF_DIR")); addToClassPath(cp, getenv("SPARK_DIST_CLASSPATH")); diff --git a/launcher/src/main/java/org/apache/spark/launcher/CommandBuilderUtils.java b/launcher/src/main/java/org/apache/spark/launcher/CommandBuilderUtils.java index a08c8dcba4..91586aad7b 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/CommandBuilderUtils.java +++ b/launcher/src/main/java/org/apache/spark/launcher/CommandBuilderUtils.java @@ -358,12 +358,12 @@ class CommandBuilderUtils { // TODO: change to the correct directory once the assembly build is changed. File libdir; if (new File(sparkHome, "RELEASE").isFile()) { - libdir = new File(sparkHome, "lib"); + libdir = new File(sparkHome, "jars"); checkState(!failIfNotFound || libdir.isDirectory(), "Library directory '%s' does not exist.", libdir.getAbsolutePath()); } else { - libdir = new File(sparkHome, String.format("assembly/target/scala-%s", scalaVersion)); + libdir = new File(sparkHome, String.format("assembly/target/scala-%s/jars", scalaVersion)); if (!libdir.isDirectory()) { checkState(!failIfNotFound, "Library directory '%s' does not exist; make sure Spark is built.", diff --git a/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java b/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java index 56e4107c5a..c31c42cd3a 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java +++ b/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java @@ -336,6 +336,7 @@ class SparkSubmitCommandBuilder extends AbstractCommandBuilder { } private List findExamplesJars() { + boolean isTesting = "1".equals(getenv("SPARK_TESTING")); List examplesJars = new ArrayList<>(); String sparkHome = getSparkHome(); @@ -346,11 +347,15 @@ class SparkSubmitCommandBuilder extends AbstractCommandBuilder { jarsDir = new File(sparkHome, String.format("examples/target/scala-%s/jars", getScalaVersion())); } - checkState(jarsDir.isDirectory(), "Examples jars directory '%s' does not exist.", + + boolean foundDir = jarsDir.isDirectory(); + checkState(isTesting || foundDir, "Examples jars directory '%s' does not exist.", jarsDir.getAbsolutePath()); - for (File f: jarsDir.listFiles()) { - examplesJars.add(f.getAbsolutePath()); + if (foundDir) { + for (File f: jarsDir.listFiles()) { + examplesJars.add(f.getAbsolutePath()); + } } return examplesJars; } diff --git a/pom.xml b/pom.xml index e135c92c07..984b2859ef 100644 --- a/pom.xml +++ b/pom.xml @@ -185,6 +185,10 @@ ${project.build.directory}/scala-${scala.binary.version}/jars + + prepare-package + none + - - org.spark-project.spark - unused - 1.0.0 - + + org.apache.avro + avro-ipc + tests + ${avro.version} + test + org.apache.avro avro-mapred @@ -1521,6 +1524,10 @@ org.codehaus.groovy groovy-all + + javax.servlet + servlet-api + @@ -1916,6 +1923,7 @@ --> ${test_classpath} 1 + ${scala.binary.version} 1 ${test.java.home} @@ -1964,6 +1972,7 @@ --> ${test_classpath} 1 + ${scala.binary.version} 1 ${test.java.home} @@ -2146,6 +2155,7 @@ 2.10 + generate-test-classpath test-compile build-classpath @@ -2155,6 +2165,17 @@ test_classpath + + copy-module-dependencies + ${build.copyDependenciesPhase} + + copy-dependencies + + + runtime + ${jars.target.dir} + + @@ -2169,9 +2190,6 @@ false - - org.spark-project.spark:unused - org.eclipse.jetty:jetty-io org.eclipse.jetty:jetty-http org.eclipse.jetty:jetty-continuation @@ -2302,7 +2320,7 @@ prepare-test-jar - prepare-package + ${build.testJarPhase} test-jar diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 5d62b688b9..b32480b164 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -57,11 +57,12 @@ object BuildCommons { Seq("yarn", "java8-tests", "ganglia-lgpl", "streaming-kinesis-asl", "docker-integration-tests").map(ProjectRef(buildLocation, _)) - val assemblyProjects@Seq(assembly, networkYarn, streamingFlumeAssembly, streamingKafkaAssembly, streamingKinesisAslAssembly) = - Seq("assembly", "network-yarn", "streaming-flume-assembly", "streaming-kafka-assembly", "streaming-kinesis-asl-assembly") + val assemblyProjects@Seq(networkYarn, streamingFlumeAssembly, streamingKafkaAssembly, streamingKinesisAslAssembly) = + Seq("network-yarn", "streaming-flume-assembly", "streaming-kafka-assembly", "streaming-kinesis-asl-assembly") .map(ProjectRef(buildLocation, _)) - val copyJarsProjects@Seq(examples) = Seq("examples").map(ProjectRef(buildLocation, _)) + val copyJarsProjects@Seq(assembly, examples) = Seq("assembly", "examples") + .map(ProjectRef(buildLocation, _)) val tools = ProjectRef(buildLocation, "tools") // Root project. @@ -263,8 +264,14 @@ object SparkBuild extends PomBuild { /* Unsafe settings */ enable(Unsafe.settings)(unsafe) - /* Set up tasks to copy dependencies during packaging. */ - copyJarsProjects.foreach(enable(CopyDependencies.settings)) + /* + * Set up tasks to copy dependencies during packaging. This step can be disabled in the command + * line, so that dev/mima can run without trying to copy these files again and potentially + * causing issues. + */ + if (!"false".equals(System.getProperty("copyDependencies"))) { + copyJarsProjects.foreach(enable(CopyDependencies.settings)) + } /* Enable Assembly for all assembly projects */ assemblyProjects.foreach(enable(Assembly.settings)) @@ -477,8 +484,6 @@ object Assembly { val hadoopVersion = taskKey[String]("The version of hadoop that spark is compiled against.") - val deployDatanucleusJars = taskKey[Unit]("Deploy datanucleus jars to the spark/lib_managed/jars directory") - lazy val settings = assemblySettings ++ Seq( test in assembly := {}, hadoopVersion := { @@ -497,27 +502,13 @@ object Assembly { s"${mName}-test-${v}.jar" }, mergeStrategy in assembly := { - case PathList("org", "datanucleus", xs @ _*) => MergeStrategy.discard case m if m.toLowerCase.endsWith("manifest.mf") => MergeStrategy.discard case m if m.toLowerCase.matches("meta-inf.*\\.sf$") => MergeStrategy.discard case "log4j.properties" => MergeStrategy.discard case m if m.toLowerCase.startsWith("meta-inf/services/") => MergeStrategy.filterDistinctLines case "reference.conf" => MergeStrategy.concat case _ => MergeStrategy.first - }, - deployDatanucleusJars := { - val jars: Seq[File] = (fullClasspath in assembly).value.map(_.data) - .filter(_.getPath.contains("org.datanucleus")) - var libManagedJars = new File(BuildCommons.sparkHome, "lib_managed/jars") - libManagedJars.mkdirs() - jars.foreach { jar => - val dest = new File(libManagedJars, jar.getName) - if (!dest.exists()) { - Files.copy(jar.toPath, dest.toPath) - } - } - }, - assembly <<= assembly.dependsOn(deployDatanucleusJars) + } ) } @@ -698,6 +689,13 @@ object Java8TestSettings { object TestSettings { import BuildCommons._ + private val scalaBinaryVersion = + if (System.getProperty("scala-2.10") == "true") { + "2.10" + } else { + "2.11" + } + lazy val settings = Seq ( // Fork new JVMs for tests and set Java options for those fork := true, @@ -707,6 +705,7 @@ object TestSettings { "SPARK_DIST_CLASSPATH" -> (fullClasspath in Test).value.files.map(_.getAbsolutePath).mkString(":").stripSuffix(":"), "SPARK_PREPEND_CLASSES" -> "1", + "SPARK_SCALA_VERSION" -> scalaBinaryVersion, "SPARK_TESTING" -> "1", "JAVA_HOME" -> sys.env.get("JAVA_HOME").getOrElse(sys.props("java.home"))), javaOptions in Test += s"-Djava.io.tmpdir=$testTempDir", @@ -744,7 +743,7 @@ object TestSettings { // Make sure the test temp directory exists. resourceGenerators in Test <+= resourceManaged in Test map { outDir: File => if (!new File(testTempDir).isDirectory()) { - require(new File(testTempDir).mkdirs()) + require(new File(testTempDir).mkdirs(), s"Error creating temp directory $testTempDir.") } Seq[File]() }, diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py index d010c0e008..148bf7e8ff 100644 --- a/python/pyspark/streaming/tests.py +++ b/python/pyspark/streaming/tests.py @@ -1482,7 +1482,7 @@ def search_kafka_assembly_jar(): raise Exception( ("Failed to find Spark Streaming kafka assembly jar in %s. " % kafka_assembly_dir) + "You need to build Spark with " - "'build/sbt assembly/assembly streaming-kafka-assembly/assembly' or " + "'build/sbt assembly/package streaming-kafka-assembly/assembly' or " "'build/mvn package' before running this test.") elif len(jars) > 1: raise Exception(("Found multiple Spark Streaming Kafka assembly JARs: %s; please " @@ -1548,7 +1548,7 @@ if __name__ == "__main__": elif are_kinesis_tests_enabled is False: sys.stderr.write("Skipping all Kinesis Python tests as the optional Kinesis project was " "not compiled into a JAR. To run these tests, " - "you need to build Spark with 'build/sbt -Pkinesis-asl assembly/assembly " + "you need to build Spark with 'build/sbt -Pkinesis-asl assembly/package " "streaming-kinesis-asl-assembly/assembly' or " "'build/mvn -Pkinesis-asl package' before running this test.") else: @@ -1556,7 +1556,7 @@ if __name__ == "__main__": ("Failed to find Spark Streaming Kinesis assembly jar in %s. " % kinesis_asl_assembly_dir) + "You need to build Spark with 'build/sbt -Pkinesis-asl " - "assembly/assembly streaming-kinesis-asl-assembly/assembly'" + "assembly/package streaming-kinesis-asl-assembly/assembly'" "or 'build/mvn -Pkinesis-asl package' before running this test.") sys.stderr.write("Running tests: %s \n" % (str(testcases))) diff --git a/python/run-tests.py b/python/run-tests.py index a9f8854e6f..38b3bb84c1 100755 --- a/python/run-tests.py +++ b/python/run-tests.py @@ -53,11 +53,25 @@ LOG_FILE = os.path.join(SPARK_HOME, "python/unit-tests.log") FAILURE_REPORTING_LOCK = Lock() LOGGER = logging.getLogger() +# Find out where the assembly jars are located. +for scala in ["2.11", "2.10"]: + build_dir = os.path.join(SPARK_HOME, "assembly", "target", "scala-" + scala) + if os.path.isdir(build_dir): + SPARK_DIST_CLASSPATH = os.path.join(build_dir, "jars", "*") + break +else: + raise Exception("Cannot find assembly build directory, please build Spark first.") + def run_individual_python_test(test_name, pyspark_python): env = dict(os.environ) - env.update({'SPARK_TESTING': '1', 'PYSPARK_PYTHON': which(pyspark_python), - 'PYSPARK_DRIVER_PYTHON': which(pyspark_python)}) + env.update({ + 'SPARK_DIST_CLASSPATH': SPARK_DIST_CLASSPATH, + 'SPARK_TESTING': '1', + 'SPARK_PREPEND_CLASSES': '1', + 'PYSPARK_PYTHON': which(pyspark_python), + 'PYSPARK_DRIVER_PYTHON': which(pyspark_python) + }) LOGGER.debug("Starting test(%s): %s", pyspark_python, test_name) start_time = time.time() try: diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala index 33af624cfd..2c7358e59a 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala @@ -763,11 +763,15 @@ abstract class HiveThriftServer2Test extends SparkFunSuite with BeforeAndAfterAl extraEnvironment = Map( // Disables SPARK_TESTING to exclude log4j.properties in test directories. "SPARK_TESTING" -> "0", + // But set SPARK_SQL_TESTING to make spark-class happy. + "SPARK_SQL_TESTING" -> "1", // Points SPARK_PID_DIR to SPARK_HOME, otherwise only 1 Thrift server instance can be // started at a time, which is not Jenkins friendly. "SPARK_PID_DIR" -> pidDir.getCanonicalPath), redirectStderr = true) + logInfo(s"COMMAND: $command") + logInfo(s"OUTPUT: $lines") lines.split("\n").collectFirst { case line if line.contains(LOG_FILE_MARK) => new File(line.drop(LOG_FILE_MARK.length)) }.getOrElse { diff --git a/sql/hive/pom.xml b/sql/hive/pom.xml index 58efd80512..61504becf1 100644 --- a/sql/hive/pom.xml +++ b/sql/hive/pom.xml @@ -225,30 +225,6 @@ -da -Xmx3g -XX:MaxPermSize=${MaxPermGen} -XX:ReservedCodeCacheSize=512m - - - - org.apache.maven.plugins - maven-dependency-plugin - - - copy-dependencies - package - - copy-dependencies - - - - ${basedir}/../../lib_managed/jars - false - false - true - org.datanucleus - - - - - diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index 4dd3ccdf37..336e29fc6b 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -447,9 +447,6 @@ private[spark] class Client( * * Note that the archive cannot be a "local" URI. If none of the above settings are found, * then upload all files found in $SPARK_HOME/jars. - * - * TODO: currently the code looks in $SPARK_HOME/lib while the work to replace assemblies - * with a directory full of jars is ongoing. */ val sparkArchive = sparkConf.get(SPARK_ARCHIVE) if (sparkArchive.isDefined) { diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala index 2eaafa072a..74e268dc48 100644 --- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala @@ -273,7 +273,7 @@ class ClientSuite extends SparkFunSuite with Matchers with BeforeAndAfterAll test("distribute local spark jars") { val temp = Utils.createTempDir() - val jarsDir = new File(temp, "lib") + val jarsDir = new File(temp, "jars") assert(jarsDir.mkdir()) val jar = TestUtils.createJarWithFiles(Map(), jarsDir) new FileOutputStream(new File(temp, "RELEASE")).close() -- cgit v1.2.3 From a172e11cba6f917baf5bd6c4f83dc6689932de9a Mon Sep 17 00:00:00 2001 From: Luciano Resende Date: Mon, 4 Apr 2016 16:55:59 -0700 Subject: [SPARK-14366] Remove sbt-idea plugin ## What changes were proposed in this pull request? Remove sbt-idea plugin as importing sbt project provides much better support. Author: Luciano Resende Closes #12151 from lresende/SPARK-14366. --- project/plugins.sbt | 2 -- 1 file changed, 2 deletions(-) diff --git a/project/plugins.sbt b/project/plugins.sbt index 4929ba3c4d..44ec3a12ae 100644 --- a/project/plugins.sbt +++ b/project/plugins.sbt @@ -2,8 +2,6 @@ addSbtPlugin("com.eed3si9n" % "sbt-assembly" % "0.11.2") addSbtPlugin("com.typesafe.sbteclipse" % "sbteclipse-plugin" % "4.0.0") -addSbtPlugin("com.github.mpeltonen" % "sbt-idea" % "1.6.0") - addSbtPlugin("net.virtual-void" % "sbt-dependency-graph" % "0.8.2") addSbtPlugin("org.scalastyle" %% "scalastyle-sbt-plugin" % "0.8.0") -- cgit v1.2.3 From 7201f033ce520259b6d07ea5ead92272cac92363 Mon Sep 17 00:00:00 2001 From: Guillaume Poulin Date: Tue, 5 Apr 2016 02:54:38 +0100 Subject: [SPARK-12425][STREAMING] DStream union optimisation Use PartitionerAwareUnionRDD when possbile for optimizing shuffling and preserving the partitioner. Author: Guillaume Poulin Closes #10382 from gpoulin/dstream_union_optimisation. --- core/src/main/scala/org/apache/spark/rdd/RDD.scala | 6 +----- .../org/apache/spark/streaming/dstream/UnionDStream.scala | 4 ++-- .../org/apache/spark/streaming/dstream/WindowedDStream.scala | 11 ++--------- 3 files changed, 5 insertions(+), 16 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index 4a0a2199ef..032939b49a 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -568,11 +568,7 @@ abstract class RDD[T: ClassTag]( * times (use `.distinct()` to eliminate them). */ def union(other: RDD[T]): RDD[T] = withScope { - if (partitioner.isDefined && other.partitioner == partitioner) { - new PartitionerAwareUnionRDD(sc, Array(this, other)) - } else { - new UnionRDD(sc, Array(this, other)) - } + sc.union(this, other) } /** diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/UnionDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/UnionDStream.scala index c1846a31f6..d46c0a01e0 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/UnionDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/UnionDStream.scala @@ -21,7 +21,7 @@ import scala.collection.mutable.ArrayBuffer import scala.reflect.ClassTag import org.apache.spark.SparkException -import org.apache.spark.rdd.{RDD, UnionRDD} +import org.apache.spark.rdd.RDD import org.apache.spark.streaming.{Duration, Time} private[streaming] @@ -45,7 +45,7 @@ class UnionDStream[T: ClassTag](parents: Array[DStream[T]]) s" time $validTime") } if (rdds.nonEmpty) { - Some(new UnionRDD(ssc.sc, rdds)) + Some(ssc.sc.union(rdds)) } else { None } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/WindowedDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/WindowedDStream.scala index ee50a8d024..fe0f875525 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/WindowedDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/WindowedDStream.scala @@ -19,7 +19,7 @@ package org.apache.spark.streaming.dstream import scala.reflect.ClassTag -import org.apache.spark.rdd.{PartitionerAwareUnionRDD, RDD, UnionRDD} +import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming._ import org.apache.spark.streaming.Duration @@ -63,13 +63,6 @@ class WindowedDStream[T: ClassTag]( override def compute(validTime: Time): Option[RDD[T]] = { val currentWindow = new Interval(validTime - windowDuration + parent.slideDuration, validTime) val rddsInWindow = parent.slice(currentWindow) - val windowRDD = if (rddsInWindow.flatMap(_.partitioner).distinct.length == 1) { - logDebug("Using partition aware union for windowing at " + validTime) - new PartitionerAwareUnionRDD(ssc.sc, rddsInWindow) - } else { - logDebug("Using normal union for windowing at " + validTime) - new UnionRDD(ssc.sc, rddsInWindow) - } - Some(windowRDD) + Some(ssc.sc.union(rddsInWindow)) } } -- cgit v1.2.3 From ba24d1ee9a1d97ca82282f3b811ec011c4285b99 Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Mon, 4 Apr 2016 19:04:09 -0700 Subject: [SPARK-14287] isStreaming method for Dataset With the addition of StreamExecution (ContinuousQuery) to Datasets, data will become unbounded. With unbounded data, the execution of some methods and operations will not make sense, e.g. `Dataset.count()`. A simple API is required to check whether the data in a Dataset is bounded or unbounded. This will allow users to check whether their Dataset is in streaming mode or not. ML algorithms may check if the data is unbounded and throw an exception for example. The implementation of this method is simple, however naming it is the challenge. Some possible names for this method are: - isStreaming - isContinuous - isBounded - isUnbounded I've gone with `isStreaming` for now. We can change it before Spark 2.0 if we decide to come up with a different name. For that reason I've marked it as `Experimental` Author: Burak Yavuz Closes #12080 from brkyvz/is-streaming. --- .../src/main/scala/org/apache/spark/sql/Dataset.scala | 15 +++++++++++++++ .../test/scala/org/apache/spark/sql/DatasetSuite.scala | 18 ++++++++++++++++++ 2 files changed, 33 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 8dfe8ff702..db2134b020 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -47,6 +47,7 @@ import org.apache.spark.sql.execution.command.ExplainCommand import org.apache.spark.sql.execution.datasources.{CreateTableUsingAsSelect, LogicalRelation} import org.apache.spark.sql.execution.datasources.json.JacksonGenerator import org.apache.spark.sql.execution.python.EvaluatePython +import org.apache.spark.sql.execution.streaming.StreamingRelation import org.apache.spark.sql.types._ import org.apache.spark.storage.StorageLevel import org.apache.spark.util.Utils @@ -449,6 +450,20 @@ class Dataset[T] private[sql]( */ def isLocal: Boolean = logicalPlan.isInstanceOf[LocalRelation] + /** + * Returns true if this [[Dataset]] contains one or more sources that continuously + * return data as it arrives. A [[Dataset]] that reads data from a streaming source + * must be executed as a [[ContinuousQuery]] using the `startStream()` method in + * [[DataFrameWriter]]. Methods that return a single answer, (e.g., `count()` or + * `collect()`) will throw an [[AnalysisException]] when there is a streaming + * source present. + * + * @group basic + * @since 2.0.0 + */ + @Experimental + def isStreaming: Boolean = logicalPlan.find(_.isInstanceOf[StreamingRelation]).isDefined + /** * Displays the [[Dataset]] in a tabular form. Strings more than 20 characters will be truncated, * and all cells will be aligned right. For example: diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 2aa90568c3..e8e801084f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -23,6 +23,7 @@ import java.sql.{Date, Timestamp} import scala.language.postfixOps import org.apache.spark.sql.catalyst.encoders.OuterScopes +import org.apache.spark.sql.execution.streaming.MemoryStream import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType} @@ -602,6 +603,23 @@ class DatasetSuite extends QueryTest with SharedSQLContext { TupleClass(1, "a") ) } + + test("isStreaming returns false for static Dataset") { + val data = Seq(("a", 1), ("b", 2), ("c", 3)).toDS() + assert(!data.isStreaming, "static Dataset returned true for 'isStreaming'.") + } + + test("isStreaming returns true for streaming Dataset") { + val data = MemoryStream[Int].toDS() + assert(data.isStreaming, "streaming Dataset returned false for 'isStreaming'.") + } + + test("isStreaming returns true after static and streaming Dataset join") { + val static = Seq(("a", 1), ("b", 2), ("c", 3)).toDF("a", "b") + val streaming = MemoryStream[Int].toDS().toDF("b") + val df = streaming.join(static, Seq("b")) + assert(df.isStreaming, "streaming Dataset returned false for 'isStreaming'.") + } } case class OtherTuple(_1: String, _2: Int) -- cgit v1.2.3 From 8f50574ab4021b9984b0017cd47ba012a894c19a Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Mon, 4 Apr 2016 20:12:09 -0700 Subject: [SPARK-14386][ML] Changed spark.ml ensemble trees methods to return concrete types ## What changes were proposed in this pull request? In spark.ml, GBT and RandomForest expose the trait DecisionTreeModel in the trees method, but they should not since it is a private trait (and not ready to be made public). It will also be more useful to users if we return the concrete types. This PR: return concrete types The MIMA checks appear to be OK with this change. ## How was this patch tested? Existing unit tests Author: Joseph K. Bradley Closes #12158 from jkbradley/hide-dtm. --- .../org/apache/spark/ml/classification/GBTClassifier.scala | 7 +++---- .../spark/ml/classification/RandomForestClassifier.scala | 6 +++--- .../org/apache/spark/ml/regression/GBTRegressor.scala | 7 +++---- .../apache/spark/ml/regression/RandomForestRegressor.scala | 5 +++-- .../main/scala/org/apache/spark/ml/tree/treeModels.scala | 14 +++++++++----- .../scala/org/apache/spark/ml/tree/impl/TreeTests.scala | 2 +- 6 files changed, 22 insertions(+), 19 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala index bfefaf1a1a..bee90fb3a5 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala @@ -24,8 +24,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.ml.{PredictionModel, Predictor} import org.apache.spark.ml.param.{Param, ParamMap} import org.apache.spark.ml.regression.DecisionTreeRegressionModel -import org.apache.spark.ml.tree.{DecisionTreeModel, GBTParams, TreeClassifierParams, - TreeEnsembleModel} +import org.apache.spark.ml.tree.{GBTParams, TreeClassifierParams, TreeEnsembleModel} import org.apache.spark.ml.tree.impl.GradientBoostedTrees import org.apache.spark.ml.util.{Identifiable, MetadataUtils} import org.apache.spark.mllib.linalg.Vector @@ -190,7 +189,7 @@ final class GBTClassificationModel private[ml]( private val _treeWeights: Array[Double], @Since("1.6.0") override val numFeatures: Int) extends PredictionModel[Vector, GBTClassificationModel] - with TreeEnsembleModel with Serializable { + with TreeEnsembleModel[DecisionTreeRegressionModel] with Serializable { require(_trees.nonEmpty, "GBTClassificationModel requires at least 1 tree.") require(_trees.length == _treeWeights.length, "GBTClassificationModel given trees, treeWeights" + @@ -206,7 +205,7 @@ final class GBTClassificationModel private[ml]( this(uid, _trees, _treeWeights, -1) @Since("1.4.0") - override def trees: Array[DecisionTreeModel] = _trees.asInstanceOf[Array[DecisionTreeModel]] + override def trees: Array[DecisionTreeRegressionModel] = _trees @Since("1.4.0") override def treeWeights: Array[Double] = _treeWeights diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala index 2ad893f4fa..cb42532271 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala @@ -155,8 +155,8 @@ final class RandomForestClassificationModel private[ml] ( @Since("1.6.0") override val numFeatures: Int, @Since("1.5.0") override val numClasses: Int) extends ProbabilisticClassificationModel[Vector, RandomForestClassificationModel] - with RandomForestClassificationModelParams with TreeEnsembleModel with MLWritable - with Serializable { + with RandomForestClassificationModelParams with TreeEnsembleModel[DecisionTreeClassificationModel] + with MLWritable with Serializable { require(_trees.nonEmpty, "RandomForestClassificationModel requires at least 1 tree.") @@ -172,7 +172,7 @@ final class RandomForestClassificationModel private[ml] ( this(Identifiable.randomUID("rfc"), trees, numFeatures, numClasses) @Since("1.4.0") - override def trees: Array[DecisionTreeModel] = _trees.asInstanceOf[Array[DecisionTreeModel]] + override def trees: Array[DecisionTreeClassificationModel] = _trees // Note: We may add support for weights (based on tree performance) later on. private lazy val _treeWeights: Array[Double] = Array.fill[Double](_trees.length)(1.0) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala index 02e124a1c0..cef7c643d7 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala @@ -23,8 +23,7 @@ import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.internal.Logging import org.apache.spark.ml.{PredictionModel, Predictor} import org.apache.spark.ml.param.{Param, ParamMap} -import org.apache.spark.ml.tree.{DecisionTreeModel, GBTParams, TreeEnsembleModel, - TreeRegressorParams} +import org.apache.spark.ml.tree.{GBTParams, TreeEnsembleModel, TreeRegressorParams} import org.apache.spark.ml.tree.impl.GradientBoostedTrees import org.apache.spark.ml.util.{Identifiable, MetadataUtils} import org.apache.spark.mllib.linalg.Vector @@ -177,7 +176,7 @@ final class GBTRegressionModel private[ml]( private val _treeWeights: Array[Double], override val numFeatures: Int) extends PredictionModel[Vector, GBTRegressionModel] - with TreeEnsembleModel with Serializable { + with TreeEnsembleModel[DecisionTreeRegressionModel] with Serializable { require(_trees.nonEmpty, "GBTRegressionModel requires at least 1 tree.") require(_trees.length == _treeWeights.length, "GBTRegressionModel given trees, treeWeights of" + @@ -193,7 +192,7 @@ final class GBTRegressionModel private[ml]( this(uid, _trees, _treeWeights, -1) @Since("1.4.0") - override def trees: Array[DecisionTreeModel] = _trees.asInstanceOf[Array[DecisionTreeModel]] + override def trees: Array[DecisionTreeRegressionModel] = _trees @Since("1.4.0") override def treeWeights: Array[Double] = _treeWeights diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala index ba56b5cd3f..736cd9f776 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala @@ -142,7 +142,8 @@ final class RandomForestRegressionModel private[ml] ( private val _trees: Array[DecisionTreeRegressionModel], override val numFeatures: Int) extends PredictionModel[Vector, RandomForestRegressionModel] - with RandomForestRegressionModelParams with TreeEnsembleModel with MLWritable with Serializable { + with RandomForestRegressionModelParams with TreeEnsembleModel[DecisionTreeRegressionModel] + with MLWritable with Serializable { require(_trees.nonEmpty, "RandomForestRegressionModel requires at least 1 tree.") @@ -155,7 +156,7 @@ final class RandomForestRegressionModel private[ml] ( this(Identifiable.randomUID("rfr"), trees, numFeatures) @Since("1.4.0") - override def trees: Array[DecisionTreeModel] = _trees.asInstanceOf[Array[DecisionTreeModel]] + override def trees: Array[DecisionTreeRegressionModel] = _trees // Note: We may add support for weights (based on tree performance) later on. private lazy val _treeWeights: Array[Double] = Array.fill[Double](_trees.length)(1.0) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala index 48b8fd19ad..db0ff28d82 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala @@ -17,6 +17,8 @@ package org.apache.spark.ml.tree +import scala.reflect.ClassTag + import org.apache.hadoop.fs.Path import org.json4s._ import org.json4s.jackson.JsonMethods._ @@ -82,14 +84,16 @@ private[spark] trait DecisionTreeModel { * Abstraction for models which are ensembles of decision trees * * TODO: Add support for predicting probabilities and raw predictions SPARK-3727 + * + * @tparam M Type of tree model in this ensemble */ -private[ml] trait TreeEnsembleModel { +private[ml] trait TreeEnsembleModel[M <: DecisionTreeModel] { // Note: We use getTrees since subclasses of TreeEnsembleModel will store subclasses of // DecisionTreeModel. /** Trees in this ensemble. Warning: These have null parent Estimators. */ - def trees: Array[DecisionTreeModel] + def trees: Array[M] /** * Number of trees in ensemble @@ -148,7 +152,7 @@ private[ml] object TreeEnsembleModel { * If -1, then numFeatures is set based on the max feature index in all trees. * @return Feature importance values, of length numFeatures. */ - def featureImportances(trees: Array[DecisionTreeModel], numFeatures: Int): Vector = { + def featureImportances[M <: DecisionTreeModel](trees: Array[M], numFeatures: Int): Vector = { val totalImportances = new OpenHashMap[Int, Double]() trees.foreach { tree => // Aggregate feature importance vector for this tree @@ -199,7 +203,7 @@ private[ml] object TreeEnsembleModel { * If -1, then numFeatures is set based on the max feature index in all trees. * @return Feature importance values, of length numFeatures. */ - def featureImportances(tree: DecisionTreeModel, numFeatures: Int): Vector = { + def featureImportances[M <: DecisionTreeModel : ClassTag](tree: M, numFeatures: Int): Vector = { featureImportances(Array(tree), numFeatures) } @@ -386,7 +390,7 @@ private[ml] object EnsembleModelReadWrite { * @param path Path to which to save the ensemble model. * @param extraMetadata Metadata such as numFeatures, numClasses, numTrees. */ - def saveImpl[M <: Params with TreeEnsembleModel]( + def saveImpl[M <: Params with TreeEnsembleModel[_ <: DecisionTreeModel]]( instance: M, path: String, sql: SQLContext, diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala index bd5bd17147..b650a9f092 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala @@ -131,7 +131,7 @@ private[ml] object TreeTests extends SparkFunSuite { * Check if the two models are exactly the same. * If the models are not equal, this throws an exception. */ - def checkEqual(a: TreeEnsembleModel, b: TreeEnsembleModel): Unit = { + def checkEqual[M <: DecisionTreeModel](a: TreeEnsembleModel[M], b: TreeEnsembleModel[M]): Unit = { try { a.trees.zip(b.trees).foreach { case (treeA, treeB) => TreeTests.checkEqual(treeA, treeB) -- cgit v1.2.3 From 7db56244fa3dba92246bad6694f31bbf68ea47ec Mon Sep 17 00:00:00 2001 From: Yong Tang Date: Tue, 5 Apr 2016 12:19:20 +0900 Subject: [SPARK-14368][PYSPARK] Support python.spark.worker.memory with upper-case unit. ## What changes were proposed in this pull request? This fix tries to address the issue in PySpark where `spark.python.worker.memory` could only be configured with a lower case unit (`k`, `m`, `g`, `t`). This fix allows the upper case unit (`K`, `M`, `G`, `T`) to be used as well. This is to conform to the JVM memory string as is specified in the documentation . ## How was this patch tested? This fix adds additional test to cover the changes. Author: Yong Tang Closes #12163 from yongtang/SPARK-14368. --- python/pyspark/rdd.py | 2 +- python/pyspark/tests.py | 12 ++++++++++++ 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index cd1f64e8aa..8978f028c5 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -115,7 +115,7 @@ def _parse_memory(s): 2048 """ units = {'g': 1024, 'm': 1, 't': 1 << 20, 'k': 1.0 / 1024} - if s[-1] not in units: + if s[-1].lower() not in units: raise ValueError("invalid format: " + s) return int(float(s[:-1]) * units[s[-1].lower()]) diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index a5a83c7e38..40fccb8c00 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -1966,6 +1966,18 @@ class ContextTests(unittest.TestCase): self.assertGreater(sc.startTime, 0) +class ConfTests(unittest.TestCase): + def test_memory_conf(self): + memoryList = ["1T", "1G", "1M", "1024K"] + for memory in memoryList: + sc = SparkContext(conf=SparkConf().set("spark.python.worker.memory", memory)) + l = list(range(1024)) + random.shuffle(l) + rdd = sc.parallelize(l, 4) + self.assertEqual(sorted(l), rdd.sortBy(lambda x: x).collect()) + sc.stop() + + @unittest.skipIf(not _have_scipy, "SciPy not installed") class SciPyTests(PySparkTestCase): -- cgit v1.2.3 From 064623014e0d6dfb0376722f24e81027fde649de Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Tue, 5 Apr 2016 00:30:55 -0500 Subject: [SPARK-14359] Create built-in functions for typed aggregates in Java ## What changes were proposed in this pull request? This adds the corresponding Java static functions for built-in typed aggregates already exposed in Scala. ## How was this patch tested? Unit tests. rxin Author: Eric Liang Closes #12168 from ericl/sc-2794. --- .../sql/execution/aggregate/typedaggregators.scala | 33 +++++++++++++++ .../apache/spark/sql/expressions/java/typed.java | 42 +++++++++++++++++++ .../sql/sources/JavaDatasetAggregatorSuite.java | 49 ++++++++++++++++++++++ 3 files changed, 124 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/typedaggregators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/typedaggregators.scala index 9afc29038b..7a18d0afce 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/typedaggregators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/typedaggregators.scala @@ -17,6 +17,9 @@ package org.apache.spark.sql.execution.aggregate +import org.apache.spark.api.java.function.MapFunction +import org.apache.spark.sql.TypedColumn +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.expressions.Aggregator //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -30,6 +33,8 @@ class TypedSum[IN, OUT : Numeric](f: IN => OUT) extends Aggregator[IN, OUT, OUT] override def reduce(b: OUT, a: IN): OUT = numeric.plus(b, f(a)) override def merge(b1: OUT, b2: OUT): OUT = numeric.plus(b1, b2) override def finish(reduction: OUT): OUT = reduction + + // TODO(ekl) java api support once this is exposed in scala } @@ -38,6 +43,13 @@ class TypedSumDouble[IN](f: IN => Double) extends Aggregator[IN, Double, Double] override def reduce(b: Double, a: IN): Double = b + f(a) override def merge(b1: Double, b2: Double): Double = b1 + b2 override def finish(reduction: Double): Double = reduction + + // Java api support + def this(f: MapFunction[IN, java.lang.Double]) = this(x => f.call(x).asInstanceOf[Double]) + def toColumnJava(): TypedColumn[IN, java.lang.Double] = { + toColumn(ExpressionEncoder(), ExpressionEncoder()) + .asInstanceOf[TypedColumn[IN, java.lang.Double]] + } } @@ -46,6 +58,13 @@ class TypedSumLong[IN](f: IN => Long) extends Aggregator[IN, Long, Long] { override def reduce(b: Long, a: IN): Long = b + f(a) override def merge(b1: Long, b2: Long): Long = b1 + b2 override def finish(reduction: Long): Long = reduction + + // Java api support + def this(f: MapFunction[IN, java.lang.Long]) = this(x => f.call(x).asInstanceOf[Long]) + def toColumnJava(): TypedColumn[IN, java.lang.Long] = { + toColumn(ExpressionEncoder(), ExpressionEncoder()) + .asInstanceOf[TypedColumn[IN, java.lang.Long]] + } } @@ -56,6 +75,13 @@ class TypedCount[IN](f: IN => Any) extends Aggregator[IN, Long, Long] { } override def merge(b1: Long, b2: Long): Long = b1 + b2 override def finish(reduction: Long): Long = reduction + + // Java api support + def this(f: MapFunction[IN, Object]) = this(x => f.call(x)) + def toColumnJava(): TypedColumn[IN, java.lang.Long] = { + toColumn(ExpressionEncoder(), ExpressionEncoder()) + .asInstanceOf[TypedColumn[IN, java.lang.Long]] + } } @@ -66,4 +92,11 @@ class TypedAverage[IN](f: IN => Double) extends Aggregator[IN, (Double, Long), D override def merge(b1: (Double, Long), b2: (Double, Long)): (Double, Long) = { (b1._1 + b2._1, b1._2 + b2._2) } + + // Java api support + def this(f: MapFunction[IN, java.lang.Double]) = this(x => f.call(x).asInstanceOf[Double]) + def toColumnJava(): TypedColumn[IN, java.lang.Double] = { + toColumn(ExpressionEncoder(), ExpressionEncoder()) + .asInstanceOf[TypedColumn[IN, java.lang.Double]] + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/java/typed.java b/sql/core/src/main/scala/org/apache/spark/sql/expressions/java/typed.java index cdba970d8f..8ff7b6549b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/java/typed.java +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/java/typed.java @@ -18,7 +18,13 @@ package org.apache.spark.sql.expressions.java; import org.apache.spark.annotation.Experimental; +import org.apache.spark.api.java.function.MapFunction; import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.TypedColumn; +import org.apache.spark.sql.execution.aggregate.TypedAverage; +import org.apache.spark.sql.execution.aggregate.TypedCount; +import org.apache.spark.sql.execution.aggregate.TypedSumDouble; +import org.apache.spark.sql.execution.aggregate.TypedSumLong; /** * :: Experimental :: @@ -30,5 +36,41 @@ import org.apache.spark.sql.Dataset; */ @Experimental public class typed { + // Note: make sure to keep in sync with typed.scala + /** + * Average aggregate function. + * + * @since 2.0.0 + */ + public static TypedColumn avg(MapFunction f) { + return new TypedAverage(f).toColumnJava(); + } + + /** + * Count aggregate function. + * + * @since 2.0.0 + */ + public static TypedColumn count(MapFunction f) { + return new TypedCount(f).toColumnJava(); + } + + /** + * Sum aggregate function for floating point (double) type. + * + * @since 2.0.0 + */ + public static TypedColumn sum(MapFunction f) { + return new TypedSumDouble(f).toColumnJava(); + } + + /** + * Sum aggregate function for integral (long, i.e. 64 bit integer) type. + * + * @since 2.0.0 + */ + public static TypedColumn sumLong(MapFunction f) { + return new TypedSumLong(f).toColumnJava(); + } } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaDatasetAggregatorSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaDatasetAggregatorSuite.java index c4c455b6e6..c8d0eecd5c 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaDatasetAggregatorSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaDatasetAggregatorSuite.java @@ -35,6 +35,7 @@ import org.apache.spark.sql.Encoder; import org.apache.spark.sql.Encoders; import org.apache.spark.sql.KeyValueGroupedDataset; import org.apache.spark.sql.expressions.Aggregator; +import org.apache.spark.sql.expressions.java.typed; import org.apache.spark.sql.test.TestSQLContext; /** @@ -120,4 +121,52 @@ public class JavaDatasetAggregatorSuite implements Serializable { return reduction; } } + + @Test + public void testTypedAggregationAverage() { + KeyValueGroupedDataset> grouped = generateGroupedDataset(); + Dataset> agged = grouped.agg(typed.avg( + new MapFunction, Double>() { + public Double call(Tuple2 value) throws Exception { + return (double)(value._2() * 2); + } + })); + Assert.assertEquals(Arrays.asList(tuple2("a", 3.0), tuple2("b", 6.0)), agged.collectAsList()); + } + + @Test + public void testTypedAggregationCount() { + KeyValueGroupedDataset> grouped = generateGroupedDataset(); + Dataset> agged = grouped.agg(typed.count( + new MapFunction, Object>() { + public Object call(Tuple2 value) throws Exception { + return value; + } + })); + Assert.assertEquals(Arrays.asList(tuple2("a", 2), tuple2("b", 1)), agged.collectAsList()); + } + + @Test + public void testTypedAggregationSumDouble() { + KeyValueGroupedDataset> grouped = generateGroupedDataset(); + Dataset> agged = grouped.agg(typed.sum( + new MapFunction, Double>() { + public Double call(Tuple2 value) throws Exception { + return (double)value._2(); + } + })); + Assert.assertEquals(Arrays.asList(tuple2("a", 3.0), tuple2("b", 3.0)), agged.collectAsList()); + } + + @Test + public void testTypedAggregationSumLong() { + KeyValueGroupedDataset> grouped = generateGroupedDataset(); + Dataset> agged = grouped.agg(typed.sumLong( + new MapFunction, Long>() { + public Long call(Tuple2 value) throws Exception { + return (long)value._2(); + } + })); + Assert.assertEquals(Arrays.asList(tuple2("a", 3), tuple2("b", 3)), agged.collectAsList()); + } } -- cgit v1.2.3 From 2715bc68bd1661d207b1af5f44ae8d02aec9d4ec Mon Sep 17 00:00:00 2001 From: Dilip Biswal Date: Tue, 5 Apr 2016 08:41:59 +0200 Subject: [SPARK-14348][SQL] Support native execution of SHOW TBLPROPERTIES command ## What changes were proposed in this pull request? This PR adds Native execution of SHOW TBLPROPERTIES command. Command Syntax: ``` SQL SHOW TBLPROPERTIES table_name[(property_key_literal)] ``` ## How was this patch tested? Tests added in HiveComandSuiie and DDLCommandSuite Author: Dilip Biswal Closes #12133 from dilipbiswal/dkb_show_tblproperties. --- .../apache/spark/sql/catalyst/parser/SqlBase.g4 | 2 + .../sql/catalyst/catalog/SessionCatalog.scala | 11 ++ .../spark/sql/execution/SparkSqlParser.scala | 37 ++++-- .../spark/sql/execution/command/commands.scala | 44 +++++++- .../sql/execution/command/DDLCommandSuite.scala | 8 ++ .../sql/hive/execution/HiveCommandSuite.scala | 125 +++++++++++++++++++++ .../spark/sql/hive/execution/SQLQuerySuite.scala | 22 ---- 7 files changed, 219 insertions(+), 30 deletions(-) create mode 100644 sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveCommandSuite.scala diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index 6cf47b5c30..27b01e0bed 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -116,6 +116,8 @@ statement | SHOW TABLES ((FROM | IN) db=identifier)? (LIKE? pattern=STRING)? #showTables | SHOW DATABASES (LIKE pattern=STRING)? #showDatabases + | SHOW TBLPROPERTIES table=tableIdentifier + ('(' key=tablePropertyKey ')')? #showTblProperties | SHOW FUNCTIONS (LIKE? (qualifiedName | pattern=STRING))? #showFunctions | (DESC | DESCRIBE) FUNCTION EXTENDED? qualifiedName #describeFunction | (DESC | DESCRIBE) option=(EXTENDED | FORMATTED)? diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala index 569b99e414..3b8ce6373d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala @@ -163,6 +163,7 @@ class SessionCatalog( /** * Retrieve the metadata of an existing metastore table. * If no database is specified, assume the table is in the current database. + * If the specified table is not found in the database then an [[AnalysisException]] is thrown. */ def getTable(name: TableIdentifier): CatalogTable = { val db = name.database.getOrElse(currentDb) @@ -271,6 +272,16 @@ class SessionCatalog( } } + /** + * Return whether a table with the specified name is a temporary table. + * + * Note: The temporary table cache is checked only when database is not + * explicitly specified. + */ + def isTemporaryTable(name: TableIdentifier): Boolean = { + !name.database.isDefined && tempTables.contains(formatTableName(name.table)) + } + /** * List all tables in the specified database, including temporary tables. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala index ff3ab7746c..fb106d1aef 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala @@ -92,6 +92,22 @@ class SparkSqlAstBuilder extends AstBuilder { ShowDatabasesCommand(Option(ctx.pattern).map(string)) } + /** + * A command for users to list the properties for a table. If propertyKey is specified, the value + * for the propertyKey is returned. If propertyKey is not specified, all the keys and their + * corresponding values are returned. + * The syntax of using this command in SQL is: + * {{{ + * SHOW TBLPROPERTIES table_name[('propertyKey')]; + * }}} + */ + override def visitShowTblProperties( + ctx: ShowTblPropertiesContext): LogicalPlan = withOrigin(ctx) { + ShowTablePropertiesCommand( + visitTableIdentifier(ctx.tableIdentifier), + Option(ctx.key).map(visitTablePropertyKey)) + } + /** * Create a [[RefreshTable]] logical plan. */ @@ -220,18 +236,25 @@ class SparkSqlAstBuilder extends AstBuilder { override def visitTablePropertyList( ctx: TablePropertyListContext): Map[String, String] = withOrigin(ctx) { ctx.tableProperty.asScala.map { property => - // A key can either be a String or a collection of dot separated elements. We need to treat - // these differently. - val key = if (property.key.STRING != null) { - string(property.key.STRING) - } else { - property.key.getText - } + val key = visitTablePropertyKey(property.key) val value = Option(property.value).map(string).orNull key -> value }.toMap } + /** + * A table property key can either be String or a collection of dot separated elements. This + * function extracts the property key based on whether its a string literal or a table property + * identifier. + */ + override def visitTablePropertyKey(key: TablePropertyKeyContext): String = { + if (key.STRING != null) { + string(key.STRING) + } else { + key.getText + } + } + /** * Create a [[CreateDatabase]] command. * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala index 4eb8d7ff0d..a4be3bc333 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala @@ -21,7 +21,7 @@ import java.util.NoSuchElementException import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{Dataset, Row, SQLContext} +import org.apache.spark.sql.{AnalysisException, Dataset, Row, SQLContext} import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow, TableIdentifier} import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} @@ -380,6 +380,48 @@ case class ShowDatabasesCommand(databasePattern: Option[String]) extends Runnabl } } +/** + * A command for users to list the properties for a table If propertyKey is specified, the value + * for the propertyKey is returned. If propertyKey is not specified, all the keys and their + * corresponding values are returned. + * The syntax of using this command in SQL is: + * {{{ + * SHOW TBLPROPERTIES table_name[('propertyKey')]; + * }}} + */ +case class ShowTablePropertiesCommand( + table: TableIdentifier, + propertyKey: Option[String]) extends RunnableCommand { + + override val output: Seq[Attribute] = { + val schema = AttributeReference("value", StringType, nullable = false)() :: Nil + propertyKey match { + case None => AttributeReference("key", StringType, nullable = false)() :: schema + case _ => schema + } + } + + override def run(sqlContext: SQLContext): Seq[Row] = { + val catalog = sqlContext.sessionState.catalog + + if (catalog.isTemporaryTable(table)) { + Seq.empty[Row] + } else { + val catalogTable = sqlContext.sessionState.catalog.getTable(table) + + propertyKey match { + case Some(p) => + val propValue = catalogTable + .properties + .getOrElse(p, s"Table ${catalogTable.qualifiedName} does not have property: $p") + Seq(Row(propValue)) + case None => + catalogTable.properties.map(p => Row(p._1, p._2)).toSeq + } + } + } +} + /** * A command for users to list all of the registered functions. * The syntax of using this command in SQL is: diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala index 458f36e832..8b2a5979e2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala @@ -773,4 +773,12 @@ class DDLCommandSuite extends PlanTest { comparePlans(parsed2, expected2) } + test("show tblproperties") { + val parsed1 = parser.parsePlan("SHOW TBLPROPERTIES tab1") + val expected1 = ShowTablePropertiesCommand(TableIdentifier("tab1", None), None) + val parsed2 = parser.parsePlan("SHOW TBLPROPERTIES tab1('propKey1')") + val expected2 = ShowTablePropertiesCommand(TableIdentifier("tab1", None), Some("propKey1")) + comparePlans(parsed1, expected1) + comparePlans(parsed2, expected2) + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveCommandSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveCommandSuite.scala new file mode 100644 index 0000000000..4c3f450522 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveCommandSuite.scala @@ -0,0 +1,125 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.execution + +import org.apache.spark.sql.{AnalysisException, QueryTest, Row} +import org.apache.spark.sql.hive.test.TestHiveSingleton +import org.apache.spark.sql.test.SQLTestUtils + +class HiveCommandSuite extends QueryTest with SQLTestUtils with TestHiveSingleton { + protected override def beforeAll(): Unit = { + super.beforeAll() + sql( + """ + |CREATE EXTERNAL TABLE parquet_tab1 (c1 INT, c2 STRING) + |USING org.apache.spark.sql.parquet.DefaultSource + """.stripMargin) + + sql( + """ + |CREATE EXTERNAL TABLE parquet_tab2 (c1 INT, c2 STRING) + |STORED AS PARQUET + |TBLPROPERTIES('prop1Key'="prop1Val", '`prop2Key`'="prop2Val") + """.stripMargin) + } + + override protected def afterAll(): Unit = { + try { + sql("DROP TABLE IF EXISTS parquet_tab1") + sql("DROP TABLE IF EXISTS parquet_tab2") + } finally { + super.afterAll() + } + } + + test("show tables") { + withTable("show1a", "show2b") { + sql("CREATE TABLE show1a(c1 int)") + sql("CREATE TABLE show2b(c2 int)") + checkAnswer( + sql("SHOW TABLES IN default 'show1*'"), + Row("show1a", false) :: Nil) + checkAnswer( + sql("SHOW TABLES IN default 'show1*|show2*'"), + Row("show1a", false) :: + Row("show2b", false) :: Nil) + checkAnswer( + sql("SHOW TABLES 'show1*|show2*'"), + Row("show1a", false) :: + Row("show2b", false) :: Nil) + assert( + sql("SHOW TABLES").count() >= 2) + assert( + sql("SHOW TABLES IN default").count() >= 2) + } + } + + test("show tblproperties of data source tables - basic") { + checkAnswer( + sql("SHOW TBLPROPERTIES parquet_tab1") + .filter(s"key = 'spark.sql.sources.provider'"), + Row("spark.sql.sources.provider", "org.apache.spark.sql.parquet.DefaultSource") :: Nil + ) + + checkAnswer( + sql("SHOW TBLPROPERTIES parquet_tab1(spark.sql.sources.provider)"), + Row("org.apache.spark.sql.parquet.DefaultSource") :: Nil + ) + + checkAnswer( + sql("SHOW TBLPROPERTIES parquet_tab1") + .filter(s"key = 'spark.sql.sources.schema.numParts'"), + Row("spark.sql.sources.schema.numParts", "1") :: Nil + ) + + checkAnswer( + sql("SHOW TBLPROPERTIES parquet_tab1('spark.sql.sources.schema.numParts')"), + Row("1")) + } + + test("show tblproperties for datasource table - errors") { + val message1 = intercept[AnalysisException] { + sql("SHOW TBLPROPERTIES badtable") + }.getMessage + assert(message1.contains("Table badtable not found in database default")) + + // When key is not found, a row containing the error is returned. + checkAnswer( + sql("SHOW TBLPROPERTIES parquet_tab1('invalid.prop.key')"), + Row("Table default.parquet_tab1 does not have property: invalid.prop.key") :: Nil + ) + } + + test("show tblproperties for hive table") { + checkAnswer(sql("SHOW TBLPROPERTIES parquet_tab2('prop1Key')"), Row("prop1Val")) + checkAnswer(sql("SHOW TBLPROPERTIES parquet_tab2('`prop2Key`')"), Row("prop2Val")) + } + + test("show tblproperties for spark temporary table - empty row") { + withTempTable("parquet_temp") { + sql( + """ + |CREATE TEMPORARY TABLE parquet_temp (c1 INT, c2 STRING) + |USING org.apache.spark.sql.parquet.DefaultSource + """.stripMargin) + + // An empty sequence of row is returned for session temporary table. + checkAnswer(sql("SHOW TBLPROPERTIES parquet_temp"), Nil) + } + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index c203518fdd..6199253d34 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -1811,26 +1811,4 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { } } } - - test("show tables") { - withTable("show1a", "show2b") { - sql("CREATE TABLE show1a(c1 int)") - sql("CREATE TABLE show2b(c2 int)") - checkAnswer( - sql("SHOW TABLES IN default 'show1*'"), - Row("show1a", false) :: Nil) - checkAnswer( - sql("SHOW TABLES IN default 'show1*|show2*'"), - Row("show1a", false) :: - Row("show2b", false) :: Nil) - checkAnswer( - sql("SHOW TABLES 'show1*|show2*'"), - Row("show1a", false) :: - Row("show2b", false) :: Nil) - assert( - sql("SHOW TABLES").count() >= 2) - assert( - sql("SHOW TABLES IN default").count() >= 2) - } - } } -- cgit v1.2.3 From 78071736799b6c86b5c01b27395f4ab87075342b Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Tue, 5 Apr 2016 11:19:46 +0200 Subject: [SPARK-14349][SQL] Issue Error Messages for Unsupported Operators/DML/DDL in SQL Context. #### What changes were proposed in this pull request? Currently, the weird error messages are issued if we use Hive Context-only operations in SQL Context. For example, - When calling `Drop Table` in SQL Context, we got the following message: ``` Expected exception org.apache.spark.sql.catalyst.parser.ParseException to be thrown, but java.lang.ClassCastException was thrown. ``` - When calling `Script Transform` in SQL Context, we got the message: ``` assertion failed: No plan for ScriptTransformation [key#9,value#10], cat, [tKey#155,tValue#156], null +- LogicalRDD [key#9,value#10], MapPartitionsRDD[3] at beforeAll at BeforeAndAfterAll.scala:187 ``` Updates: Based on the investigation from hvanhovell , the root cause is `visitChildren`, which is the default implementation. It always returns the result of the last defined context child. After merging the code changes from hvanhovell , it works! Thank you hvanhovell ! #### How was this patch tested? A few test cases are added. Not sure if the same issue exist for the other operators/DDL/DML. hvanhovell Author: gatorsmile Author: xiaoli Author: Herman van Hovell Author: Xiao Li Closes #12134 from gatorsmile/hiveParserCommand. --- .../apache/spark/sql/catalyst/parser/SqlBase.g4 | 14 +-- .../spark/sql/catalyst/parser/AstBuilder.scala | 118 +++++++++++---------- .../sql/catalyst/parser/PlanParserSuite.scala | 10 -- .../sql/execution/command/DDLCommandSuite.scala | 23 ++++ .../spark/sql/hive/execution/HiveSqlParser.scala | 16 ++- .../org/apache/spark/sql/hive/HiveQlSuite.scala | 31 +++++- 6 files changed, 133 insertions(+), 79 deletions(-) diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index 27b01e0bed..96c170be3d 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -468,15 +468,15 @@ booleanExpression // https://github.com/antlr/antlr4/issues/780 // https://github.com/antlr/antlr4/issues/781 predicated - : valueExpression predicate[$valueExpression.ctx]? + : valueExpression predicate? ; -predicate[ParserRuleContext value] - : NOT? BETWEEN lower=valueExpression AND upper=valueExpression #between - | NOT? IN '(' expression (',' expression)* ')' #inList - | NOT? IN '(' query ')' #inSubquery - | NOT? like=(RLIKE | LIKE) pattern=valueExpression #like - | IS NOT? NULL #nullPredicate +predicate + : NOT? kind=BETWEEN lower=valueExpression AND upper=valueExpression + | NOT? kind=IN '(' expression (',' expression)* ')' + | NOT? kind=IN '(' query ')' + | NOT? kind=(RLIKE | LIKE) pattern=valueExpression + | IS NOT? kind=NULL ; valueExpression diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 61ea3e4010..14c90918e6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -22,7 +22,7 @@ import scala.collection.JavaConverters._ import scala.collection.mutable.ArrayBuffer import org.antlr.v4.runtime.{ParserRuleContext, Token} -import org.antlr.v4.runtime.tree.{ParseTree, TerminalNode} +import org.antlr.v4.runtime.tree.{ParseTree, RuleNode, TerminalNode} import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.{InternalRow, TableIdentifier} @@ -46,6 +46,19 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { ctx.accept(this).asInstanceOf[T] } + /** + * Override the default behavior for all visit methods. This will only return a non-null result + * when the context has only one child. This is done because there is no generic method to + * combine the results of the context children. In all other cases null is returned. + */ + override def visitChildren(node: RuleNode): AnyRef = { + if (node.getChildCount == 1) { + node.getChild(0).accept(this) + } else { + null + } + } + override def visitSingleStatement(ctx: SingleStatementContext): LogicalPlan = withOrigin(ctx) { visit(ctx.statement).asInstanceOf[LogicalPlan] } @@ -351,7 +364,8 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { string(script), attributes, withFilter, - withScriptIOSchema(inRowFormat, recordWriter, outRowFormat, recordReader, schemaLess)) + withScriptIOSchema( + ctx, inRowFormat, recordWriter, outRowFormat, recordReader, schemaLess)) case SqlBaseParser.SELECT => // Regular select @@ -398,11 +412,14 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { * Create a (Hive based) [[ScriptInputOutputSchema]]. */ protected def withScriptIOSchema( + ctx: QuerySpecificationContext, inRowFormat: RowFormatContext, recordWriter: Token, outRowFormat: RowFormatContext, recordReader: Token, - schemaLess: Boolean): ScriptInputOutputSchema = null + schemaLess: Boolean): ScriptInputOutputSchema = { + throw new ParseException("Script Transform is not supported", ctx) + } /** * Create a logical plan for a given 'FROM' clause. Note that we support multiple (comma @@ -778,17 +795,6 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { trees.asScala.map(expression) } - /** - * Invert a boolean expression if it has a valid NOT clause. - */ - private def invertIfNotDefined(expression: Expression, not: TerminalNode): Expression = { - if (not != null) { - Not(expression) - } else { - expression - } - } - /** * Create a star (i.e. all) expression; this selects all elements (in the specified object). * Both un-targeted (global) and targeted aliases are supported. @@ -909,57 +915,55 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { } /** - * Create a BETWEEN expression. This tests if an expression lies with in the bounds set by two - * other expressions. The inverse can also be created. - */ - override def visitBetween(ctx: BetweenContext): Expression = withOrigin(ctx) { - val value = expression(ctx.value) - val between = And( - GreaterThanOrEqual(value, expression(ctx.lower)), - LessThanOrEqual(value, expression(ctx.upper))) - invertIfNotDefined(between, ctx.NOT) - } - - /** - * Create an IN expression. This tests if the value of the left hand side expression is - * contained by the sequence of expressions on the right hand side. + * Create a predicated expression. A predicated expression is a normal expression with a + * predicate attached to it, for example: + * {{{ + * a + 1 IS NULL + * }}} */ - override def visitInList(ctx: InListContext): Expression = withOrigin(ctx) { - val in = In(expression(ctx.value), ctx.expression().asScala.map(expression)) - invertIfNotDefined(in, ctx.NOT) + override def visitPredicated(ctx: PredicatedContext): Expression = withOrigin(ctx) { + val e = expression(ctx.valueExpression) + if (ctx.predicate != null) { + withPredicate(e, ctx.predicate) + } else { + e + } } /** - * Create an IN expression, where the the right hand side is a query. This is unsupported. + * Add a predicate to the given expression. Supported expressions are: + * - (NOT) BETWEEN + * - (NOT) IN + * - (NOT) LIKE + * - (NOT) RLIKE + * - IS (NOT) NULL. */ - override def visitInSubquery(ctx: InSubqueryContext): Expression = { - throw new ParseException("IN with a Sub-query is currently not supported.", ctx) - } + private def withPredicate(e: Expression, ctx: PredicateContext): Expression = withOrigin(ctx) { + // Invert a predicate if it has a valid NOT clause. + def invertIfNotDefined(e: Expression): Expression = ctx.NOT match { + case null => e + case not => Not(e) + } - /** - * Create a (R)LIKE/REGEXP expression. - */ - override def visitLike(ctx: LikeContext): Expression = { - val left = expression(ctx.value) - val right = expression(ctx.pattern) - val like = ctx.like.getType match { + // Create the predicate. + ctx.kind.getType match { + case SqlBaseParser.BETWEEN => + // BETWEEN is translated to lower <= e && e <= upper + invertIfNotDefined(And( + GreaterThanOrEqual(e, expression(ctx.lower)), + LessThanOrEqual(e, expression(ctx.upper)))) + case SqlBaseParser.IN if ctx.query != null => + throw new ParseException("IN with a Sub-query is currently not supported.", ctx) + case SqlBaseParser.IN => + invertIfNotDefined(In(e, ctx.expression.asScala.map(expression))) case SqlBaseParser.LIKE => - Like(left, right) + invertIfNotDefined(Like(e, expression(ctx.pattern))) case SqlBaseParser.RLIKE => - RLike(left, right) - } - invertIfNotDefined(like, ctx.NOT) - } - - /** - * Create an IS (NOT) NULL expression. - */ - override def visitNullPredicate(ctx: NullPredicateContext): Expression = withOrigin(ctx) { - val value = expression(ctx.value) - if (ctx.NOT != null) { - IsNotNull(value) - } else { - IsNull(value) + invertIfNotDefined(RLike(e, expression(ctx.pattern))) + case SqlBaseParser.NULL if ctx.NOT != null => + IsNotNull(e) + case SqlBaseParser.NULL => + IsNull(e) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala index 23f05ce846..9e1660df06 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala @@ -122,16 +122,6 @@ class PlanParserSuite extends PlanTest { table("a").union(table("b")).as("c").select(star())) } - test("transform query spec") { - val p = ScriptTransformation(Seq('a, 'b), "func", Seq.empty, table("e"), null) - assertEqual("select transform(a, b) using 'func' from e where f < 10", - p.copy(child = p.child.where('f < 10), output = Seq('key.string, 'value.string))) - assertEqual("map a, b using 'func' as c, d from e", - p.copy(output = Seq('c.string, 'd.string))) - assertEqual("reduce a, b using 'func' as (c: int, d decimal(10, 0)) from e", - p.copy(output = Seq('c.int, 'd.decimal(10, 0)))) - } - test("multi select query") { assertEqual( "from a select * select * where s < 10", diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala index 8b2a5979e2..47e295a7e7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution.command import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.execution.SparkSqlParser import org.apache.spark.sql.execution.datasources.BucketSpec @@ -781,4 +782,26 @@ class DDLCommandSuite extends PlanTest { comparePlans(parsed1, expected1) comparePlans(parsed2, expected2) } + + test("commands only available in HiveContext") { + intercept[ParseException] { + parser.parsePlan("DROP TABLE D1.T1") + } + intercept[ParseException] { + parser.parsePlan("CREATE VIEW testView AS SELECT id FROM tab") + } + intercept[ParseException] { + parser.parsePlan("ALTER VIEW testView AS SELECT id FROM tab") + } + intercept[ParseException] { + parser.parsePlan( + """ + |CREATE EXTERNAL TABLE parquet_tab2(c1 INT, c2 STRING) + |TBLPROPERTIES('prop1Key '= "prop1Val", ' `prop2Key` '= "prop2Val") + """.stripMargin) + } + intercept[ParseException] { + parser.parsePlan("SELECT TRANSFORM (key, value) USING 'cat' AS (tKey, tValue) FROM testData") + } + } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveSqlParser.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveSqlParser.scala index 12e4f49756..55e69f99a4 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveSqlParser.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveSqlParser.scala @@ -133,6 +133,18 @@ class HiveSqlAstBuilder extends SparkSqlAstBuilder { } } + /** + * Create a [[CatalogStorageFormat]]. This is part of the [[CreateTableAsSelect]] command. + */ + override def visitCreateFileFormat( + ctx: CreateFileFormatContext): CatalogStorageFormat = withOrigin(ctx) { + if (ctx.storageHandler == null) { + typedVisit[CatalogStorageFormat](ctx.fileFormat) + } else { + visitStorageHandler(ctx.storageHandler) + } + } + /** * Create a [[CreateTableAsSelect]] command. */ @@ -282,6 +294,7 @@ class HiveSqlAstBuilder extends SparkSqlAstBuilder { * Create a [[HiveScriptIOSchema]]. */ override protected def withScriptIOSchema( + ctx: QuerySpecificationContext, inRowFormat: RowFormatContext, recordWriter: Token, outRowFormat: RowFormatContext, @@ -391,7 +404,8 @@ class HiveSqlAstBuilder extends SparkSqlAstBuilder { /** * Storage Handlers are currently not supported in the statements we support (CTAS). */ - override def visitStorageHandler(ctx: StorageHandlerContext): AnyRef = withOrigin(ctx) { + override def visitStorageHandler( + ctx: StorageHandlerContext): CatalogStorageFormat = withOrigin(ctx) { throw new ParseException("Storage Handlers are currently unsupported.", ctx) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveQlSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveQlSuite.scala index 75108c6d47..a8a0d6b8de 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveQlSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveQlSuite.scala @@ -18,16 +18,19 @@ package org.apache.spark.sql.hive import org.apache.hadoop.hive.serde.serdeConstants -import org.scalatest.BeforeAndAfterAll -import org.apache.spark.SparkFunSuite import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute import org.apache.spark.sql.catalyst.catalog.{CatalogColumn, CatalogTable, CatalogTableType} +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans +import org.apache.spark.sql.catalyst.dsl.plans.DslLogicalPlan import org.apache.spark.sql.catalyst.expressions.JsonTuple -import org.apache.spark.sql.catalyst.plans.logical.Generate +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical.{Generate, ScriptTransformation} import org.apache.spark.sql.hive.execution.HiveSqlParser -class HiveQlSuite extends SparkFunSuite with BeforeAndAfterAll { +class HiveQlSuite extends PlanTest { val parser = HiveSqlParser private def extractTableDesc(sql: String): (CatalogTable, Boolean) = { @@ -201,6 +204,26 @@ class HiveQlSuite extends SparkFunSuite with BeforeAndAfterAll { assert(plan.children.head.asInstanceOf[Generate].generator.isInstanceOf[JsonTuple]) } + test("transform query spec") { + val plan1 = parser.parsePlan("select transform(a, b) using 'func' from e where f < 10") + .asInstanceOf[ScriptTransformation].copy(ioschema = null) + val plan2 = parser.parsePlan("map a, b using 'func' as c, d from e") + .asInstanceOf[ScriptTransformation].copy(ioschema = null) + val plan3 = parser.parsePlan("reduce a, b using 'func' as (c: int, d decimal(10, 0)) from e") + .asInstanceOf[ScriptTransformation].copy(ioschema = null) + + val p = ScriptTransformation( + Seq(UnresolvedAttribute("a"), UnresolvedAttribute("b")), + "func", Seq.empty, plans.table("e"), null) + + comparePlans(plan1, + p.copy(child = p.child.where('f < 10), output = Seq('key.string, 'value.string))) + comparePlans(plan2, + p.copy(output = Seq('c.string, 'd.string))) + comparePlans(plan3, + p.copy(output = Seq('c.int, 'd.decimal(10, 0)))) + } + test("use backticks in output of Script Transform") { val plan = parser.parsePlan( """SELECT `t`.`thing1` -- cgit v1.2.3 From d35690158810465809679ef39548e1400b38d448 Mon Sep 17 00:00:00 2001 From: Shally Sangal Date: Tue, 5 Apr 2016 10:41:59 -0700 Subject: [SPARK-14284][ML] KMeansSummary deprecating size; adding clusterSizes ## What changes were proposed in this pull request? KMeansSummary class : deprecated size and added clusterSizes Author: Shally Sangal Closes #12084 from shallys/master. --- mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala | 3 ++- mllib/src/main/scala/org/apache/spark/ml/r/KMeansWrapper.scala | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala index 38428826a8..a8beef8b12 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala @@ -299,7 +299,8 @@ class KMeansSummary private[clustering] ( * Size of each cluster. */ @Since("2.0.0") - lazy val size: Array[Int] = cluster.rdd.map { + lazy val clusterSizes: Array[Int] = cluster.rdd.map { case Row(clusterIdx: Int) => (clusterIdx, 1) }.reduceByKey(_ + _).collect().sortBy(_._1).map(_._2) + } diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/KMeansWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/KMeansWrapper.scala index d3a0df4063..ed735a4ea3 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/KMeansWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/KMeansWrapper.scala @@ -37,7 +37,7 @@ private[r] class KMeansWrapper private ( lazy val k: Int = kMeansModel.getK - lazy val size: Array[Int] = kMeansModel.summary.size + lazy val size: Array[Int] = kMeansModel.summary.clusterSizes lazy val cluster: DataFrame = kMeansModel.summary.cluster -- cgit v1.2.3 From e4bd50412043c1ed2816406ba8d2af4f775ee3cf Mon Sep 17 00:00:00 2001 From: Kousuke Saruta Date: Tue, 5 Apr 2016 10:51:23 -0700 Subject: [SPARK-14397][WEBUI] and tags are nested in LogPage ## What changes were proposed in this pull request? In `LogPage`, the content to be rendered is defined as follows. ``` val content = {linkToMaster}
    {backButton}
    {range}
    {nextButton}

    {logText}
    UIUtils.basicSparkPage(content, logType + " log page for " + pageName) ``` As you can see, and tags will be rendered. On the other hand, `UIUtils.basicSparkPage` will render those tags so those tags will be nested. ``` def basicSparkPage( content: => Seq[Node], title: String, useDataTables: Boolean = false): Seq[Node] = { {commonHeaderNodes} {if (useDataTables) dataTablesHeaderNodes else Seq.empty} {title} } ``` These are the screen shots before this patch is applied. ![before1](https://cloud.githubusercontent.com/assets/4736016/14273236/03cbed8a-fb44-11e5-8786-bc1bfa4d3f8c.png) ![before2](https://cloud.githubusercontent.com/assets/4736016/14273237/03d1741c-fb44-11e5-9dee-ea93022033a6.png) And these are the ones after this patch is applied. ![after1](https://cloud.githubusercontent.com/assets/4736016/14273248/1b6a7d8a-fb44-11e5-8a3b-69964f3434f6.png) ![after2](https://cloud.githubusercontent.com/assets/4736016/14273249/1b6b9c38-fb44-11e5-9d6f-281d64c842e4.png) The appearance is not changed but the html source code is changed. ## How was this patch tested? Manually run some jobs on my standalone-cluster and check the WebUI. Author: Kousuke Saruta Closes #12170 from sarutak/SPARK-14397. --- .../apache/spark/deploy/worker/ui/LogPage.scala | 26 ++++++++++------------ 1 file changed, 12 insertions(+), 14 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ui/LogPage.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ui/LogPage.scala index 6500cab73b..e75c0cec4a 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ui/LogPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ui/LogPage.scala @@ -107,20 +107,18 @@ private[ui] class LogPage(parent: WorkerWebUI) extends WebUIPage("logPage") with } val content = - - - {linkToMaster} -
    -
    {backButton}
    -
    {range}
    -
    {nextButton}
    -
    -
    -
    -
    {logText}
    -
    - - +
    + {linkToMaster} +
    +
    {backButton}
    +
    {range}
    +
    {nextButton}
    +
    +
    +
    +
    {logText}
    +
    +
    UIUtils.basicSparkPage(content, logType + " log page for " + pageName) } -- cgit v1.2.3 From f77f11c67125fdac2e6849a4d45d9286fc872ed9 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 5 Apr 2016 10:53:54 -0700 Subject: [SPARK-14345][SQL] Decouple deserializer expression resolution from ObjectOperator ## What changes were proposed in this pull request? This PR decouples deserializer expression resolution from `ObjectOperator`, so that we can use deserializer expression in normal operators. This is needed by #12061 and #12067 , I abstracted the logic out and put them in this PR to reduce code change in the future. ## How was this patch tested? existing tests. Author: Wenchen Fan Closes #12131 from cloud-fan/separate. --- .../spark/sql/catalyst/analysis/Analyzer.scala | 183 +++++++++++---------- .../spark/sql/catalyst/analysis/unresolved.scala | 22 +++ .../sql/catalyst/encoders/ExpressionEncoder.scala | 8 +- .../spark/sql/catalyst/expressions/objects.scala | 14 +- .../spark/sql/catalyst/plans/logical/object.scala | 52 ++---- 5 files changed, 153 insertions(+), 126 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index a6e317ebf0..3e0a6d29b4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql.catalyst.analysis -import java.lang.reflect.Modifier - import scala.annotation.tailrec import scala.collection.mutable.ArrayBuffer @@ -87,9 +85,11 @@ class Analyzer( Batch("Resolution", fixedPoint, ResolveRelations :: ResolveReferences :: + ResolveDeserializer :: + ResolveNewInstance :: + ResolveUpCast :: ResolveGroupingAnalytics :: ResolvePivot :: - ResolveUpCast :: ResolveOrdinalInOrderByAndGroupBy :: ResolveSortReferences :: ResolveGenerate :: @@ -499,18 +499,9 @@ class Analyzer( Generate(newG.asInstanceOf[Generator], join, outer, qualifier, output, child) } - // A special case for ObjectOperator, because the deserializer expressions in ObjectOperator - // should be resolved by their corresponding attributes instead of children's output. - case o: ObjectOperator if containsUnresolvedDeserializer(o.deserializers.map(_._1)) => - val deserializerToAttributes = o.deserializers.map { - case (deserializer, attributes) => new TreeNodeRef(deserializer) -> attributes - }.toMap - - o.transformExpressions { - case expr => deserializerToAttributes.get(new TreeNodeRef(expr)).map { attributes => - resolveDeserializer(expr, attributes) - }.getOrElse(expr) - } + // Skips plan which contains deserializer expressions, as they should be resolved by another + // rule: ResolveDeserializer. + case plan if containsDeserializer(plan.expressions) => plan case q: LogicalPlan => logTrace(s"Attempting to resolve ${q.simpleString}") @@ -526,38 +517,6 @@ class Analyzer( } } - private def containsUnresolvedDeserializer(exprs: Seq[Expression]): Boolean = { - exprs.exists { expr => - !expr.resolved || expr.find(_.isInstanceOf[BoundReference]).isDefined - } - } - - def resolveDeserializer( - deserializer: Expression, - attributes: Seq[Attribute]): Expression = { - val unbound = deserializer transform { - case b: BoundReference => attributes(b.ordinal) - } - - resolveExpression(unbound, LocalRelation(attributes), throws = true) transform { - case n: NewInstance - // If this is an inner class of another class, register the outer object in `OuterScopes`. - // Note that static inner classes (e.g., inner classes within Scala objects) don't need - // outer pointer registration. - if n.outerPointer.isEmpty && - n.cls.isMemberClass && - !Modifier.isStatic(n.cls.getModifiers) => - val outer = OuterScopes.getOuterScope(n.cls) - if (outer == null) { - throw new AnalysisException( - s"Unable to generate an encoder for inner class `${n.cls.getName}` without " + - "access to the scope that this class was defined in.\n" + - "Try moving this class out of its parent class.") - } - n.copy(outerPointer = Some(outer)) - } - } - def newAliases(expressions: Seq[NamedExpression]): Seq[NamedExpression] = { expressions.map { case a: Alias => Alias(a.child, a.name)(isGenerated = a.isGenerated) @@ -623,6 +582,10 @@ class Analyzer( } } + private def containsDeserializer(exprs: Seq[Expression]): Boolean = { + exprs.exists(_.find(_.isInstanceOf[UnresolvedDeserializer]).isDefined) + } + protected[sql] def resolveExpression( expr: Expression, plan: LogicalPlan, @@ -1475,7 +1438,94 @@ class Analyzer( Project(projectList, Join(left, right, joinType, newCondition)) } + /** + * Replaces [[UnresolvedDeserializer]] with the deserialization expression that has been resolved + * to the given input attributes. + */ + object ResolveDeserializer extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + case p if !p.childrenResolved => p + case p if p.resolved => p + case p => p transformExpressions { + case UnresolvedDeserializer(deserializer, inputAttributes) => + val inputs = if (inputAttributes.isEmpty) { + p.children.flatMap(_.output) + } else { + inputAttributes + } + val unbound = deserializer transform { + case b: BoundReference => inputs(b.ordinal) + } + resolveExpression(unbound, LocalRelation(inputs), throws = true) + } + } + } + + /** + * Resolves [[NewInstance]] by finding and adding the outer scope to it if the object being + * constructed is an inner class. + */ + object ResolveNewInstance extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + case p if !p.childrenResolved => p + case p if p.resolved => p + + case p => p transformExpressions { + case n: NewInstance if n.childrenResolved && !n.resolved => + val outer = OuterScopes.getOuterScope(n.cls) + if (outer == null) { + throw new AnalysisException( + s"Unable to generate an encoder for inner class `${n.cls.getName}` without " + + "access to the scope that this class was defined in.\n" + + "Try moving this class out of its parent class.") + } + n.copy(outerPointer = Some(outer)) + } + } + } + + /** + * Replace the [[UpCast]] expression by [[Cast]], and throw exceptions if the cast may truncate. + */ + object ResolveUpCast extends Rule[LogicalPlan] { + private def fail(from: Expression, to: DataType, walkedTypePath: Seq[String]) = { + throw new AnalysisException(s"Cannot up cast ${from.sql} from " + + s"${from.dataType.simpleString} to ${to.simpleString} as it may truncate\n" + + "The type path of the target object is:\n" + walkedTypePath.mkString("", "\n", "\n") + + "You can either add an explicit cast to the input data or choose a higher precision " + + "type of the field in the target object") + } + + private def illegalNumericPrecedence(from: DataType, to: DataType): Boolean = { + val fromPrecedence = HiveTypeCoercion.numericPrecedence.indexOf(from) + val toPrecedence = HiveTypeCoercion.numericPrecedence.indexOf(to) + toPrecedence > 0 && fromPrecedence > toPrecedence + } + + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + case p if !p.childrenResolved => p + case p if p.resolved => p + + case p => p transformExpressions { + case u @ UpCast(child, _, _) if !child.resolved => u + + case UpCast(child, dataType, walkedTypePath) => (child.dataType, dataType) match { + case (from: NumericType, to: DecimalType) if !to.isWiderThan(from) => + fail(child, to, walkedTypePath) + case (from: DecimalType, to: NumericType) if !from.isTighterThan(to) => + fail(child, to, walkedTypePath) + case (from, to) if illegalNumericPrecedence(from, to) => + fail(child, to, walkedTypePath) + case (TimestampType, DateType) => + fail(child, DateType, walkedTypePath) + case (StringType, to: NumericType) => + fail(child, to, walkedTypePath) + case _ => Cast(child, dataType.asNullable) + } + } + } + } } /** @@ -1559,45 +1609,6 @@ object CleanupAliases extends Rule[LogicalPlan] { } } -/** - * Replace the `UpCast` expression by `Cast`, and throw exceptions if the cast may truncate. - */ -object ResolveUpCast extends Rule[LogicalPlan] { - private def fail(from: Expression, to: DataType, walkedTypePath: Seq[String]) = { - throw new AnalysisException(s"Cannot up cast ${from.sql} from " + - s"${from.dataType.simpleString} to ${to.simpleString} as it may truncate\n" + - "The type path of the target object is:\n" + walkedTypePath.mkString("", "\n", "\n") + - "You can either add an explicit cast to the input data or choose a higher precision " + - "type of the field in the target object") - } - - private def illegalNumericPrecedence(from: DataType, to: DataType): Boolean = { - val fromPrecedence = HiveTypeCoercion.numericPrecedence.indexOf(from) - val toPrecedence = HiveTypeCoercion.numericPrecedence.indexOf(to) - toPrecedence > 0 && fromPrecedence > toPrecedence - } - - def apply(plan: LogicalPlan): LogicalPlan = { - plan transformAllExpressions { - case u @ UpCast(child, _, _) if !child.resolved => u - - case UpCast(child, dataType, walkedTypePath) => (child.dataType, dataType) match { - case (from: NumericType, to: DecimalType) if !to.isWiderThan(from) => - fail(child, to, walkedTypePath) - case (from: DecimalType, to: NumericType) if !from.isTighterThan(to) => - fail(child, to, walkedTypePath) - case (from, to) if illegalNumericPrecedence(from, to) => - fail(child, to, walkedTypePath) - case (TimestampType, DateType) => - fail(child, DateType, walkedTypePath) - case (StringType, to: NumericType) => - fail(child, to, walkedTypePath) - case _ => Cast(child, dataType.asNullable) - } - } - } -} - /** * Maps a time column to multiple time windows using the Expand operator. Since it's non-trivial to * figure out how many windows a time column can map to, we over-estimate the number of windows and diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala index e73d367a73..fbbf6302e9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala @@ -307,3 +307,25 @@ case class UnresolvedAlias(child: Expression, aliasName: Option[String] = None) override lazy val resolved = false } + +/** + * Holds the deserializer expression and the attributes that are available during the resolution + * for it. Deserializer expression is a special kind of expression that is not always resolved by + * children output, but by given attributes, e.g. the `keyDeserializer` in `MapGroups` should be + * resolved by `groupingAttributes` instead of children output. + * + * @param deserializer The unresolved deserializer expression + * @param inputAttributes The input attributes used to resolve deserializer expression, can be empty + * if we want to resolve deserializer by children output. + */ +case class UnresolvedDeserializer(deserializer: Expression, inputAttributes: Seq[Attribute]) + extends UnaryExpression with Unevaluable with NonSQLExpression { + // The input attributes used to resolve deserializer expression must be all resolved. + require(inputAttributes.forall(_.resolved), "Input attributes must all be resolved.") + + override def child: Expression = deserializer + override def dataType: DataType = throw new UnresolvedException(this, "dataType") + override def foldable: Boolean = throw new UnresolvedException(this, "foldable") + override def nullable: Boolean = throw new UnresolvedException(this, "nullable") + override lazy val resolved = false +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala index 1c712fde26..56d29cfbe1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala @@ -24,7 +24,7 @@ import scala.reflect.runtime.universe.{typeTag, TypeTag} import org.apache.spark.sql.{AnalysisException, Encoder} import org.apache.spark.sql.catalyst.{InternalRow, JavaTypeInference, ScalaReflection} -import org.apache.spark.sql.catalyst.analysis.{SimpleAnalyzer, UnresolvedAttribute, UnresolvedExtractValue} +import org.apache.spark.sql.catalyst.analysis.{SimpleAnalyzer, UnresolvedAttribute, UnresolvedDeserializer, UnresolvedExtractValue} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateSafeProjection, GenerateUnsafeProjection} import org.apache.spark.sql.catalyst.optimizer.SimplifyCasts @@ -317,11 +317,11 @@ case class ExpressionEncoder[T]( def resolve( schema: Seq[Attribute], outerScopes: ConcurrentMap[String, AnyRef]): ExpressionEncoder[T] = { - val resolved = SimpleAnalyzer.ResolveReferences.resolveDeserializer(deserializer, schema) - // Make a fake plan to wrap the deserializer, so that we can go though the whole analyzer, check // analysis, go through optimizer, etc. - val plan = Project(Alias(resolved, "")() :: Nil, LocalRelation(schema)) + val plan = Project( + Alias(UnresolvedDeserializer(deserializer, schema), "")() :: Nil, + LocalRelation(schema)) val analyzedPlan = SimpleAnalyzer.execute(plan) SimpleAnalyzer.checkAnalysis(analyzedPlan) copy(deserializer = SimplifyCasts(analyzedPlan).expressions.head.children.head) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala index 07b67a0240..eebd43dae9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.expressions +import java.lang.reflect.Modifier + import scala.annotation.tailrec import scala.language.existentials import scala.reflect.ClassTag @@ -112,7 +114,7 @@ case class Invoke( arguments: Seq[Expression] = Nil) extends Expression with NonSQLExpression { override def nullable: Boolean = true - override def children: Seq[Expression] = arguments.+:(targetObject) + override def children: Seq[Expression] = targetObject +: arguments override def eval(input: InternalRow): Any = throw new UnsupportedOperationException("Only code-generated evaluation is supported.") @@ -214,6 +216,16 @@ case class NewInstance( override def children: Seq[Expression] = arguments + override lazy val resolved: Boolean = { + // If the class to construct is an inner class, we need to get its outer pointer, or this + // expression should be regarded as unresolved. + // Note that static inner classes (e.g., inner classes within Scala objects) don't need + // outer pointer registration. + val needOuterPointer = + outerPointer.isEmpty && cls.isMemberClass && !Modifier.isStatic(cls.getModifiers) + childrenResolved && !needOuterPointer + } + override def eval(input: InternalRow): Any = throw new UnsupportedOperationException("Only code-generated evaluation is supported.") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala index 058fb6bff1..58313c7b72 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.Encoder +import org.apache.spark.sql.catalyst.analysis.UnresolvedDeserializer import org.apache.spark.sql.catalyst.encoders._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types.{ObjectType, StructType} @@ -32,13 +33,6 @@ trait ObjectOperator extends LogicalPlan { override def output: Seq[Attribute] = serializer.map(_.toAttribute) - /** - * An [[ObjectOperator]] may have one or more deserializers to convert internal rows to objects. - * It must also provide the attributes that are available during the resolution of each - * deserializer. - */ - def deserializers: Seq[(Expression, Seq[Attribute])] - /** * The object type that is produced by the user defined function. Note that the return type here * is the same whether or not the operator is output serialized data. @@ -71,7 +65,7 @@ object MapPartitions { child: LogicalPlan): MapPartitions = { MapPartitions( func.asInstanceOf[Iterator[Any] => Iterator[Any]], - encoderFor[T].deserializer, + UnresolvedDeserializer(encoderFor[T].deserializer, Nil), encoderFor[U].namedExpressions, child) } @@ -87,9 +81,7 @@ case class MapPartitions( func: Iterator[Any] => Iterator[Any], deserializer: Expression, serializer: Seq[NamedExpression], - child: LogicalPlan) extends UnaryNode with ObjectOperator { - override def deserializers: Seq[(Expression, Seq[Attribute])] = Seq(deserializer -> child.output) -} + child: LogicalPlan) extends UnaryNode with ObjectOperator /** Factory for constructing new `AppendColumn` nodes. */ object AppendColumns { @@ -98,7 +90,7 @@ object AppendColumns { child: LogicalPlan): AppendColumns = { new AppendColumns( func.asInstanceOf[Any => Any], - encoderFor[T].deserializer, + UnresolvedDeserializer(encoderFor[T].deserializer, Nil), encoderFor[U].namedExpressions, child) } @@ -120,8 +112,6 @@ case class AppendColumns( override def output: Seq[Attribute] = child.output ++ newColumns def newColumns: Seq[Attribute] = serializer.map(_.toAttribute) - - override def deserializers: Seq[(Expression, Seq[Attribute])] = Seq(deserializer -> child.output) } /** Factory for constructing new `MapGroups` nodes. */ @@ -133,8 +123,8 @@ object MapGroups { child: LogicalPlan): MapGroups = { new MapGroups( func.asInstanceOf[(Any, Iterator[Any]) => TraversableOnce[Any]], - encoderFor[K].deserializer, - encoderFor[T].deserializer, + UnresolvedDeserializer(encoderFor[K].deserializer, groupingAttributes), + UnresolvedDeserializer(encoderFor[T].deserializer, dataAttributes), encoderFor[U].namedExpressions, groupingAttributes, dataAttributes, @@ -158,11 +148,7 @@ case class MapGroups( serializer: Seq[NamedExpression], groupingAttributes: Seq[Attribute], dataAttributes: Seq[Attribute], - child: LogicalPlan) extends UnaryNode with ObjectOperator { - - override def deserializers: Seq[(Expression, Seq[Attribute])] = - Seq(keyDeserializer -> groupingAttributes, valueDeserializer -> dataAttributes) -} + child: LogicalPlan) extends UnaryNode with ObjectOperator /** Factory for constructing new `CoGroup` nodes. */ object CoGroup { @@ -170,22 +156,24 @@ object CoGroup { func: (Key, Iterator[Left], Iterator[Right]) => TraversableOnce[Result], leftGroup: Seq[Attribute], rightGroup: Seq[Attribute], - leftData: Seq[Attribute], - rightData: Seq[Attribute], + leftAttr: Seq[Attribute], + rightAttr: Seq[Attribute], left: LogicalPlan, right: LogicalPlan): CoGroup = { require(StructType.fromAttributes(leftGroup) == StructType.fromAttributes(rightGroup)) CoGroup( func.asInstanceOf[(Any, Iterator[Any], Iterator[Any]) => TraversableOnce[Any]], - encoderFor[Key].deserializer, - encoderFor[Left].deserializer, - encoderFor[Right].deserializer, + // The `leftGroup` and `rightGroup` are guaranteed te be of same schema, so it's safe to + // resolve the `keyDeserializer` based on either of them, here we pick the left one. + UnresolvedDeserializer(encoderFor[Key].deserializer, leftGroup), + UnresolvedDeserializer(encoderFor[Left].deserializer, leftAttr), + UnresolvedDeserializer(encoderFor[Right].deserializer, rightAttr), encoderFor[Result].namedExpressions, leftGroup, rightGroup, - leftData, - rightData, + leftAttr, + rightAttr, left, right) } @@ -206,10 +194,4 @@ case class CoGroup( leftAttr: Seq[Attribute], rightAttr: Seq[Attribute], left: LogicalPlan, - right: LogicalPlan) extends BinaryNode with ObjectOperator { - - override def deserializers: Seq[(Expression, Seq[Attribute])] = - // The `leftGroup` and `rightGroup` are guaranteed te be of same schema, so it's safe to resolve - // the `keyDeserializer` based on either of them, here we pick the left one. - Seq(keyDeserializer -> leftGroup, leftDeserializer -> leftAttr, rightDeserializer -> rightAttr) -} + right: LogicalPlan) extends BinaryNode with ObjectOperator -- cgit v1.2.3 From 463bac001171622538fc93d2e31d1a617ab562e6 Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Tue, 5 Apr 2016 11:12:05 -0700 Subject: [SPARK-14257][SQL] Allow multiple continuous queries to be started from the same DataFrame ## What changes were proposed in this pull request? Make StreamingRelation store the closure to create the source in StreamExecution so that we can start multiple continuous queries from the same DataFrame. ## How was this patch tested? `test("DataFrame reuse")` Author: Shixiong Zhu Closes #12049 from zsxwing/df-reuse. --- .../apache/spark/sql/ContinuousQueryManager.scala | 12 ++++- .../org/apache/spark/sql/DataFrameReader.scala | 2 +- .../main/scala/org/apache/spark/sql/Dataset.scala | 6 ++- .../sql/execution/streaming/StreamExecution.scala | 8 +-- .../execution/streaming/StreamingRelation.scala | 27 ++++++++-- .../spark/sql/execution/streaming/memory.scala | 6 +-- .../scala/org/apache/spark/sql/StreamTest.scala | 5 +- .../streaming/ContinuousQueryManagerSuite.scala | 6 +-- .../sql/streaming/FileStreamSourceSuite.scala | 10 ++-- .../apache/spark/sql/streaming/StreamSuite.scala | 62 +++++++++++++++++++++- 10 files changed, 118 insertions(+), 26 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/ContinuousQueryManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/ContinuousQueryManager.scala index 2306df09b8..d7f71bd4b0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/ContinuousQueryManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/ContinuousQueryManager.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql import scala.collection.mutable import org.apache.spark.annotation.Experimental -import org.apache.spark.sql.execution.streaming.{ContinuousQueryListenerBus, Sink, StreamExecution} +import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.state.StateStoreCoordinatorRef import org.apache.spark.sql.util.ContinuousQueryListener @@ -178,11 +178,19 @@ class ContinuousQueryManager(sqlContext: SQLContext) { throw new IllegalArgumentException( s"Cannot start query with name $name as a query with that name is already active") } + val logicalPlan = df.logicalPlan.transform { + case StreamingRelation(dataSource, _, output) => + // Materialize source to avoid creating it in every batch + val source = dataSource.createSource() + // We still need to use the previous `output` instead of `source.schema` as attributes in + // "df.logicalPlan" has already used attributes of the previous `output`. + StreamingExecutionRelation(source, output) + } val query = new StreamExecution( sqlContext, name, checkpointLocation, - df.logicalPlan, + logicalPlan, sink, trigger) query.start() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index a5a6e01e99..15f2344df6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -176,7 +176,7 @@ class DataFrameReader private[sql](sqlContext: SQLContext) extends Logging { userSpecifiedSchema = userSpecifiedSchema, className = source, options = extraOptions.toMap) - Dataset.ofRows(sqlContext, StreamingRelation(dataSource.createSource())) + Dataset.ofRows(sqlContext, StreamingRelation(dataSource)) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index db2134b020..f472a5068e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -47,7 +47,7 @@ import org.apache.spark.sql.execution.command.ExplainCommand import org.apache.spark.sql.execution.datasources.{CreateTableUsingAsSelect, LogicalRelation} import org.apache.spark.sql.execution.datasources.json.JacksonGenerator import org.apache.spark.sql.execution.python.EvaluatePython -import org.apache.spark.sql.execution.streaming.StreamingRelation +import org.apache.spark.sql.execution.streaming.{StreamingExecutionRelation, StreamingRelation} import org.apache.spark.sql.types._ import org.apache.spark.storage.StorageLevel import org.apache.spark.util.Utils @@ -462,7 +462,9 @@ class Dataset[T] private[sql]( * @since 2.0.0 */ @Experimental - def isStreaming: Boolean = logicalPlan.find(_.isInstanceOf[StreamingRelation]).isDefined + def isStreaming: Boolean = logicalPlan.find { n => + n.isInstanceOf[StreamingRelation] || n.isInstanceOf[StreamingExecutionRelation] + }.isDefined /** * Displays the [[Dataset]] in a tabular form. Strings more than 20 characters will be truncated, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala index 64f80699ce..3e4acb752a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala @@ -43,9 +43,9 @@ import org.apache.spark.util.UninterruptibleThread * and the results are committed transactionally to the given [[Sink]]. */ class StreamExecution( - val sqlContext: SQLContext, + override val sqlContext: SQLContext, override val name: String, - val checkpointRoot: String, + checkpointRoot: String, private[sql] val logicalPlan: LogicalPlan, val sink: Sink, val trigger: Trigger) extends ContinuousQuery with Logging { @@ -72,7 +72,7 @@ class StreamExecution( /** All stream sources present the query plan. */ private val sources = - logicalPlan.collect { case s: StreamingRelation => s.source } + logicalPlan.collect { case s: StreamingExecutionRelation => s.source } /** A list of unique sources in the query plan. */ private val uniqueSources = sources.distinct @@ -295,7 +295,7 @@ class StreamExecution( var replacements = new ArrayBuffer[(Attribute, Attribute)] // Replace sources in the logical plan with data that has arrived since the last batch. val withNewSources = logicalPlan transform { - case StreamingRelation(source, output) => + case StreamingExecutionRelation(source, output) => newData.get(source).map { data => val newPlan = data.logicalPlan assert(output.size == newPlan.output.size, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala index e35c444348..f951dea735 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala @@ -19,16 +19,37 @@ package org.apache.spark.sql.execution.streaming import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.LeafNode +import org.apache.spark.sql.execution.datasources.DataSource object StreamingRelation { - def apply(source: Source): StreamingRelation = - StreamingRelation(source, source.schema.toAttributes) + def apply(dataSource: DataSource): StreamingRelation = { + val source = dataSource.createSource() + StreamingRelation(dataSource, source.toString, source.schema.toAttributes) + } +} + +/** + * Used to link a streaming [[DataSource]] into a + * [[org.apache.spark.sql.catalyst.plans.logical.LogicalPlan]]. This is only used for creating + * a streaming [[org.apache.spark.sql.DataFrame]] from [[org.apache.spark.sql.DataFrameReader]]. + * It should be used to create [[Source]] and converted to [[StreamingExecutionRelation]] when + * passing to [StreamExecution]] to run a query. + */ +case class StreamingRelation(dataSource: DataSource, sourceName: String, output: Seq[Attribute]) + extends LeafNode { + override def toString: String = sourceName } /** * Used to link a streaming [[Source]] of data into a * [[org.apache.spark.sql.catalyst.plans.logical.LogicalPlan]]. */ -case class StreamingRelation(source: Source, output: Seq[Attribute]) extends LeafNode { +case class StreamingExecutionRelation(source: Source, output: Seq[Attribute]) extends LeafNode { override def toString: String = source.toString } + +object StreamingExecutionRelation { + def apply(source: Source): StreamingExecutionRelation = { + StreamingExecutionRelation(source, source.schema.toAttributes) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala index 7d97f81b0f..b652530d7c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala @@ -22,11 +22,9 @@ import java.util.concurrent.atomic.AtomicInteger import scala.collection.mutable.ArrayBuffer import scala.util.control.NonFatal -import org.apache.spark.SparkEnv import org.apache.spark.internal.Logging import org.apache.spark.sql.{DataFrame, Dataset, Encoder, Row, SQLContext} -import org.apache.spark.sql.catalyst.encoders.{encoderFor, RowEncoder} -import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.encoders.encoderFor import org.apache.spark.sql.types.StructType object MemoryStream { @@ -45,7 +43,7 @@ object MemoryStream { case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) extends Source with Logging { protected val encoder = encoderFor[A] - protected val logicalPlan = StreamingRelation(this) + protected val logicalPlan = StreamingExecutionRelation(this) protected val output = logicalPlan.output protected val batches = new ArrayBuffer[Dataset[A]] diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala index 3444e56e9e..6ccc99fe17 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala @@ -36,6 +36,7 @@ import org.scalatest.time.SpanSugar._ import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder, RowEncoder} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.util._ +import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.execution.streaming._ import org.apache.spark.util.Utils @@ -66,9 +67,9 @@ import org.apache.spark.util.Utils trait StreamTest extends QueryTest with Timeouts { implicit class RichSource(s: Source) { - def toDF(): DataFrame = Dataset.ofRows(sqlContext, StreamingRelation(s)) + def toDF(): DataFrame = Dataset.ofRows(sqlContext, StreamingExecutionRelation(s)) - def toDS[A: Encoder](): Dataset[A] = Dataset(sqlContext, StreamingRelation(s)) + def toDS[A: Encoder](): Dataset[A] = Dataset(sqlContext, StreamingExecutionRelation(s)) } /** How long to wait for an active stream to catch up when checking a result. */ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/ContinuousQueryManagerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/ContinuousQueryManagerSuite.scala index 29bd3e018e..33787de9da 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/ContinuousQueryManagerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/ContinuousQueryManagerSuite.scala @@ -29,7 +29,7 @@ import org.scalatest.time.SpanSugar._ import org.apache.spark.SparkException import org.apache.spark.sql.{ContinuousQuery, Dataset, StreamTest} -import org.apache.spark.sql.execution.streaming.{MemorySink, MemoryStream, StreamExecution, StreamingRelation} +import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.util.Utils @@ -294,8 +294,8 @@ class ContinuousQueryManagerSuite extends StreamTest with SharedSQLContext with if (withError) { logDebug(s"Terminating query ${queryToStop.name} with error") queryToStop.asInstanceOf[StreamExecution].logicalPlan.collect { - case StreamingRelation(memoryStream, _) => - memoryStream.asInstanceOf[MemoryStream[Int]].addData(0) + case StreamingExecutionRelation(source, _) => + source.asInstanceOf[MemoryStream[Int]].addData(0) } } else { logDebug(s"Stopping query ${queryToStop.name}") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala index 054f5c9fa2..09daa7f81a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala @@ -71,8 +71,9 @@ class FileStreamSourceTest extends StreamTest with SharedSQLContext { } reader.stream(path) .queryExecution.analyzed - .collect { case StreamingRelation(s: FileStreamSource, _) => s } - .head + .collect { case StreamingRelation(dataSource, _, _) => + dataSource.createSource().asInstanceOf[FileStreamSource] + }.head } val valueSchema = new StructType().add("value", StringType) @@ -96,8 +97,9 @@ class FileStreamSourceSuite extends FileStreamSourceTest with SharedSQLContext { reader.stream() } df.queryExecution.analyzed - .collect { case StreamingRelation(s: FileStreamSource, _) => s } - .head + .collect { case StreamingRelation(dataSource, _, _) => + dataSource.createSource().asInstanceOf[FileStreamSource] + }.head .schema } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala index fbb1792596..e4ea555526 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala @@ -17,9 +17,13 @@ package org.apache.spark.sql.streaming -import org.apache.spark.sql.{Row, StreamTest} +import org.scalatest.concurrent.Eventually._ + +import org.apache.spark.sql.{DataFrame, Row, SQLContext, StreamTest} import org.apache.spark.sql.execution.streaming._ +import org.apache.spark.sql.sources.StreamSourceProvider import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types.{IntegerType, StructField, StructType} class StreamSuite extends StreamTest with SharedSQLContext { @@ -81,4 +85,60 @@ class StreamSuite extends StreamTest with SharedSQLContext { AddData(inputData, 1, 2, 3, 4), CheckAnswer(2, 4)) } + + test("DataFrame reuse") { + def assertDF(df: DataFrame) { + withTempDir { outputDir => + withTempDir { checkpointDir => + val query = df.write.format("parquet") + .option("checkpointLocation", checkpointDir.getAbsolutePath) + .startStream(outputDir.getAbsolutePath) + try { + query.processAllAvailable() + val outputDf = sqlContext.read.parquet(outputDir.getAbsolutePath).as[Long] + checkDataset[Long](outputDf, (0L to 10L).toArray: _*) + } finally { + query.stop() + } + } + } + } + + val df = sqlContext.read.format(classOf[FakeDefaultSource].getName).stream() + assertDF(df) + assertDF(df) + } +} + +/** + * A fake StreamSourceProvider thats creates a fake Source that cannot be reused. + */ +class FakeDefaultSource extends StreamSourceProvider { + + override def createSource( + sqlContext: SQLContext, + schema: Option[StructType], + providerName: String, + parameters: Map[String, String]): Source = { + // Create a fake Source that emits 0 to 10. + new Source { + private var offset = -1L + + override def schema: StructType = StructType(StructField("a", IntegerType) :: Nil) + + override def getOffset: Option[Offset] = { + if (offset >= 10) { + None + } else { + offset += 1 + Some(LongOffset(offset)) + } + } + + override def getBatch(start: Option[Offset], end: Offset): DataFrame = { + val startOffset = start.map(_.asInstanceOf[LongOffset].offset).getOrElse(-1L) + 1 + sqlContext.range(startOffset, end.asInstanceOf[LongOffset].offset + 1).toDF("a") + } + } + } } -- cgit v1.2.3 From bc36df127d3b9f56b4edaeb5eca7697d4aef761a Mon Sep 17 00:00:00 2001 From: Devaraj K Date: Tue, 5 Apr 2016 14:12:00 -0500 Subject: [SPARK-13063][YARN] Make the SPARK YARN STAGING DIR as configurable MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? Made the SPARK YARN STAGING DIR as configurable with the configuration as 'spark.yarn.staging-dir'. ## How was this patch tested? I have verified it manually by running applications on yarn, If the 'spark.yarn.staging-dir' is configured then the value used as staging directory otherwise uses the default value i.e. file system’s home directory for the user. Author: Devaraj K Closes #12082 from devaraj-kavali/SPARK-13063. --- docs/running-on-yarn.md | 7 +++++++ .../scala/org/apache/spark/deploy/yarn/Client.scala | 18 +++++++++++++++--- .../scala/org/apache/spark/deploy/yarn/config.scala | 5 +++++ 3 files changed, 27 insertions(+), 3 deletions(-) diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md index bb83272ec8..ddc75a70b9 100644 --- a/docs/running-on-yarn.md +++ b/docs/running-on-yarn.md @@ -159,6 +159,13 @@ If you need a reference to the proper location to put log files in the YARN so t HDFS replication level for the files uploaded into HDFS for the application. These include things like the Spark jar, the app jar, and any distributed cache files/archives. +
  • + + + + diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index 336e29fc6b..5e7e3be08d 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -182,8 +182,8 @@ private[spark] class Client( val appStagingDir = getAppStagingDir(appId) try { val preserveFiles = sparkConf.get(PRESERVE_STAGING_FILES) - val stagingDirPath = new Path(appStagingDir) val fs = FileSystem.get(hadoopConf) + val stagingDirPath = getAppStagingDirPath(sparkConf, fs, appStagingDir) if (!preserveFiles && fs.exists(stagingDirPath)) { logInfo("Deleting staging directory " + stagingDirPath) fs.delete(stagingDirPath, true) @@ -357,7 +357,7 @@ private[spark] class Client( // Upload Spark and the application JAR to the remote file system if necessary, // and add them as local resources to the application master. val fs = FileSystem.get(hadoopConf) - val dst = new Path(fs.getHomeDirectory(), appStagingDir) + val dst = getAppStagingDirPath(sparkConf, fs, appStagingDir) val nns = YarnSparkHadoopUtil.get.getNameNodesToAccess(sparkConf) + dst YarnSparkHadoopUtil.get.obtainTokensForNamenodes(nns, hadoopConf, credentials) // Used to keep track of URIs added to the distributed cache. If the same URI is added @@ -668,7 +668,7 @@ private[spark] class Client( env("SPARK_USER") = UserGroupInformation.getCurrentUser().getShortUserName() if (loginFromKeytab) { val remoteFs = FileSystem.get(hadoopConf) - val stagingDirPath = new Path(remoteFs.getHomeDirectory, stagingDir) + val stagingDirPath = getAppStagingDirPath(sparkConf, remoteFs, stagingDir) val credentialsFile = "credentials-" + UUID.randomUUID().toString sparkConf.set(CREDENTIALS_FILE_PATH, new Path(stagingDirPath, credentialsFile).toString) logInfo(s"Credentials file set to: $credentialsFile") @@ -1438,4 +1438,16 @@ private object Client extends Logging { uri.startsWith(s"$LOCAL_SCHEME:") } + /** + * Returns the app staging dir based on the STAGING_DIR configuration if configured + * otherwise based on the users home directory. + */ + private def getAppStagingDirPath( + conf: SparkConf, + fs: FileSystem, + appStagingDir: String): Path = { + val baseDir = conf.get(STAGING_DIR).map { new Path(_) }.getOrElse(fs.getHomeDirectory()) + new Path(baseDir, appStagingDir) + } + } diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/config.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/config.scala index a3b9134b58..5188a3e229 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/config.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/config.scala @@ -108,6 +108,11 @@ package object config { .intConf .optional + private[spark] val STAGING_DIR = ConfigBuilder("spark.yarn.stagingDir") + .doc("Staging directory used while submitting applications.") + .stringConf + .optional + /* Cluster-mode launcher configuration. */ private[spark] val WAIT_FOR_APP_COMPLETION = ConfigBuilder("spark.yarn.submit.waitAppCompletion") -- cgit v1.2.3 From 72544d6f2a72b9e44e0a32de1fb379e3342be5c3 Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Tue, 5 Apr 2016 12:27:06 -0700 Subject: [SPARK-14123][SPARK-14384][SQL] Handle CreateFunction/DropFunction ## What changes were proposed in this pull request? This PR implements CreateFunction and DropFunction commands. Besides implementing these two commands, we also change how to manage functions. Here are the main changes. * `FunctionRegistry` will be a container to store all functions builders and it will not actively load any functions. Because of this change, we do not need to maintain a separate registry for HiveContext. So, `HiveFunctionRegistry` is deleted. * SessionCatalog takes care the job of loading a function if this function is not in the `FunctionRegistry` but its metadata is stored in the external catalog. For this case, SessionCatalog will (1) load the metadata from the external catalog, (2) load all needed resources (i.e. jars and files), (3) create a function builder based on the function definition, (4) register the function builder in the `FunctionRegistry`. * A `UnresolvedGenerator` is created. So, the parser will not need to call `FunctionRegistry` directly during parsing, which is not a good time to create a Hive UDTF. In the analysis phase, we will resolve `UnresolvedGenerator`. This PR is based on viirya's https://github.com/apache/spark/pull/12036/ ## How was this patch tested? Existing tests and new tests. ## TODOs [x] Self-review [x] Cleanup [x] More tests for create/drop functions (we need to more tests for permanent functions). [ ] File JIRAs for all TODOs [x] Standardize the error message when a function does not exist. Author: Yin Huai Author: Liang-Chi Hsieh Closes #12117 from yhuai/function. --- .../spark/sql/catalyst/analysis/Analyzer.scala | 17 +- .../sql/catalyst/analysis/FunctionRegistry.scala | 2 + .../spark/sql/catalyst/analysis/unresolved.scala | 31 +++- .../sql/catalyst/catalog/InMemoryCatalog.scala | 5 - .../sql/catalyst/catalog/SessionCatalog.scala | 171 +++++++++++++-------- .../sql/catalyst/catalog/functionResources.scala | 61 ++++++++ .../spark/sql/catalyst/catalog/interface.scala | 16 +- .../apache/spark/sql/catalyst/identifiers.scala | 16 +- .../spark/sql/catalyst/parser/AstBuilder.scala | 14 +- .../spark/sql/catalyst/analysis/AnalysisTest.scala | 2 +- .../catalyst/analysis/DecimalPrecisionSuite.scala | 2 +- .../sql/catalyst/catalog/CatalogTestCases.scala | 20 +-- .../sql/catalyst/catalog/SessionCatalogSuite.scala | 132 ++++------------ .../optimizer/BooleanSimplificationSuite.scala | 1 - .../catalyst/optimizer/EliminateSortsSuite.scala | 2 +- .../sql/catalyst/parser/PlanParserSuite.scala | 15 +- .../scala/org/apache/spark/sql/SQLContext.scala | 18 ++- .../spark/sql/execution/SparkSqlParser.scala | 7 +- .../spark/sql/execution/command/commands.scala | 29 ++-- .../apache/spark/sql/execution/command/ddl.scala | 20 --- .../spark/sql/execution/command/functions.scala | 114 ++++++++++++++ .../apache/spark/sql/internal/SessionState.scala | 11 +- .../scala/org/apache/spark/sql/SQLQuerySuite.scala | 8 +- .../test/scala/org/apache/spark/sql/UDFSuite.scala | 3 +- .../sql/execution/command/DDLCommandSuite.scala | 12 +- .../org/apache/spark/sql/test/SQLTestUtils.scala | 22 +++ .../thriftserver/HiveThriftServer2Suites.scala | 94 +++++------ .../spark/sql/hive/HiveExternalCatalog.scala | 11 +- .../apache/spark/sql/hive/HiveSessionCatalog.scala | 143 ++++++++++++++++- .../apache/spark/sql/hive/HiveSessionState.scala | 18 +-- .../spark/sql/hive/client/HiveClientImpl.scala | 20 ++- .../spark/sql/hive/execution/HiveSqlParser.scala | 17 +- .../scala/org/apache/spark/sql/hive/hiveUDFs.scala | 119 +------------- .../org/apache/spark/sql/hive/test/TestHive.scala | 16 +- .../spark/sql/hive/HiveSparkSubmitSuite.scala | 165 ++++++++++++++++++++ .../scala/org/apache/spark/sql/hive/UDFSuite.scala | 167 +++++++++++++++++++- .../spark/sql/hive/execution/HiveQuerySuite.scala | 8 +- .../spark/sql/hive/execution/HiveUDFSuite.scala | 10 +- .../spark/sql/hive/execution/SQLQuerySuite.scala | 74 ++++++--- 39 files changed, 1100 insertions(+), 513 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/functionResources.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/command/functions.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 3e0a6d29b4..473c91e69e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -45,10 +45,7 @@ object SimpleAnalyzer new SimpleCatalystConf(caseSensitiveAnalysis = true)) class SimpleAnalyzer(functionRegistry: FunctionRegistry, conf: CatalystConf) - extends Analyzer( - new SessionCatalog(new InMemoryCatalog, functionRegistry, conf), - functionRegistry, - conf) + extends Analyzer(new SessionCatalog(new InMemoryCatalog, functionRegistry, conf), conf) /** * Provides a logical query plan analyzer, which translates [[UnresolvedAttribute]]s and @@ -57,7 +54,6 @@ class SimpleAnalyzer(functionRegistry: FunctionRegistry, conf: CatalystConf) */ class Analyzer( catalog: SessionCatalog, - registry: FunctionRegistry, conf: CatalystConf, maxIterations: Int = 100) extends RuleExecutor[LogicalPlan] with CheckAnalysis { @@ -756,9 +752,18 @@ class Analyzer( case q: LogicalPlan => q transformExpressions { case u if !u.childrenResolved => u // Skip until children are resolved. + case u @ UnresolvedGenerator(name, children) => + withPosition(u) { + catalog.lookupFunction(name, children) match { + case generator: Generator => generator + case other => + failAnalysis(s"$name is expected to be a generator. However, " + + s"its class is ${other.getClass.getCanonicalName}, which is not a generator.") + } + } case u @ UnresolvedFunction(name, children, isDistinct) => withPosition(u) { - registry.lookupFunction(name, children) match { + catalog.lookupFunction(name, children) match { // DISTINCT is not meaningful for a Max or a Min. case max: Max if isDistinct => AggregateExpression(max, Complete, isDistinct = false) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index ca8db3cbc5..7af5ffbe47 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -52,6 +52,8 @@ trait FunctionRegistry { /** Drop a function and return whether the function existed. */ def dropFunction(name: String): Boolean + /** Checks if a function with a given name exists. */ + def functionExists(name: String): Boolean = lookupFunction(name).isDefined } class SimpleFunctionRegistry extends FunctionRegistry { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala index fbbf6302e9..b2f362b6b8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala @@ -18,9 +18,9 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.{errors, TableIdentifier} +import org.apache.spark.sql.catalyst.{errors, InternalRow, TableIdentifier} import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodegenFallback, ExprCode} import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan} import org.apache.spark.sql.catalyst.trees.TreeNode import org.apache.spark.sql.catalyst.util.quoteIdentifier @@ -133,6 +133,33 @@ object UnresolvedAttribute { } } +/** + * Represents an unresolved generator, which will be created by the parser for + * the [[org.apache.spark.sql.catalyst.plans.logical.Generate]] operator. + * The analyzer will resolve this generator. + */ +case class UnresolvedGenerator(name: String, children: Seq[Expression]) extends Generator { + + override def elementTypes: Seq[(DataType, Boolean, String)] = + throw new UnresolvedException(this, "elementTypes") + override def dataType: DataType = throw new UnresolvedException(this, "dataType") + override def foldable: Boolean = throw new UnresolvedException(this, "foldable") + override def nullable: Boolean = throw new UnresolvedException(this, "nullable") + override lazy val resolved = false + + override def prettyName: String = name + override def toString: String = s"'$name(${children.mkString(", ")})" + + override def eval(input: InternalRow = null): TraversableOnce[InternalRow] = + throw new UnsupportedOperationException(s"Cannot evaluate expression: $this") + + override protected def genCode(ctx: CodegenContext, ev: ExprCode): String = + throw new UnsupportedOperationException(s"Cannot evaluate expression: $this") + + override def terminate(): TraversableOnce[InternalRow] = + throw new UnsupportedOperationException(s"Cannot evaluate expression: $this") +} + case class UnresolvedFunction( name: String, children: Seq[Expression], diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala index 2bbb970ec9..2af0107fa3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala @@ -315,11 +315,6 @@ class InMemoryCatalog extends ExternalCatalog { catalog(db).functions.put(newName, newFunc) } - override def alterFunction(db: String, funcDefinition: CatalogFunction): Unit = synchronized { - requireFunctionExists(db, funcDefinition.identifier.funcName) - catalog(db).functions.put(funcDefinition.identifier.funcName, funcDefinition) - } - override def getFunction(db: String, funcName: String): CatalogFunction = synchronized { requireFunctionExists(db, funcName) catalog(db).functions(funcName) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala index 3b8ce6373d..c08ffbb235 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala @@ -24,7 +24,7 @@ import scala.collection.mutable import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.{CatalystConf, SimpleCatalystConf} import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} -import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, SimpleFunctionRegistry} +import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, NoSuchFunctionException, SimpleFunctionRegistry} import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionInfo} import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, SubqueryAlias} @@ -39,17 +39,21 @@ import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, SubqueryAlias} */ class SessionCatalog( externalCatalog: ExternalCatalog, + functionResourceLoader: FunctionResourceLoader, functionRegistry: FunctionRegistry, conf: CatalystConf) { import ExternalCatalog._ - def this(externalCatalog: ExternalCatalog, functionRegistry: FunctionRegistry) { - this(externalCatalog, functionRegistry, new SimpleCatalystConf(true)) + def this( + externalCatalog: ExternalCatalog, + functionRegistry: FunctionRegistry, + conf: CatalystConf) { + this(externalCatalog, DummyFunctionResourceLoader, functionRegistry, conf) } // For testing only. def this(externalCatalog: ExternalCatalog) { - this(externalCatalog, new SimpleFunctionRegistry) + this(externalCatalog, new SimpleFunctionRegistry, new SimpleCatalystConf(true)) } protected[this] val tempTables = new mutable.HashMap[String, LogicalPlan] @@ -439,53 +443,88 @@ class SessionCatalog( */ def dropFunction(name: FunctionIdentifier): Unit = { val db = name.database.getOrElse(currentDb) + val qualified = name.copy(database = Some(db)).unquotedString + if (functionRegistry.functionExists(qualified)) { + // If we have loaded this function into the FunctionRegistry, + // also drop it from there. + // For a permanent function, because we loaded it to the FunctionRegistry + // when it's first used, we also need to drop it from the FunctionRegistry. + functionRegistry.dropFunction(qualified) + } externalCatalog.dropFunction(db, name.funcName) } - /** - * Alter a metastore function whose name that matches the one specified in `funcDefinition`. - * - * If no database is specified in `funcDefinition`, assume the function is in the - * current database. - * - * Note: If the underlying implementation does not support altering a certain field, - * this becomes a no-op. - */ - def alterFunction(funcDefinition: CatalogFunction): Unit = { - val db = funcDefinition.identifier.database.getOrElse(currentDb) - val newFuncDefinition = funcDefinition.copy( - identifier = FunctionIdentifier(funcDefinition.identifier.funcName, Some(db))) - externalCatalog.alterFunction(db, newFuncDefinition) - } - /** * Retrieve the metadata of a metastore function. * * If a database is specified in `name`, this will return the function in that database. * If no database is specified, this will return the function in the current database. */ + // TODO: have a better name. This method is actually for fetching the metadata of a function. def getFunction(name: FunctionIdentifier): CatalogFunction = { val db = name.database.getOrElse(currentDb) externalCatalog.getFunction(db, name.funcName) } + /** + * Check if the specified function exists. + */ + def functionExists(name: FunctionIdentifier): Boolean = { + if (functionRegistry.functionExists(name.unquotedString)) { + // This function exists in the FunctionRegistry. + true + } else { + // Need to check if this function exists in the metastore. + try { + // TODO: It's better to ask external catalog if this function exists. + // So, we can avoid of having this hacky try/catch block. + getFunction(name) != null + } catch { + case _: NoSuchFunctionException => false + case _: AnalysisException => false // HiveExternalCatalog wraps all exceptions with it. + } + } + } // ---------------------------------------------------------------- // | Methods that interact with temporary and metastore functions | // ---------------------------------------------------------------- + /** + * Construct a [[FunctionBuilder]] based on the provided class that represents a function. + * + * This performs reflection to decide what type of [[Expression]] to return in the builder. + */ + private[sql] def makeFunctionBuilder(name: String, functionClassName: String): FunctionBuilder = { + // TODO: at least support UDAFs here + throw new UnsupportedOperationException("Use sqlContext.udf.register(...) instead.") + } + + /** + * Loads resources such as JARs and Files for a function. Every resource is represented + * by a tuple (resource type, resource uri). + */ + def loadFunctionResources(resources: Seq[(String, String)]): Unit = { + resources.foreach { case (resourceType, uri) => + val functionResource = + FunctionResource(FunctionResourceType.fromString(resourceType.toLowerCase), uri) + functionResourceLoader.loadResource(functionResource) + } + } + /** * Create a temporary function. * This assumes no database is specified in `funcDefinition`. */ def createTempFunction( name: String, + info: ExpressionInfo, funcDefinition: FunctionBuilder, ignoreIfExists: Boolean): Unit = { if (functionRegistry.lookupFunctionBuilder(name).isDefined && !ignoreIfExists) { throw new AnalysisException(s"Temporary function '$name' already exists.") } - functionRegistry.registerFunction(name, funcDefinition) + functionRegistry.registerFunction(name, info, funcDefinition) } /** @@ -501,41 +540,59 @@ class SessionCatalog( } } - /** - * Rename a function. - * - * If a database is specified in `oldName`, this will rename the function in that database. - * If no database is specified, this will first attempt to rename a temporary function with - * the same name, then, if that does not exist, rename the function in the current database. - * - * This assumes the database specified in `oldName` matches the one specified in `newName`. - */ - def renameFunction(oldName: FunctionIdentifier, newName: FunctionIdentifier): Unit = { - if (oldName.database != newName.database) { - throw new AnalysisException("rename does not support moving functions across databases") - } - val db = oldName.database.getOrElse(currentDb) - val oldBuilder = functionRegistry.lookupFunctionBuilder(oldName.funcName) - if (oldName.database.isDefined || oldBuilder.isEmpty) { - externalCatalog.renameFunction(db, oldName.funcName, newName.funcName) - } else { - val oldExpressionInfo = functionRegistry.lookupFunction(oldName.funcName).get - val newExpressionInfo = new ExpressionInfo( - oldExpressionInfo.getClassName, - newName.funcName, - oldExpressionInfo.getUsage, - oldExpressionInfo.getExtended) - functionRegistry.dropFunction(oldName.funcName) - functionRegistry.registerFunction(newName.funcName, newExpressionInfo, oldBuilder.get) - } + protected def failFunctionLookup(name: String): Nothing = { + throw new AnalysisException(s"Undefined function: $name. This function is " + + s"neither a registered temporary function nor " + + s"a permanent function registered in the database $currentDb.") } /** * Return an [[Expression]] that represents the specified function, assuming it exists. - * Note: This is currently only used for temporary functions. + * + * For a temporary function or a permanent function that has been loaded, + * this method will simply lookup the function through the + * FunctionRegistry and create an expression based on the builder. + * + * For a permanent function that has not been loaded, we will first fetch its metadata + * from the underlying external catalog. Then, we will load all resources associated + * with this function (i.e. jars and files). Finally, we create a function builder + * based on the function class and put the builder into the FunctionRegistry. + * The name of this function in the FunctionRegistry will be `databaseName.functionName`. */ def lookupFunction(name: String, children: Seq[Expression]): Expression = { - functionRegistry.lookupFunction(name, children) + // TODO: Right now, the name can be qualified or not qualified. + // It will be better to get a FunctionIdentifier. + // TODO: Right now, we assume that name is not qualified! + val qualifiedName = FunctionIdentifier(name, Some(currentDb)).unquotedString + if (functionRegistry.functionExists(name)) { + // This function has been already loaded into the function registry. + functionRegistry.lookupFunction(name, children) + } else if (functionRegistry.functionExists(qualifiedName)) { + // This function has been already loaded into the function registry. + // Unlike the above block, we find this function by using the qualified name. + functionRegistry.lookupFunction(qualifiedName, children) + } else { + // The function has not been loaded to the function registry, which means + // that the function is a permanent function (if it actually has been registered + // in the metastore). We need to first put the function in the FunctionRegistry. + val catalogFunction = try { + externalCatalog.getFunction(currentDb, name) + } catch { + case e: AnalysisException => failFunctionLookup(name) + case e: NoSuchFunctionException => failFunctionLookup(name) + } + loadFunctionResources(catalogFunction.resources) + // Please note that qualifiedName is provided by the user. However, + // catalogFunction.identifier.unquotedString is returned by the underlying + // catalog. So, it is possible that qualifiedName is not exactly the same as + // catalogFunction.identifier.unquotedString (difference is on case-sensitivity). + // At here, we preserve the input from the user. + val info = new ExpressionInfo(catalogFunction.className, qualifiedName) + val builder = makeFunctionBuilder(qualifiedName, catalogFunction.className) + createTempFunction(qualifiedName, info, builder, ignoreIfExists = false) + // Now, we need to create the Expression. + functionRegistry.lookupFunction(qualifiedName, children) + } } /** @@ -545,17 +602,11 @@ class SessionCatalog( val dbFunctions = externalCatalog.listFunctions(db, pattern).map { f => FunctionIdentifier(f, Some(db)) } val regex = pattern.replaceAll("\\*", ".*").r - val _tempFunctions = functionRegistry.listFunction() + val loadedFunctions = functionRegistry.listFunction() .filter { f => regex.pattern.matcher(f).matches() } .map { f => FunctionIdentifier(f) } - dbFunctions ++ _tempFunctions - } - - /** - * Return a temporary function. For testing only. - */ - private[catalog] def getTempFunction(name: String): Option[FunctionBuilder] = { - functionRegistry.lookupFunctionBuilder(name) + // TODO: Actually, there will be dbFunctions that have been loaded into the FunctionRegistry. + // So, the returned list may have two entries for the same function. + dbFunctions ++ loadedFunctions } - } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/functionResources.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/functionResources.scala new file mode 100644 index 0000000000..5adcc892cf --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/functionResources.scala @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.catalog + +import org.apache.spark.sql.AnalysisException + +/** An trait that represents the type of a resourced needed by a function. */ +sealed trait FunctionResourceType + +object JarResource extends FunctionResourceType + +object FileResource extends FunctionResourceType + +// We do not allow users to specify a archive because it is YARN specific. +// When loading resources, we will throw an exception and ask users to +// use --archive with spark submit. +object ArchiveResource extends FunctionResourceType + +object FunctionResourceType { + def fromString(resourceType: String): FunctionResourceType = { + resourceType.toLowerCase match { + case "jar" => JarResource + case "file" => FileResource + case "archive" => ArchiveResource + case other => + throw new AnalysisException(s"Resource Type '$resourceType' is not supported.") + } + } +} + +case class FunctionResource(resourceType: FunctionResourceType, uri: String) + +/** + * A simple trait representing a class that can be used to load resources used by + * a function. Because only a SQLContext can load resources, we create this trait + * to avoid of explicitly passing SQLContext around. + */ +trait FunctionResourceLoader { + def loadResource(resource: FunctionResource): Unit +} + +object DummyFunctionResourceLoader extends FunctionResourceLoader { + override def loadResource(resource: FunctionResource): Unit = { + throw new UnsupportedOperationException + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala index 303846d313..97b9946140 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala @@ -150,15 +150,6 @@ abstract class ExternalCatalog { def renameFunction(db: String, oldName: String, newName: String): Unit - /** - * Alter a function whose name that matches the one specified in `funcDefinition`, - * assuming the function exists. - * - * Note: If the underlying implementation does not support altering a certain field, - * this becomes a no-op. - */ - def alterFunction(db: String, funcDefinition: CatalogFunction): Unit - def getFunction(db: String, funcName: String): CatalogFunction def listFunctions(db: String, pattern: String): Seq[String] @@ -171,8 +162,13 @@ abstract class ExternalCatalog { * * @param identifier name of the function * @param className fully qualified class name, e.g. "org.apache.spark.util.MyFunc" + * @param resources resource types and Uris used by the function */ -case class CatalogFunction(identifier: FunctionIdentifier, className: String) +// TODO: Use FunctionResource instead of (String, String) as the element type of resources. +case class CatalogFunction( + identifier: FunctionIdentifier, + className: String, + resources: Seq[(String, String)]) /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/identifiers.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/identifiers.scala index 87f4d1b007..aae75956ea 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/identifiers.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/identifiers.scala @@ -25,10 +25,10 @@ package org.apache.spark.sql.catalyst * Format (quoted): "`name`" or "`db`.`name`" */ sealed trait IdentifierWithDatabase { - val name: String + val identifier: String def database: Option[String] - def quotedString: String = database.map(db => s"`$db`.`$name`").getOrElse(s"`$name`") - def unquotedString: String = database.map(db => s"$db.$name").getOrElse(name) + def quotedString: String = database.map(db => s"`$db`.`$identifier`").getOrElse(s"`$identifier`") + def unquotedString: String = database.map(db => s"$db.$identifier").getOrElse(identifier) override def toString: String = quotedString } @@ -36,13 +36,15 @@ sealed trait IdentifierWithDatabase { /** * Identifies a table in a database. * If `database` is not defined, the current database is used. + * When we register a permenent function in the FunctionRegistry, we use + * unquotedString as the function name. */ case class TableIdentifier(table: String, database: Option[String]) extends IdentifierWithDatabase { - override val name: String = table + override val identifier: String = table - def this(name: String) = this(name, None) + def this(table: String) = this(table, None) } @@ -58,9 +60,9 @@ object TableIdentifier { case class FunctionIdentifier(funcName: String, database: Option[String]) extends IdentifierWithDatabase { - override val name: String = funcName + override val identifier: String = funcName - def this(name: String) = this(name, None) + def this(funcName: String) = this(funcName, None) } object FunctionIdentifier { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 14c90918e6..5a3aebff09 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -549,8 +549,8 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { Explode(expressions.head) case "json_tuple" => JsonTuple(expressions) - case other => - withGenerator(other, expressions, ctx) + case name => + UnresolvedGenerator(name, expressions) } Generate( @@ -562,16 +562,6 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { query) } - /** - * Create a [[Generator]]. Override this method in order to support custom Generators. - */ - protected def withGenerator( - name: String, - expressions: Seq[Expression], - ctx: LateralViewContext): Generator = { - throw new ParseException(s"Generator function '$name' is not supported", ctx) - } - /** * Create a joins between two or more logical plans. */ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala index 3ec95ef5b5..b1fcf011f4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala @@ -32,7 +32,7 @@ trait AnalysisTest extends PlanTest { val conf = new SimpleCatalystConf(caseSensitive) val catalog = new SessionCatalog(new InMemoryCatalog, EmptyFunctionRegistry, conf) catalog.createTempTable("TaBlE", TestRelations.testRelation, overrideIfExists = true) - new Analyzer(catalog, EmptyFunctionRegistry, conf) { + new Analyzer(catalog, conf) { override val extendedResolutionRules = EliminateSubqueryAliases :: Nil } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala index 1a350bf847..b3b1f5b920 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala @@ -33,7 +33,7 @@ import org.apache.spark.sql.types._ class DecimalPrecisionSuite extends PlanTest with BeforeAndAfter { private val conf = new SimpleCatalystConf(caseSensitiveAnalysis = true) private val catalog = new SessionCatalog(new InMemoryCatalog, EmptyFunctionRegistry, conf) - private val analyzer = new Analyzer(catalog, EmptyFunctionRegistry, conf) + private val analyzer = new Analyzer(catalog, conf) private val relation = LocalRelation( AttributeReference("i", IntegerType)(), diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/CatalogTestCases.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/CatalogTestCases.scala index 959bd564d9..fbcac09ce2 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/CatalogTestCases.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/CatalogTestCases.scala @@ -433,7 +433,8 @@ abstract class CatalogTestCases extends SparkFunSuite with BeforeAndAfterEach { test("get function") { val catalog = newBasicCatalog() assert(catalog.getFunction("db2", "func1") == - CatalogFunction(FunctionIdentifier("func1", Some("db2")), funcClass)) + CatalogFunction(FunctionIdentifier("func1", Some("db2")), funcClass, + Seq.empty[(String, String)])) intercept[AnalysisException] { catalog.getFunction("db2", "does_not_exist") } @@ -464,21 +465,6 @@ abstract class CatalogTestCases extends SparkFunSuite with BeforeAndAfterEach { } } - test("alter function") { - val catalog = newBasicCatalog() - assert(catalog.getFunction("db2", "func1").className == funcClass) - catalog.alterFunction("db2", newFunc("func1").copy(className = "muhaha")) - assert(catalog.getFunction("db2", "func1").className == "muhaha") - intercept[AnalysisException] { catalog.alterFunction("db2", newFunc("funcky")) } - } - - test("alter function when database does not exist") { - val catalog = newBasicCatalog() - intercept[AnalysisException] { - catalog.alterFunction("does_not_exist", newFunc()) - } - } - test("list functions") { val catalog = newBasicCatalog() catalog.createFunction("db2", newFunc("func2")) @@ -557,7 +543,7 @@ abstract class CatalogTestUtils { } def newFunc(name: String, database: Option[String] = None): CatalogFunction = { - CatalogFunction(FunctionIdentifier(name, database), funcClass) + CatalogFunction(FunctionIdentifier(name, database), funcClass, Seq.empty[(String, String)]) } /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala index acd97592b6..4d56d001b3 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.catalog import org.apache.spark.SparkFunSuite import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} -import org.apache.spark.sql.catalyst.expressions.{Expression, Literal} +import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionInfo, Literal} import org.apache.spark.sql.catalyst.plans.logical.{Range, SubqueryAlias} @@ -685,19 +685,26 @@ class SessionCatalogSuite extends SparkFunSuite { val catalog = new SessionCatalog(newBasicCatalog()) val tempFunc1 = (e: Seq[Expression]) => e.head val tempFunc2 = (e: Seq[Expression]) => e.last - catalog.createTempFunction("temp1", tempFunc1, ignoreIfExists = false) - catalog.createTempFunction("temp2", tempFunc2, ignoreIfExists = false) - assert(catalog.getTempFunction("temp1") == Some(tempFunc1)) - assert(catalog.getTempFunction("temp2") == Some(tempFunc2)) - assert(catalog.getTempFunction("temp3") == None) + val info1 = new ExpressionInfo("tempFunc1", "temp1") + val info2 = new ExpressionInfo("tempFunc2", "temp2") + catalog.createTempFunction("temp1", info1, tempFunc1, ignoreIfExists = false) + catalog.createTempFunction("temp2", info2, tempFunc2, ignoreIfExists = false) + val arguments = Seq(Literal(1), Literal(2), Literal(3)) + assert(catalog.lookupFunction("temp1", arguments) === Literal(1)) + assert(catalog.lookupFunction("temp2", arguments) === Literal(3)) + // Temporary function does not exist. + intercept[AnalysisException] { + catalog.lookupFunction("temp3", arguments) + } val tempFunc3 = (e: Seq[Expression]) => Literal(e.size) + val info3 = new ExpressionInfo("tempFunc3", "temp1") // Temporary function already exists intercept[AnalysisException] { - catalog.createTempFunction("temp1", tempFunc3, ignoreIfExists = false) + catalog.createTempFunction("temp1", info3, tempFunc3, ignoreIfExists = false) } // Temporary function is overridden - catalog.createTempFunction("temp1", tempFunc3, ignoreIfExists = true) - assert(catalog.getTempFunction("temp1") == Some(tempFunc3)) + catalog.createTempFunction("temp1", info3, tempFunc3, ignoreIfExists = true) + assert(catalog.lookupFunction("temp1", arguments) === Literal(arguments.length)) } test("drop function") { @@ -726,11 +733,15 @@ class SessionCatalogSuite extends SparkFunSuite { test("drop temp function") { val catalog = new SessionCatalog(newBasicCatalog()) + val info = new ExpressionInfo("tempFunc", "func1") val tempFunc = (e: Seq[Expression]) => e.head - catalog.createTempFunction("func1", tempFunc, ignoreIfExists = false) - assert(catalog.getTempFunction("func1") == Some(tempFunc)) + catalog.createTempFunction("func1", info, tempFunc, ignoreIfExists = false) + val arguments = Seq(Literal(1), Literal(2), Literal(3)) + assert(catalog.lookupFunction("func1", arguments) === Literal(1)) catalog.dropTempFunction("func1", ignoreIfNotExists = false) - assert(catalog.getTempFunction("func1") == None) + intercept[AnalysisException] { + catalog.lookupFunction("func1", arguments) + } intercept[AnalysisException] { catalog.dropTempFunction("func1", ignoreIfNotExists = false) } @@ -739,7 +750,9 @@ class SessionCatalogSuite extends SparkFunSuite { test("get function") { val catalog = new SessionCatalog(newBasicCatalog()) - val expected = CatalogFunction(FunctionIdentifier("func1", Some("db2")), funcClass) + val expected = + CatalogFunction(FunctionIdentifier("func1", Some("db2")), funcClass, + Seq.empty[(String, String)]) assert(catalog.getFunction(FunctionIdentifier("func1", Some("db2"))) == expected) // Get function without explicitly specifying database catalog.setCurrentDatabase("db2") @@ -758,8 +771,9 @@ class SessionCatalogSuite extends SparkFunSuite { test("lookup temp function") { val catalog = new SessionCatalog(newBasicCatalog()) + val info1 = new ExpressionInfo("tempFunc1", "func1") val tempFunc1 = (e: Seq[Expression]) => e.head - catalog.createTempFunction("func1", tempFunc1, ignoreIfExists = false) + catalog.createTempFunction("func1", info1, tempFunc1, ignoreIfExists = false) assert(catalog.lookupFunction("func1", Seq(Literal(1), Literal(2), Literal(3))) == Literal(1)) catalog.dropTempFunction("func1", ignoreIfNotExists = false) intercept[AnalysisException] { @@ -767,98 +781,16 @@ class SessionCatalogSuite extends SparkFunSuite { } } - test("rename function") { - val externalCatalog = newBasicCatalog() - val sessionCatalog = new SessionCatalog(externalCatalog) - val newName = "funcky" - assert(sessionCatalog.getFunction( - FunctionIdentifier("func1", Some("db2"))) == newFunc("func1", Some("db2"))) - assert(externalCatalog.listFunctions("db2", "*").toSet == Set("func1")) - sessionCatalog.renameFunction( - FunctionIdentifier("func1", Some("db2")), FunctionIdentifier(newName, Some("db2"))) - assert(sessionCatalog.getFunction( - FunctionIdentifier(newName, Some("db2"))) == newFunc(newName, Some("db2"))) - assert(externalCatalog.listFunctions("db2", "*").toSet == Set(newName)) - // Rename function without explicitly specifying database - sessionCatalog.setCurrentDatabase("db2") - sessionCatalog.renameFunction(FunctionIdentifier(newName), FunctionIdentifier("func1")) - assert(sessionCatalog.getFunction( - FunctionIdentifier("func1")) == newFunc("func1", Some("db2"))) - assert(externalCatalog.listFunctions("db2", "*").toSet == Set("func1")) - // Renaming "db2.func1" to "db1.func2" should fail because databases don't match - intercept[AnalysisException] { - sessionCatalog.renameFunction( - FunctionIdentifier("func1", Some("db2")), FunctionIdentifier("func2", Some("db1"))) - } - } - - test("rename function when database/function does not exist") { - val catalog = new SessionCatalog(newBasicCatalog()) - intercept[AnalysisException] { - catalog.renameFunction( - FunctionIdentifier("func1", Some("does_not_exist")), - FunctionIdentifier("func5", Some("does_not_exist"))) - } - intercept[AnalysisException] { - catalog.renameFunction( - FunctionIdentifier("does_not_exist", Some("db2")), - FunctionIdentifier("x", Some("db2"))) - } - } - - test("rename temp function") { - val externalCatalog = newBasicCatalog() - val sessionCatalog = new SessionCatalog(externalCatalog) - val tempFunc = (e: Seq[Expression]) => e.head - sessionCatalog.createTempFunction("func1", tempFunc, ignoreIfExists = false) - sessionCatalog.setCurrentDatabase("db2") - // If a database is specified, we'll always rename the function in that database - sessionCatalog.renameFunction( - FunctionIdentifier("func1", Some("db2")), FunctionIdentifier("func3", Some("db2"))) - assert(sessionCatalog.getTempFunction("func1") == Some(tempFunc)) - assert(sessionCatalog.getTempFunction("func3") == None) - assert(externalCatalog.listFunctions("db2", "*").toSet == Set("func3")) - // If no database is specified, we'll first rename temporary functions - sessionCatalog.createFunction(newFunc("func1", Some("db2"))) - sessionCatalog.renameFunction(FunctionIdentifier("func1"), FunctionIdentifier("func4")) - assert(sessionCatalog.getTempFunction("func4") == Some(tempFunc)) - assert(sessionCatalog.getTempFunction("func1") == None) - assert(externalCatalog.listFunctions("db2", "*").toSet == Set("func1", "func3")) - // Then, if no such temporary function exist, rename the function in the current database - sessionCatalog.renameFunction(FunctionIdentifier("func1"), FunctionIdentifier("func5")) - assert(sessionCatalog.getTempFunction("func5") == None) - assert(externalCatalog.listFunctions("db2", "*").toSet == Set("func3", "func5")) - } - - test("alter function") { - val catalog = new SessionCatalog(newBasicCatalog()) - assert(catalog.getFunction(FunctionIdentifier("func1", Some("db2"))).className == funcClass) - catalog.alterFunction(newFunc("func1", Some("db2")).copy(className = "muhaha")) - assert(catalog.getFunction(FunctionIdentifier("func1", Some("db2"))).className == "muhaha") - // Alter function without explicitly specifying database - catalog.setCurrentDatabase("db2") - catalog.alterFunction(newFunc("func1").copy(className = "derpy")) - assert(catalog.getFunction(FunctionIdentifier("func1")).className == "derpy") - } - - test("alter function when database/function does not exist") { - val catalog = new SessionCatalog(newBasicCatalog()) - intercept[AnalysisException] { - catalog.alterFunction(newFunc("func5", Some("does_not_exist"))) - } - intercept[AnalysisException] { - catalog.alterFunction(newFunc("funcky", Some("db2"))) - } - } - test("list functions") { val catalog = new SessionCatalog(newBasicCatalog()) + val info1 = new ExpressionInfo("tempFunc1", "func1") + val info2 = new ExpressionInfo("tempFunc2", "yes_me") val tempFunc1 = (e: Seq[Expression]) => e.head val tempFunc2 = (e: Seq[Expression]) => e.last catalog.createFunction(newFunc("func2", Some("db2"))) catalog.createFunction(newFunc("not_me", Some("db2"))) - catalog.createTempFunction("func1", tempFunc1, ignoreIfExists = false) - catalog.createTempFunction("yes_me", tempFunc2, ignoreIfExists = false) + catalog.createTempFunction("func1", info1, tempFunc1, ignoreIfExists = false) + catalog.createTempFunction("yes_me", info2, tempFunc2, ignoreIfExists = false) assert(catalog.listFunctions("db1", "*").toSet == Set(FunctionIdentifier("func1"), FunctionIdentifier("yes_me"))) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala index dd6b5cac28..8147d06969 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala @@ -141,7 +141,6 @@ class BooleanSimplificationSuite extends PlanTest with PredicateHelper { private val caseInsensitiveConf = new SimpleCatalystConf(false) private val caseInsensitiveAnalyzer = new Analyzer( new SessionCatalog(new InMemoryCatalog, EmptyFunctionRegistry, caseInsensitiveConf), - EmptyFunctionRegistry, caseInsensitiveConf) test("(a && b) || (a && c) => a && (b || c) when case insensitive") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSortsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSortsSuite.scala index 009889d5a1..8c92ad82ac 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSortsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSortsSuite.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.catalyst.rules._ class EliminateSortsSuite extends PlanTest { val conf = new SimpleCatalystConf(caseSensitiveAnalysis = true, orderByOrdinal = false) val catalog = new SessionCatalog(new InMemoryCatalog, EmptyFunctionRegistry, conf) - val analyzer = new Analyzer(catalog, EmptyFunctionRegistry, conf) + val analyzer = new Analyzer(catalog, conf) object Optimize extends RuleExecutor[LogicalPlan] { val batches = diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala index 9e1660df06..262537d9c7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.parser import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.analysis.UnresolvedGenerator import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ @@ -296,10 +297,18 @@ class PlanParserSuite extends PlanTest { .insertInto("t2"), from.where('s < 10).select(star()).insertInto("t3"))) - // Unsupported generator. - intercept( + // Unresolved generator. + val expected = table("t") + .generate( + UnresolvedGenerator("posexplode", Seq('x)), + join = true, + outer = false, + Some("posexpl"), + Seq("x", "y")) + .select(star()) + assertEqual( "select * from t lateral view posexplode(x) posexpl as x, y", - "Generator function 'posexplode' is not supported") + expected) } test("joins") { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index d4290fee0a..587ba1ea05 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -32,7 +32,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD import org.apache.spark.scheduler.{SparkListener, SparkListenerApplicationEnd} import org.apache.spark.sql.catalyst._ -import org.apache.spark.sql.catalyst.catalog.{ExternalCatalog, InMemoryCatalog} +import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.encoders.encoderFor import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, Range} @@ -208,6 +208,22 @@ class SQLContext private[sql]( sparkContext.addJar(path) } + /** A [[FunctionResourceLoader]] that can be used in SessionCatalog. */ + @transient protected[sql] lazy val functionResourceLoader: FunctionResourceLoader = { + new FunctionResourceLoader { + override def loadResource(resource: FunctionResource): Unit = { + resource.resourceType match { + case JarResource => addJar(resource.uri) + case FileResource => sparkContext.addFile(resource.uri) + case ArchiveResource => + throw new AnalysisException( + "Archive is not allowed to be loaded. If YARN mode is used, " + + "please use --archives options while calling spark-submit.") + } + } + } + } + /** * :: Experimental :: * A collection of methods that are considered experimental, but can be used to hook into diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala index fb106d1aef..382cc61fac 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala @@ -337,10 +337,9 @@ class SparkSqlAstBuilder extends AstBuilder { CreateFunction( database, function, - string(ctx.className), // TODO this is not an alias. + string(ctx.className), resources, - ctx.TEMPORARY != null)( - command(ctx)) + ctx.TEMPORARY != null) } /** @@ -353,7 +352,7 @@ class SparkSqlAstBuilder extends AstBuilder { */ override def visitDropFunction(ctx: DropFunctionContext): LogicalPlan = withOrigin(ctx) { val (database, function) = visitFunctionName(ctx.qualifiedName) - DropFunction(database, function, ctx.EXISTS != null, ctx.TEMPORARY != null)(command(ctx)) + DropFunction(database, function, ctx.EXISTS != null, ctx.TEMPORARY != null) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala index a4be3bc333..faa7a2cdb4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala @@ -426,8 +426,12 @@ case class ShowTablePropertiesCommand( * A command for users to list all of the registered functions. * The syntax of using this command in SQL is: * {{{ - * SHOW FUNCTIONS + * SHOW FUNCTIONS [LIKE pattern] * }}} + * For the pattern, '*' matches any sequence of characters (including no characters) and + * '|' is for alternation. + * For example, "show functions like 'yea*|windo*'" will return "window" and "year". + * * TODO currently we are simply ignore the db */ case class ShowFunctions(db: Option[String], pattern: Option[String]) extends RunnableCommand { @@ -438,18 +442,17 @@ case class ShowFunctions(db: Option[String], pattern: Option[String]) extends Ru schema.toAttributes } - override def run(sqlContext: SQLContext): Seq[Row] = pattern match { - case Some(p) => - try { - val regex = java.util.regex.Pattern.compile(p) - sqlContext.sessionState.functionRegistry.listFunction() - .filter(regex.matcher(_).matches()).map(Row(_)) - } catch { - // probably will failed in the regex that user provided, then returns empty row. - case _: Throwable => Seq.empty[Row] - } - case None => - sqlContext.sessionState.functionRegistry.listFunction().map(Row(_)) + override def run(sqlContext: SQLContext): Seq[Row] = { + val dbName = db.getOrElse(sqlContext.sessionState.catalog.getCurrentDatabase) + // If pattern is not specified, we use '*', which is used to + // match any sequence of characters (including no characters). + val functionNames = + sqlContext.sessionState.catalog + .listFunctions(dbName, pattern.getOrElse("*")) + .map(_.unquotedString) + // The session catalog caches some persistent functions in the FunctionRegistry + // so there can be duplicates. + functionNames.distinct.sorted.map(Row(_)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala index cd7e0eed65..6896881910 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala @@ -175,26 +175,6 @@ case class DescribeDatabase( } } -case class CreateFunction( - databaseName: Option[String], - functionName: String, - alias: String, - resources: Seq[(String, String)], - isTemp: Boolean)(sql: String) - extends NativeDDLCommand(sql) with Logging - -/** - * The DDL command that drops a function. - * ifExists: returns an error if the function doesn't exist, unless this is true. - * isTemp: indicates if it is a temporary function. - */ -case class DropFunction( - databaseName: Option[String], - functionName: String, - ifExists: Boolean, - isTemp: Boolean)(sql: String) - extends NativeDDLCommand(sql) with Logging - /** Rename in ALTER TABLE/VIEW: change the name of a table/view to a different name. */ case class AlterTableRename( oldName: TableIdentifier, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/functions.scala new file mode 100644 index 0000000000..66d17e322e --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/functions.scala @@ -0,0 +1,114 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.command + +import org.apache.spark.sql.{AnalysisException, Row, SQLContext} +import org.apache.spark.sql.catalyst.FunctionIdentifier +import org.apache.spark.sql.catalyst.catalog.CatalogFunction +import org.apache.spark.sql.catalyst.expressions.ExpressionInfo + + +/** + * The DDL command that creates a function. + * To create a temporary function, the syntax of using this command in SQL is: + * {{{ + * CREATE TEMPORARY FUNCTION functionName + * AS className [USING JAR\FILE 'uri' [, JAR|FILE 'uri']] + * }}} + * + * To create a permanent function, the syntax in SQL is: + * {{{ + * CREATE FUNCTION [databaseName.]functionName + * AS className [USING JAR\FILE 'uri' [, JAR|FILE 'uri']] + * }}} + */ +// TODO: Use Seq[FunctionResource] instead of Seq[(String, String)] for resources. +case class CreateFunction( + databaseName: Option[String], + functionName: String, + className: String, + resources: Seq[(String, String)], + isTemp: Boolean) + extends RunnableCommand { + + override def run(sqlContext: SQLContext): Seq[Row] = { + if (isTemp) { + if (databaseName.isDefined) { + throw new AnalysisException( + s"It is not allowed to provide database name when defining a temporary function. " + + s"However, database name ${databaseName.get} is provided.") + } + // We first load resources and then put the builder in the function registry. + // Please note that it is allowed to overwrite an existing temp function. + sqlContext.sessionState.catalog.loadFunctionResources(resources) + val info = new ExpressionInfo(className, functionName) + val builder = + sqlContext.sessionState.catalog.makeFunctionBuilder(functionName, className) + sqlContext.sessionState.catalog.createTempFunction( + functionName, info, builder, ignoreIfExists = false) + } else { + // For a permanent, we will store the metadata into underlying external catalog. + // This function will be loaded into the FunctionRegistry when a query uses it. + // We do not load it into FunctionRegistry right now. + val dbName = databaseName.getOrElse(sqlContext.sessionState.catalog.getCurrentDatabase) + val func = FunctionIdentifier(functionName, Some(dbName)) + val catalogFunc = CatalogFunction(func, className, resources) + if (sqlContext.sessionState.catalog.functionExists(func)) { + throw new AnalysisException( + s"Function '$functionName' already exists in database '$dbName'.") + } + sqlContext.sessionState.catalog.createFunction(catalogFunc) + } + Seq.empty[Row] + } +} + +/** + * The DDL command that drops a function. + * ifExists: returns an error if the function doesn't exist, unless this is true. + * isTemp: indicates if it is a temporary function. + */ +case class DropFunction( + databaseName: Option[String], + functionName: String, + ifExists: Boolean, + isTemp: Boolean) + extends RunnableCommand { + + override def run(sqlContext: SQLContext): Seq[Row] = { + val catalog = sqlContext.sessionState.catalog + if (isTemp) { + if (databaseName.isDefined) { + throw new AnalysisException( + s"It is not allowed to provide database name when dropping a temporary function. " + + s"However, database name ${databaseName.get} is provided.") + } + catalog.dropTempFunction(functionName, ifExists) + } else { + // We are dropping a permanent function. + val dbName = databaseName.getOrElse(catalog.getCurrentDatabase) + val func = FunctionIdentifier(functionName, Some(dbName)) + if (!ifExists && !catalog.functionExists(func)) { + throw new AnalysisException( + s"Function '$functionName' does not exist in database '$dbName'.") + } + catalog.dropFunction(func) + } + Seq.empty[Row] + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala index cd29def3be..69e3358d4e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala @@ -51,7 +51,12 @@ private[sql] class SessionState(ctx: SQLContext) { /** * Internal catalog for managing table and database states. */ - lazy val catalog = new SessionCatalog(ctx.externalCatalog, functionRegistry, conf) + lazy val catalog = + new SessionCatalog( + ctx.externalCatalog, + ctx.functionResourceLoader, + functionRegistry, + conf) /** * Interface exposed to the user for registering user-defined functions. @@ -62,7 +67,7 @@ private[sql] class SessionState(ctx: SQLContext) { * Logical query plan analyzer for resolving unresolved attributes and relations. */ lazy val analyzer: Analyzer = { - new Analyzer(catalog, functionRegistry, conf) { + new Analyzer(catalog, conf) { override val extendedResolutionRules = PreInsertCastAndRename :: DataSourceAnalysis :: @@ -98,5 +103,5 @@ private[sql] class SessionState(ctx: SQLContext) { * Interface to start and stop [[org.apache.spark.sql.ContinuousQuery]]s. */ lazy val continuousQueryManager: ContinuousQueryManager = new ContinuousQueryManager(ctx) - } + diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index b727e88668..5a851b47ca 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -61,8 +61,12 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { .filter(regex.matcher(_).matches()).map(Row(_)) } checkAnswer(sql("SHOW functions"), getFunctions(".*")) - Seq("^c.*", ".*e$", "log.*", ".*date.*").foreach { pattern => - checkAnswer(sql(s"SHOW FUNCTIONS '$pattern'"), getFunctions(pattern)) + Seq("^c*", "*e$", "log*", "*date*").foreach { pattern => + // For the pattern part, only '*' and '|' are allowed as wildcards. + // For '*', we need to replace it to '.*'. + checkAnswer( + sql(s"SHOW FUNCTIONS '$pattern'"), + getFunctions(pattern.replaceAll("\\*", ".*"))) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala index fd736718af..ec950332c5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala @@ -83,7 +83,8 @@ class UDFSuite extends QueryTest with SharedSQLContext { val e = intercept[AnalysisException] { df.selectExpr("a_function_that_does_not_exist()") } - assert(e.getMessage.contains("undefined function")) + assert(e.getMessage.contains("Undefined function")) + assert(e.getMessage.contains("a_function_that_does_not_exist")) } test("Simple UDF") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala index 47e295a7e7..c42e8e7233 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala @@ -147,13 +147,13 @@ class DDLCommandSuite extends PlanTest { "helloworld", "com.matthewrathbone.example.SimpleUDFExample", Seq(("jar", "/path/to/jar1"), ("jar", "/path/to/jar2")), - isTemp = true)(sql1) + isTemp = true) val expected2 = CreateFunction( Some("hello"), "world", "com.matthewrathbone.example.SimpleUDFExample", Seq(("archive", "/path/to/archive"), ("file", "/path/to/file")), - isTemp = false)(sql2) + isTemp = false) comparePlans(parsed1, expected1) comparePlans(parsed2, expected2) } @@ -173,22 +173,22 @@ class DDLCommandSuite extends PlanTest { None, "helloworld", ifExists = false, - isTemp = true)(sql1) + isTemp = true) val expected2 = DropFunction( None, "helloworld", ifExists = true, - isTemp = true)(sql2) + isTemp = true) val expected3 = DropFunction( Some("hello"), "world", ifExists = false, - isTemp = false)(sql3) + isTemp = false) val expected4 = DropFunction( Some("hello"), "world", ifExists = true, - isTemp = false)(sql4) + isTemp = false) comparePlans(parsed1, expected1) comparePlans(parsed2, expected2) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala index 80a85a6615..7844d1b296 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala @@ -29,6 +29,7 @@ import org.scalatest.BeforeAndAfterAll import org.apache.spark.SparkFunSuite import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.analysis.NoSuchTableException +import org.apache.spark.sql.catalyst.FunctionIdentifier import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.execution.Filter @@ -131,6 +132,27 @@ private[sql] trait SQLTestUtils try f(dir) finally Utils.deleteRecursively(dir) } + /** + * Drops functions after calling `f`. A function is represented by (functionName, isTemporary). + */ + protected def withUserDefinedFunction(functions: (String, Boolean)*)(f: => Unit): Unit = { + try { + f + } catch { + case cause: Throwable => throw cause + } finally { + // If the test failed part way, we don't want to mask the failure by failing to remove + // temp tables that never got created. + try functions.foreach { case (functionName, isTemporary) => + val withTemporary = if (isTemporary) "TEMPORARY" else "" + sqlContext.sql(s"DROP $withTemporary FUNCTION IF EXISTS $functionName") + assert( + !sqlContext.sessionState.catalog.functionExists(FunctionIdentifier(functionName)), + s"Function $functionName should have been dropped. But, it still exists.") + } + } + } + /** * Drops temporary table `tableName` after calling `f`. */ diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala index 2c7358e59a..a1268b8e94 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala @@ -491,46 +491,50 @@ class HiveThriftBinaryServerSuite extends HiveThriftJdbcTest { test("SPARK-11595 ADD JAR with input path having URL scheme") { withJdbcStatement { statement => - val jarPath = "../hive/src/test/resources/TestUDTF.jar" - val jarURL = s"file://${System.getProperty("user.dir")}/$jarPath" + try { + val jarPath = "../hive/src/test/resources/TestUDTF.jar" + val jarURL = s"file://${System.getProperty("user.dir")}/$jarPath" - Seq( - s"ADD JAR $jarURL", - s"""CREATE TEMPORARY FUNCTION udtf_count2 - |AS 'org.apache.spark.sql.hive.execution.GenericUDTFCount2' - """.stripMargin - ).foreach(statement.execute) + Seq( + s"ADD JAR $jarURL", + s"""CREATE TEMPORARY FUNCTION udtf_count2 + |AS 'org.apache.spark.sql.hive.execution.GenericUDTFCount2' + """.stripMargin + ).foreach(statement.execute) - val rs1 = statement.executeQuery("DESCRIBE FUNCTION udtf_count2") + val rs1 = statement.executeQuery("DESCRIBE FUNCTION udtf_count2") - assert(rs1.next()) - assert(rs1.getString(1) === "Function: udtf_count2") + assert(rs1.next()) + assert(rs1.getString(1) === "Function: udtf_count2") - assert(rs1.next()) - assertResult("Class: org.apache.spark.sql.hive.execution.GenericUDTFCount2") { - rs1.getString(1) - } + assert(rs1.next()) + assertResult("Class: org.apache.spark.sql.hive.execution.GenericUDTFCount2") { + rs1.getString(1) + } - assert(rs1.next()) - assert(rs1.getString(1) === "Usage: To be added.") + assert(rs1.next()) + assert(rs1.getString(1) === "Usage: To be added.") - val dataPath = "../hive/src/test/resources/data/files/kv1.txt" + val dataPath = "../hive/src/test/resources/data/files/kv1.txt" - Seq( - s"CREATE TABLE test_udtf(key INT, value STRING)", - s"LOAD DATA LOCAL INPATH '$dataPath' OVERWRITE INTO TABLE test_udtf" - ).foreach(statement.execute) + Seq( + s"CREATE TABLE test_udtf(key INT, value STRING)", + s"LOAD DATA LOCAL INPATH '$dataPath' OVERWRITE INTO TABLE test_udtf" + ).foreach(statement.execute) - val rs2 = statement.executeQuery( - "SELECT key, cc FROM test_udtf LATERAL VIEW udtf_count2(value) dd AS cc") + val rs2 = statement.executeQuery( + "SELECT key, cc FROM test_udtf LATERAL VIEW udtf_count2(value) dd AS cc") - assert(rs2.next()) - assert(rs2.getInt(1) === 97) - assert(rs2.getInt(2) === 500) + assert(rs2.next()) + assert(rs2.getInt(1) === 97) + assert(rs2.getInt(2) === 500) - assert(rs2.next()) - assert(rs2.getInt(1) === 97) - assert(rs2.getInt(2) === 500) + assert(rs2.next()) + assert(rs2.getInt(1) === 97) + assert(rs2.getInt(2) === 500) + } finally { + statement.executeQuery("DROP TEMPORARY FUNCTION udtf_count2") + } } } @@ -565,24 +569,28 @@ class SingleSessionSuite extends HiveThriftJdbcTest { }, { statement => - val rs1 = statement.executeQuery("SET foo") + try { + val rs1 = statement.executeQuery("SET foo") - assert(rs1.next()) - assert(rs1.getString(1) === "foo") - assert(rs1.getString(2) === "bar") + assert(rs1.next()) + assert(rs1.getString(1) === "foo") + assert(rs1.getString(2) === "bar") - val rs2 = statement.executeQuery("DESCRIBE FUNCTION udtf_count2") + val rs2 = statement.executeQuery("DESCRIBE FUNCTION udtf_count2") - assert(rs2.next()) - assert(rs2.getString(1) === "Function: udtf_count2") + assert(rs2.next()) + assert(rs2.getString(1) === "Function: udtf_count2") - assert(rs2.next()) - assertResult("Class: org.apache.spark.sql.hive.execution.GenericUDTFCount2") { - rs2.getString(1) - } + assert(rs2.next()) + assertResult("Class: org.apache.spark.sql.hive.execution.GenericUDTFCount2") { + rs2.getString(1) + } - assert(rs2.next()) - assert(rs2.getString(1) === "Usage: To be added.") + assert(rs2.next()) + assert(rs2.getString(1) === "Usage: To be added.") + } finally { + statement.executeQuery("DROP TEMPORARY FUNCTION udtf_count2") + } } ) } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala index 11205ae67c..98a5998d03 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala @@ -272,7 +272,12 @@ private[spark] class HiveExternalCatalog(client: HiveClient) extends ExternalCat override def createFunction( db: String, funcDefinition: CatalogFunction): Unit = withClient { - client.createFunction(db, funcDefinition) + // Hive's metastore is case insensitive. However, Hive's createFunction does + // not normalize the function name (unlike the getFunction part). So, + // we are normalizing the function name. + val functionName = funcDefinition.identifier.funcName.toLowerCase + val functionIdentifier = funcDefinition.identifier.copy(funcName = functionName) + client.createFunction(db, funcDefinition.copy(identifier = functionIdentifier)) } override def dropFunction(db: String, name: String): Unit = withClient { @@ -283,10 +288,6 @@ private[spark] class HiveExternalCatalog(client: HiveClient) extends ExternalCat client.renameFunction(db, oldName, newName) } - override def alterFunction(db: String, funcDefinition: CatalogFunction): Unit = withClient { - client.alterFunction(db, funcDefinition) - } - override def getFunction(db: String, funcName: String): CatalogFunction = withClient { client.getFunction(db, funcName) } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala index dfbf22cc47..d315f39a91 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala @@ -17,27 +17,39 @@ package org.apache.spark.sql.hive +import scala.util.{Failure, Success, Try} +import scala.util.control.NonFatal + import org.apache.hadoop.fs.Path import org.apache.hadoop.hive.conf.HiveConf +import org.apache.hadoop.hive.ql.exec.{UDAF, UDF} +import org.apache.hadoop.hive.ql.exec.{FunctionRegistry => HiveFunctionRegistry} +import org.apache.hadoop.hive.ql.udf.generic.{AbstractGenericUDAFResolver, GenericUDF, GenericUDTF} import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.FunctionRegistry -import org.apache.spark.sql.catalyst.catalog.SessionCatalog +import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder +import org.apache.spark.sql.catalyst.catalog.{FunctionResourceLoader, SessionCatalog} +import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionInfo} import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, SubqueryAlias} import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.execution.datasources.BucketSpec +import org.apache.spark.sql.hive.HiveShim.HiveFunctionWrapper import org.apache.spark.sql.hive.client.HiveClient import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.StructType +import org.apache.spark.util.Utils -class HiveSessionCatalog( +private[sql] class HiveSessionCatalog( externalCatalog: HiveExternalCatalog, client: HiveClient, context: HiveContext, + functionResourceLoader: FunctionResourceLoader, functionRegistry: FunctionRegistry, conf: SQLConf) - extends SessionCatalog(externalCatalog, functionRegistry, conf) { + extends SessionCatalog(externalCatalog, functionResourceLoader, functionRegistry, conf) { override def setCurrentDatabase(db: String): Unit = { super.setCurrentDatabase(db) @@ -112,4 +124,129 @@ class HiveSessionCatalog( metastoreCatalog.cachedDataSourceTables.getIfPresent(key) } + override def makeFunctionBuilder(funcName: String, className: String): FunctionBuilder = { + makeFunctionBuilder(funcName, Utils.classForName(className)) + } + + /** + * Construct a [[FunctionBuilder]] based on the provided class that represents a function. + */ + private def makeFunctionBuilder(name: String, clazz: Class[_]): FunctionBuilder = { + // When we instantiate hive UDF wrapper class, we may throw exception if the input + // expressions don't satisfy the hive UDF, such as type mismatch, input number + // mismatch, etc. Here we catch the exception and throw AnalysisException instead. + (children: Seq[Expression]) => { + try { + if (classOf[UDF].isAssignableFrom(clazz)) { + val udf = HiveSimpleUDF(name, new HiveFunctionWrapper(clazz.getName), children) + udf.dataType // Force it to check input data types. + udf + } else if (classOf[GenericUDF].isAssignableFrom(clazz)) { + val udf = HiveGenericUDF(name, new HiveFunctionWrapper(clazz.getName), children) + udf.dataType // Force it to check input data types. + udf + } else if (classOf[AbstractGenericUDAFResolver].isAssignableFrom(clazz)) { + val udaf = HiveUDAFFunction(name, new HiveFunctionWrapper(clazz.getName), children) + udaf.dataType // Force it to check input data types. + udaf + } else if (classOf[UDAF].isAssignableFrom(clazz)) { + val udaf = HiveUDAFFunction( + name, + new HiveFunctionWrapper(clazz.getName), + children, + isUDAFBridgeRequired = true) + udaf.dataType // Force it to check input data types. + udaf + } else if (classOf[GenericUDTF].isAssignableFrom(clazz)) { + val udtf = HiveGenericUDTF(name, new HiveFunctionWrapper(clazz.getName), children) + udtf.elementTypes // Force it to check input data types. + udtf + } else { + throw new AnalysisException(s"No handler for Hive UDF '${clazz.getCanonicalName}'") + } + } catch { + case ae: AnalysisException => + throw ae + case NonFatal(e) => + val analysisException = + new AnalysisException(s"No handler for Hive UDF '${clazz.getCanonicalName}': $e") + analysisException.setStackTrace(e.getStackTrace) + throw analysisException + } + } + } + + // We have a list of Hive built-in functions that we do not support. So, we will check + // Hive's function registry and lazily load needed functions into our own function registry. + // Those Hive built-in functions are + // assert_true, collect_list, collect_set, compute_stats, context_ngrams, create_union, + // current_user ,elt, ewah_bitmap, ewah_bitmap_and, ewah_bitmap_empty, ewah_bitmap_or, field, + // histogram_numeric, in_file, index, inline, java_method, map_keys, map_values, + // matchpath, ngrams, noop, noopstreaming, noopwithmap, noopwithmapstreaming, + // parse_url, parse_url_tuple, percentile, percentile_approx, posexplode, reflect, reflect2, + // regexp, sentences, stack, std, str_to_map, windowingtablefunction, xpath, xpath_boolean, + // xpath_double, xpath_float, xpath_int, xpath_long, xpath_number, + // xpath_short, and xpath_string. + override def lookupFunction(name: String, children: Seq[Expression]): Expression = { + // TODO: Once lookupFunction accepts a FunctionIdentifier, we should refactor this method to + // if (super.functionExists(name)) { + // super.lookupFunction(name, children) + // } else { + // // This function is a Hive builtin function. + // ... + // } + Try(super.lookupFunction(name, children)) match { + case Success(expr) => expr + case Failure(error) => + if (functionRegistry.functionExists(name)) { + // If the function actually exists in functionRegistry, it means that there is an + // error when we create the Expression using the given children. + // We need to throw the original exception. + throw error + } else { + // This function is not in functionRegistry, let's try to load it as a Hive's + // built-in function. + // Hive is case insensitive. + val functionName = name.toLowerCase + // TODO: This may not really work for current_user because current_user is not evaluated + // with session info. + // We do not need to use executionHive at here because we only load + // Hive's builtin functions, which do not need current db. + val functionInfo = { + try { + Option(HiveFunctionRegistry.getFunctionInfo(functionName)).getOrElse( + failFunctionLookup(name)) + } catch { + // If HiveFunctionRegistry.getFunctionInfo throws an exception, + // we are failing to load a Hive builtin function, which means that + // the given function is not a Hive builtin function. + case NonFatal(e) => failFunctionLookup(name) + } + } + val className = functionInfo.getFunctionClass.getName + val builder = makeFunctionBuilder(functionName, className) + // Put this Hive built-in function to our function registry. + val info = new ExpressionInfo(className, functionName) + createTempFunction(functionName, info, builder, ignoreIfExists = false) + // Now, we need to create the Expression. + functionRegistry.lookupFunction(functionName, children) + } + } + } + + // Pre-load a few commonly used Hive built-in functions. + HiveSessionCatalog.preloadedHiveBuiltinFunctions.foreach { + case (functionName, clazz) => + val builder = makeFunctionBuilder(functionName, clazz) + val info = new ExpressionInfo(clazz.getCanonicalName, functionName) + createTempFunction(functionName, info, builder, ignoreIfExists = false) + } +} + +private[sql] object HiveSessionCatalog { + // This is the list of Hive's built-in functions that are commonly used and we want to + // pre-load when we create the FunctionRegistry. + val preloadedHiveBuiltinFunctions = + ("collect_set", classOf[org.apache.hadoop.hive.ql.udf.generic.GenericUDAFCollectSet]) :: + ("collect_list", classOf[org.apache.hadoop.hive.ql.udf.generic.GenericUDAFCollectList]) :: Nil } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala index 829afa8432..cff24e28fd 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala @@ -35,26 +35,24 @@ private[hive] class HiveSessionState(ctx: HiveContext) extends SessionState(ctx) override def caseSensitiveAnalysis: Boolean = getConf(SQLConf.CASE_SENSITIVE, false) } - /** - * Internal catalog for managing functions registered by the user. - * Note that HiveUDFs will be overridden by functions registered in this context. - */ - override lazy val functionRegistry: FunctionRegistry = { - new HiveFunctionRegistry(FunctionRegistry.builtin.copy(), ctx.executionHive) - } - /** * Internal catalog for managing table and database states. */ override lazy val catalog = { - new HiveSessionCatalog(ctx.hiveCatalog, ctx.metadataHive, ctx, functionRegistry, conf) + new HiveSessionCatalog( + ctx.hiveCatalog, + ctx.metadataHive, + ctx, + ctx.functionResourceLoader, + functionRegistry, + conf) } /** * An analyzer that uses the Hive metastore. */ override lazy val analyzer: Analyzer = { - new Analyzer(catalog, functionRegistry, conf) { + new Analyzer(catalog, conf) { override val extendedResolutionRules = catalog.ParquetConversions :: catalog.OrcConversions :: diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala index a31178e347..1f66fbfd85 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala @@ -21,13 +21,14 @@ import java.io.{File, PrintStream} import scala.collection.JavaConverters._ import scala.language.reflectiveCalls +import scala.util.Try import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.apache.hadoop.hive.cli.CliSessionState import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.metastore.{TableType => HiveTableType} -import org.apache.hadoop.hive.metastore.api.{Database => HiveDatabase, FieldSchema, Function => HiveFunction, FunctionType, PrincipalType, ResourceUri} +import org.apache.hadoop.hive.metastore.api.{Database => HiveDatabase, FieldSchema, Function => HiveFunction, FunctionType, PrincipalType, ResourceType, ResourceUri} import org.apache.hadoop.hive.ql.Driver import org.apache.hadoop.hive.ql.metadata.{Hive, Partition => HivePartition, Table => HiveTable} import org.apache.hadoop.hive.ql.plan.AddPartitionDesc @@ -37,6 +38,7 @@ import org.apache.hadoop.security.UserGroupInformation import org.apache.spark.{SparkConf, SparkException} import org.apache.spark.internal.Logging +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} import org.apache.spark.sql.catalyst.analysis.{NoSuchDatabaseException, NoSuchPartitionException} import org.apache.spark.sql.catalyst.catalog._ @@ -611,6 +613,9 @@ private[hive] class HiveClientImpl( .asInstanceOf[Class[_ <: org.apache.hadoop.hive.ql.io.HiveOutputFormat[_, _]]] private def toHiveFunction(f: CatalogFunction, db: String): HiveFunction = { + val resourceUris = f.resources.map { case (resourceType, resourcePath) => + new ResourceUri(ResourceType.valueOf(resourceType.toUpperCase), resourcePath) + } new HiveFunction( f.identifier.funcName, db, @@ -619,12 +624,21 @@ private[hive] class HiveClientImpl( PrincipalType.USER, (System.currentTimeMillis / 1000).toInt, FunctionType.JAVA, - List.empty[ResourceUri].asJava) + resourceUris.asJava) } private def fromHiveFunction(hf: HiveFunction): CatalogFunction = { val name = FunctionIdentifier(hf.getFunctionName, Option(hf.getDbName)) - new CatalogFunction(name, hf.getClassName) + val resources = hf.getResourceUris.asScala.map { uri => + val resourceType = uri.getResourceType() match { + case ResourceType.ARCHIVE => "archive" + case ResourceType.FILE => "file" + case ResourceType.JAR => "jar" + case r => throw new AnalysisException(s"Unknown resource type: $r") + } + (resourceType, uri.getUri()) + } + new CatalogFunction(name, hf.getClassName, resources) } private def toHiveColumn(c: CatalogColumn): FieldSchema = { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveSqlParser.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveSqlParser.scala index 55e69f99a4..c6c0b2ca59 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveSqlParser.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveSqlParser.scala @@ -21,13 +21,13 @@ import scala.collection.JavaConverters._ import org.antlr.v4.runtime.{ParserRuleContext, Token} import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.conf.HiveConf.ConfVars -import org.apache.hadoop.hive.ql.exec.FunctionRegistry import org.apache.hadoop.hive.ql.parse.EximUtil import org.apache.hadoop.hive.ql.session.SessionState import org.apache.hadoop.hive.serde.serdeConstants import org.apache.hadoop.hive.serde2.`lazy`.LazySimpleSerDe -import org.apache.spark.sql.catalyst.catalog.{CatalogColumn, CatalogStorageFormat, CatalogTable, CatalogTableType} +import org.apache.spark.sql.catalyst.analysis.UnresolvedGenerator +import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.parser._ import org.apache.spark.sql.catalyst.parser.SqlBaseParser._ @@ -277,19 +277,6 @@ class HiveSqlAstBuilder extends SparkSqlAstBuilder { CreateView(tableDesc, plan(query), allowExist, replace, command(ctx)) } - /** - * Create a [[Generator]]. Override this method in order to support custom Generators. - */ - override protected def withGenerator( - name: String, - expressions: Seq[Expression], - ctx: LateralViewContext): Generator = { - val info = Option(FunctionRegistry.getFunctionInfo(name.toLowerCase)).getOrElse { - throw new ParseException(s"Couldn't find Generator function '$name'", ctx) - } - HiveGenericUDTF(name, new HiveFunctionWrapper(info.getFunctionClass.getName), expressions) - } - /** * Create a [[HiveScriptIOSchema]]. */ diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala index 5ada3d5598..784b018353 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql.hive import scala.collection.JavaConverters._ import scala.collection.mutable.ArrayBuffer -import scala.util.Try import org.apache.hadoop.hive.ql.exec._ import org.apache.hadoop.hive.ql.udf.{UDFType => HiveUDFType} @@ -31,130 +30,14 @@ import org.apache.hadoop.hive.serde2.objectinspector.{ConstantObjectInspector, O import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory.ObjectInspectorOptions import org.apache.spark.internal.Logging -import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.{analysis, InternalRow} -import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.hive.HiveShim._ -import org.apache.spark.sql.hive.client.HiveClientImpl import org.apache.spark.sql.types._ -private[hive] class HiveFunctionRegistry( - underlying: analysis.FunctionRegistry, - executionHive: HiveClientImpl) - extends analysis.FunctionRegistry with HiveInspectors { - - def getFunctionInfo(name: String): FunctionInfo = { - // Hive Registry need current database to lookup function - // TODO: the current database of executionHive should be consistent with metadataHive - executionHive.withHiveState { - FunctionRegistry.getFunctionInfo(name) - } - } - - override def lookupFunction(name: String, children: Seq[Expression]): Expression = { - Try(underlying.lookupFunction(name, children)).getOrElse { - // We only look it up to see if it exists, but do not include it in the HiveUDF since it is - // not always serializable. - val functionInfo: FunctionInfo = - Option(getFunctionInfo(name.toLowerCase)).getOrElse( - throw new AnalysisException(s"undefined function $name")) - - val functionClassName = functionInfo.getFunctionClass.getName - - // When we instantiate hive UDF wrapper class, we may throw exception if the input expressions - // don't satisfy the hive UDF, such as type mismatch, input number mismatch, etc. Here we - // catch the exception and throw AnalysisException instead. - try { - if (classOf[GenericUDFMacro].isAssignableFrom(functionInfo.getFunctionClass)) { - val udf = HiveGenericUDF( - name, new HiveFunctionWrapper(functionClassName, functionInfo.getGenericUDF), children) - udf.dataType // Force it to check input data types. - udf - } else if (classOf[UDF].isAssignableFrom(functionInfo.getFunctionClass)) { - val udf = HiveSimpleUDF(name, new HiveFunctionWrapper(functionClassName), children) - udf.dataType // Force it to check input data types. - udf - } else if (classOf[GenericUDF].isAssignableFrom(functionInfo.getFunctionClass)) { - val udf = HiveGenericUDF(name, new HiveFunctionWrapper(functionClassName), children) - udf.dataType // Force it to check input data types. - udf - } else if ( - classOf[AbstractGenericUDAFResolver].isAssignableFrom(functionInfo.getFunctionClass)) { - val udaf = HiveUDAFFunction(name, new HiveFunctionWrapper(functionClassName), children) - udaf.dataType // Force it to check input data types. - udaf - } else if (classOf[UDAF].isAssignableFrom(functionInfo.getFunctionClass)) { - val udaf = HiveUDAFFunction( - name, new HiveFunctionWrapper(functionClassName), children, isUDAFBridgeRequired = true) - udaf.dataType // Force it to check input data types. - udaf - } else if (classOf[GenericUDTF].isAssignableFrom(functionInfo.getFunctionClass)) { - val udtf = HiveGenericUDTF(name, new HiveFunctionWrapper(functionClassName), children) - udtf.elementTypes // Force it to check input data types. - udtf - } else { - throw new AnalysisException(s"No handler for udf ${functionInfo.getFunctionClass}") - } - } catch { - case analysisException: AnalysisException => - // If the exception is an AnalysisException, just throw it. - throw analysisException - case throwable: Throwable => - // If there is any other error, we throw an AnalysisException. - val errorMessage = s"No handler for Hive udf ${functionInfo.getFunctionClass} " + - s"because: ${throwable.getMessage}." - val analysisException = new AnalysisException(errorMessage) - analysisException.setStackTrace(throwable.getStackTrace) - throw analysisException - } - } - } - - override def registerFunction(name: String, info: ExpressionInfo, builder: FunctionBuilder) - : Unit = underlying.registerFunction(name, info, builder) - - /* List all of the registered function names. */ - override def listFunction(): Seq[String] = { - (FunctionRegistry.getFunctionNames.asScala ++ underlying.listFunction()).toList.sorted - } - - /* Get the class of the registered function by specified name. */ - override def lookupFunction(name: String): Option[ExpressionInfo] = { - underlying.lookupFunction(name).orElse( - Try { - val info = getFunctionInfo(name) - val annotation = info.getFunctionClass.getAnnotation(classOf[Description]) - if (annotation != null) { - Some(new ExpressionInfo( - info.getFunctionClass.getCanonicalName, - annotation.name(), - annotation.value(), - annotation.extended())) - } else { - Some(new ExpressionInfo( - info.getFunctionClass.getCanonicalName, - name, - null, - null)) - } - }.getOrElse(None)) - } - - override def lookupFunctionBuilder(name: String): Option[FunctionBuilder] = { - underlying.lookupFunctionBuilder(name) - } - - // Note: This does not drop functions stored in the metastore - override def dropFunction(name: String): Boolean = { - underlying.dropFunction(name) - } - -} - private[hive] case class HiveSimpleUDF( name: String, funcWrapper: HiveFunctionWrapper, children: Seq[Expression]) extends Expression with HiveInspectors with CodegenFallback with Logging { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala index 9393302355..7f6ca21782 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala @@ -201,8 +201,13 @@ class TestHiveContext private[hive]( } override lazy val functionRegistry = { - new TestHiveFunctionRegistry( - org.apache.spark.sql.catalyst.analysis.FunctionRegistry.builtin.copy(), self.executionHive) + // We use TestHiveFunctionRegistry at here to track functions that have been explicitly + // unregistered (through TestHiveFunctionRegistry.unregisterFunction method). + val fr = new TestHiveFunctionRegistry + org.apache.spark.sql.catalyst.analysis.FunctionRegistry.expressions.foreach { + case (name, (info, builder)) => fr.registerFunction(name, info, builder) + } + fr } } @@ -528,19 +533,18 @@ class TestHiveContext private[hive]( } -private[hive] class TestHiveFunctionRegistry(fr: SimpleFunctionRegistry, client: HiveClientImpl) - extends HiveFunctionRegistry(fr, client) { +private[hive] class TestHiveFunctionRegistry extends SimpleFunctionRegistry { private val removedFunctions = collection.mutable.ArrayBuffer.empty[(String, (ExpressionInfo, FunctionBuilder))] def unregisterFunction(name: String): Unit = { - fr.functionBuilders.remove(name).foreach(f => removedFunctions += name -> f) + functionBuilders.remove(name).foreach(f => removedFunctions += name -> f) } def restore(): Unit = { removedFunctions.foreach { - case (name, (info, builder)) => fr.registerFunction(name, info, builder) + case (name, (info, builder)) => registerFunction(name, info, builder) } } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala index 53dec6348f..dd2129375d 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala @@ -32,6 +32,8 @@ import org.scalatest.time.SpanSugar._ import org.apache.spark._ import org.apache.spark.internal.Logging import org.apache.spark.sql.{QueryTest, Row, SQLContext} +import org.apache.spark.sql.catalyst.catalog.CatalogFunction +import org.apache.spark.sql.catalyst.FunctionIdentifier import org.apache.spark.sql.expressions.Window import org.apache.spark.sql.hive.test.{TestHive, TestHiveContext} import org.apache.spark.sql.test.ProcessTestUtils.ProcessOutputCapturer @@ -55,6 +57,57 @@ class HiveSparkSubmitSuite System.setProperty("spark.testing", "true") } + test("temporary Hive UDF: define a UDF and use it") { + val unusedJar = TestUtils.createJarWithClasses(Seq.empty) + val jar1 = TestUtils.createJarWithClasses(Seq("SparkSubmitClassA")) + val jar2 = TestUtils.createJarWithClasses(Seq("SparkSubmitClassB")) + val jarsString = Seq(jar1, jar2).map(j => j.toString).mkString(",") + val args = Seq( + "--class", TemporaryHiveUDFTest.getClass.getName.stripSuffix("$"), + "--name", "TemporaryHiveUDFTest", + "--master", "local-cluster[2,1,1024]", + "--conf", "spark.ui.enabled=false", + "--conf", "spark.master.rest.enabled=false", + "--driver-java-options", "-Dderby.system.durability=test", + "--jars", jarsString, + unusedJar.toString, "SparkSubmitClassA", "SparkSubmitClassB") + runSparkSubmit(args) + } + + test("permanent Hive UDF: define a UDF and use it") { + val unusedJar = TestUtils.createJarWithClasses(Seq.empty) + val jar1 = TestUtils.createJarWithClasses(Seq("SparkSubmitClassA")) + val jar2 = TestUtils.createJarWithClasses(Seq("SparkSubmitClassB")) + val jarsString = Seq(jar1, jar2).map(j => j.toString).mkString(",") + val args = Seq( + "--class", PermanentHiveUDFTest1.getClass.getName.stripSuffix("$"), + "--name", "PermanentHiveUDFTest1", + "--master", "local-cluster[2,1,1024]", + "--conf", "spark.ui.enabled=false", + "--conf", "spark.master.rest.enabled=false", + "--driver-java-options", "-Dderby.system.durability=test", + "--jars", jarsString, + unusedJar.toString, "SparkSubmitClassA", "SparkSubmitClassB") + runSparkSubmit(args) + } + + test("permanent Hive UDF: use a already defined permanent function") { + val unusedJar = TestUtils.createJarWithClasses(Seq.empty) + val jar1 = TestUtils.createJarWithClasses(Seq("SparkSubmitClassA")) + val jar2 = TestUtils.createJarWithClasses(Seq("SparkSubmitClassB")) + val jarsString = Seq(jar1, jar2).map(j => j.toString).mkString(",") + val args = Seq( + "--class", PermanentHiveUDFTest2.getClass.getName.stripSuffix("$"), + "--name", "PermanentHiveUDFTest2", + "--master", "local-cluster[2,1,1024]", + "--conf", "spark.ui.enabled=false", + "--conf", "spark.master.rest.enabled=false", + "--driver-java-options", "-Dderby.system.durability=test", + "--jars", jarsString, + unusedJar.toString, "SparkSubmitClassA", "SparkSubmitClassB") + runSparkSubmit(args) + } + test("SPARK-8368: includes jars passed in through --jars") { val unusedJar = TestUtils.createJarWithClasses(Seq.empty) val jar1 = TestUtils.createJarWithClasses(Seq("SparkSubmitClassA")) @@ -208,6 +261,118 @@ class HiveSparkSubmitSuite } } +// This application is used to test defining a new Hive UDF (with an associated jar) +// and use this UDF. We need to run this test in separate JVM to make sure we +// can load the jar defined with the function. +object TemporaryHiveUDFTest extends Logging { + def main(args: Array[String]) { + Utils.configTestLog4j("INFO") + val conf = new SparkConf() + conf.set("spark.ui.enabled", "false") + val sc = new SparkContext(conf) + val hiveContext = new TestHiveContext(sc) + + // Load a Hive UDF from the jar. + logInfo("Registering a temporary Hive UDF provided in a jar.") + val jar = hiveContext.getHiveFile("hive-contrib-0.13.1.jar").getCanonicalPath + hiveContext.sql( + s""" + |CREATE TEMPORARY FUNCTION example_max + |AS 'org.apache.hadoop.hive.contrib.udaf.example.UDAFExampleMax' + |USING JAR '$jar' + """.stripMargin) + val source = + hiveContext.createDataFrame((1 to 10).map(i => (i, s"str$i"))).toDF("key", "val") + source.registerTempTable("sourceTable") + // Actually use the loaded UDF. + logInfo("Using the UDF.") + val result = hiveContext.sql( + "SELECT example_max(key) as key, val FROM sourceTable GROUP BY val") + logInfo("Running a simple query on the table.") + val count = result.orderBy("key", "val").count() + if (count != 10) { + throw new Exception(s"Result table should have 10 rows instead of $count rows") + } + hiveContext.sql("DROP temporary FUNCTION example_max") + logInfo("Test finishes.") + sc.stop() + } +} + +// This application is used to test defining a new Hive UDF (with an associated jar) +// and use this UDF. We need to run this test in separate JVM to make sure we +// can load the jar defined with the function. +object PermanentHiveUDFTest1 extends Logging { + def main(args: Array[String]) { + Utils.configTestLog4j("INFO") + val conf = new SparkConf() + conf.set("spark.ui.enabled", "false") + val sc = new SparkContext(conf) + val hiveContext = new TestHiveContext(sc) + + // Load a Hive UDF from the jar. + logInfo("Registering a permanent Hive UDF provided in a jar.") + val jar = hiveContext.getHiveFile("hive-contrib-0.13.1.jar").getCanonicalPath + hiveContext.sql( + s""" + |CREATE FUNCTION example_max + |AS 'org.apache.hadoop.hive.contrib.udaf.example.UDAFExampleMax' + |USING JAR '$jar' + """.stripMargin) + val source = + hiveContext.createDataFrame((1 to 10).map(i => (i, s"str$i"))).toDF("key", "val") + source.registerTempTable("sourceTable") + // Actually use the loaded UDF. + logInfo("Using the UDF.") + val result = hiveContext.sql( + "SELECT example_max(key) as key, val FROM sourceTable GROUP BY val") + logInfo("Running a simple query on the table.") + val count = result.orderBy("key", "val").count() + if (count != 10) { + throw new Exception(s"Result table should have 10 rows instead of $count rows") + } + hiveContext.sql("DROP FUNCTION example_max") + logInfo("Test finishes.") + sc.stop() + } +} + +// This application is used to test that a pre-defined permanent function with a jar +// resources can be used. We need to run this test in separate JVM to make sure we +// can load the jar defined with the function. +object PermanentHiveUDFTest2 extends Logging { + def main(args: Array[String]) { + Utils.configTestLog4j("INFO") + val conf = new SparkConf() + conf.set("spark.ui.enabled", "false") + val sc = new SparkContext(conf) + val hiveContext = new TestHiveContext(sc) + // Load a Hive UDF from the jar. + logInfo("Write the metadata of a permanent Hive UDF into metastore.") + val jar = hiveContext.getHiveFile("hive-contrib-0.13.1.jar").getCanonicalPath + val function = CatalogFunction( + FunctionIdentifier("example_max"), + "org.apache.hadoop.hive.contrib.udaf.example.UDAFExampleMax", + ("JAR" -> jar) :: Nil) + hiveContext.sessionState.catalog.createFunction(function) + val source = + hiveContext.createDataFrame((1 to 10).map(i => (i, s"str$i"))).toDF("key", "val") + source.registerTempTable("sourceTable") + // Actually use the loaded UDF. + logInfo("Using the UDF.") + val result = hiveContext.sql( + "SELECT example_max(key) as key, val FROM sourceTable GROUP BY val") + logInfo("Running a simple query on the table.") + val count = result.orderBy("key", "val").count() + if (count != 10) { + throw new Exception(s"Result table should have 10 rows instead of $count rows") + } + hiveContext.sql("DROP FUNCTION example_max") + logInfo("Test finishes.") + sc.stop() + } +} + // This object is used for testing SPARK-8368: https://issues.apache.org/jira/browse/SPARK-8368. // We test if we can load user jars in both driver and executors when HiveContext is used. object SparkSubmitClassLoaderTest extends Logging { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/UDFSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/UDFSuite.scala index 3ab4576811..d1aa5aa931 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/UDFSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/UDFSuite.scala @@ -17,12 +17,51 @@ package org.apache.spark.sql.hive -import org.apache.spark.sql.QueryTest +import org.scalatest.BeforeAndAfterEach + +import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest, Row} import org.apache.spark.sql.hive.test.TestHiveSingleton +import org.apache.spark.sql.test.SQLTestUtils case class FunctionResult(f1: String, f2: String) -class UDFSuite extends QueryTest with TestHiveSingleton { +/** + * A test suite for UDF related functionalities. Because Hive metastore is + * case insensitive, database names and function names have both upper case + * letters and lower case letters. + */ +class UDFSuite + extends QueryTest + with SQLTestUtils + with TestHiveSingleton + with BeforeAndAfterEach { + + import hiveContext.implicits._ + + private[this] val functionName = "myUPper" + private[this] val functionNameUpper = "MYUPPER" + private[this] val functionNameLower = "myupper" + + private[this] val functionClass = + classOf[org.apache.hadoop.hive.ql.udf.generic.GenericUDFUpper].getCanonicalName + + private var testDF: DataFrame = null + private[this] val testTableName = "testDF_UDFSuite" + private var expectedDF: DataFrame = null + + override def beforeAll(): Unit = { + sql("USE default") + + testDF = (1 to 10).map(i => s"sTr$i").toDF("value") + testDF.registerTempTable(testTableName) + expectedDF = (1 to 10).map(i => s"STR$i").toDF("value") + super.beforeAll() + } + + override def afterEach(): Unit = { + sql("USE default") + super.afterEach() + } test("UDF case insensitive") { hiveContext.udf.register("random0", () => { Math.random() }) @@ -32,4 +71,128 @@ class UDFSuite extends QueryTest with TestHiveSingleton { assert(hiveContext.sql("SELECT RANDOm1() FROM src LIMIT 1").head().getDouble(0) >= 0.0) assert(hiveContext.sql("SELECT strlenscala('test', 1) FROM src LIMIT 1").head().getInt(0) === 5) } + + test("temporary function: create and drop") { + withUserDefinedFunction(functionName -> true) { + intercept[AnalysisException] { + sql(s"CREATE TEMPORARY FUNCTION default.$functionName AS '$functionClass'") + } + sql(s"CREATE TEMPORARY FUNCTION $functionName AS '$functionClass'") + checkAnswer( + sql(s"SELECT $functionNameLower(value) from $testTableName"), + expectedDF + ) + intercept[AnalysisException] { + sql(s"DROP TEMPORARY FUNCTION default.$functionName") + } + } + } + + test("permanent function: create and drop without specifying db name") { + withUserDefinedFunction(functionName -> false) { + sql(s"CREATE FUNCTION $functionName AS '$functionClass'") + checkAnswer( + sql("SHOW functions like '.*upper'"), + Row(s"default.$functionNameLower") + ) + checkAnswer( + sql(s"SELECT $functionName(value) from $testTableName"), + expectedDF + ) + assert( + sql("SHOW functions").collect() + .map(_.getString(0)) + .contains(s"default.$functionNameLower")) + } + } + + test("permanent function: create and drop with a db name") { + // For this block, drop function command uses functionName as the function name. + withUserDefinedFunction(functionNameUpper -> false) { + sql(s"CREATE FUNCTION default.$functionName AS '$functionClass'") + // TODO: Re-enable it after can distinguish qualified and unqualified function name + // in SessionCatalog.lookupFunction. + // checkAnswer( + // sql(s"SELECT default.myuPPer(value) from $testTableName"), + // expectedDF + // ) + checkAnswer( + sql(s"SELECT $functionName(value) from $testTableName"), + expectedDF + ) + checkAnswer( + sql(s"SELECT default.$functionName(value) from $testTableName"), + expectedDF + ) + } + + // For this block, drop function command uses default.functionName as the function name. + withUserDefinedFunction(s"DEfault.$functionNameLower" -> false) { + sql(s"CREATE FUNCTION dEFault.$functionName AS '$functionClass'") + checkAnswer( + sql(s"SELECT $functionNameUpper(value) from $testTableName"), + expectedDF + ) + } + } + + test("permanent function: create and drop a function in another db") { + // For this block, drop function command uses functionName as the function name. + withTempDatabase { dbName => + withUserDefinedFunction(functionName -> false) { + sql(s"CREATE FUNCTION $dbName.$functionName AS '$functionClass'") + // TODO: Re-enable it after can distinguish qualified and unqualified function name + // checkAnswer( + // sql(s"SELECT $dbName.myuPPer(value) from $testTableName"), + // expectedDF + // ) + + checkAnswer( + sql(s"SHOW FUNCTIONS like $dbName.$functionNameUpper"), + Row(s"$dbName.$functionNameLower") + ) + + sql(s"USE $dbName") + + checkAnswer( + sql(s"SELECT $functionName(value) from $testTableName"), + expectedDF + ) + + sql(s"USE default") + + checkAnswer( + sql(s"SELECT $dbName.$functionName(value) from $testTableName"), + expectedDF + ) + + sql(s"USE $dbName") + } + + sql(s"USE default") + + // For this block, drop function command uses default.functionName as the function name. + withUserDefinedFunction(s"$dbName.$functionNameUpper" -> false) { + sql(s"CREATE FUNCTION $dbName.$functionName AS '$functionClass'") + // TODO: Re-enable it after can distinguish qualified and unqualified function name + // checkAnswer( + // sql(s"SELECT $dbName.myupper(value) from $testTableName"), + // expectedDF + // ) + + sql(s"USE $dbName") + + assert( + sql("SHOW functions").collect() + .map(_.getString(0)) + .contains(s"$dbName.$functionNameLower")) + checkAnswer( + sql(s"SELECT $functionNameLower(value) from $testTableName"), + expectedDF + ) + + sql(s"USE default") + } + } + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index b951948fda..0c57ede9ed 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -62,7 +62,7 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { TestHive.cacheTables = false TimeZone.setDefault(originalTimeZone) Locale.setDefault(originalLocale) - sql("DROP TEMPORARY FUNCTION udtf_count2") + sql("DROP TEMPORARY FUNCTION IF EXISTS udtf_count2") } finally { super.afterAll() } @@ -1230,14 +1230,16 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { val e = intercept[AnalysisException] { range(1).selectExpr("not_a_udf()") } - assert(e.getMessage.contains("undefined function not_a_udf")) + assert(e.getMessage.contains("Undefined function")) + assert(e.getMessage.contains("not_a_udf")) var success = false val t = new Thread("test") { override def run(): Unit = { val e = intercept[AnalysisException] { range(1).selectExpr("not_a_udf()") } - assert(e.getMessage.contains("undefined function not_a_udf")) + assert(e.getMessage.contains("Undefined function")) + assert(e.getMessage.contains("not_a_udf")) success = true } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala index b0e263dff9..d07ac56586 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala @@ -303,7 +303,7 @@ class HiveUDFSuite extends QueryTest with TestHiveSingleton with SQLTestUtils { val message = intercept[AnalysisException] { sql("SELECT testUDFTwoListList() FROM testUDF") }.getMessage - assert(message.contains("No handler for Hive udf")) + assert(message.contains("No handler for Hive UDF")) sql("DROP TEMPORARY FUNCTION IF EXISTS testUDFTwoListList") } @@ -313,7 +313,7 @@ class HiveUDFSuite extends QueryTest with TestHiveSingleton with SQLTestUtils { val message = intercept[AnalysisException] { sql("SELECT testUDFAnd() FROM testUDF") }.getMessage - assert(message.contains("No handler for Hive udf")) + assert(message.contains("No handler for Hive UDF")) sql("DROP TEMPORARY FUNCTION IF EXISTS testUDFAnd") } @@ -323,7 +323,7 @@ class HiveUDFSuite extends QueryTest with TestHiveSingleton with SQLTestUtils { val message = intercept[AnalysisException] { sql("SELECT testUDAFPercentile(a) FROM testUDF GROUP BY b") }.getMessage - assert(message.contains("No handler for Hive udf")) + assert(message.contains("No handler for Hive UDF")) sql("DROP TEMPORARY FUNCTION IF EXISTS testUDAFPercentile") } @@ -333,7 +333,7 @@ class HiveUDFSuite extends QueryTest with TestHiveSingleton with SQLTestUtils { val message = intercept[AnalysisException] { sql("SELECT testUDAFAverage() FROM testUDF GROUP BY b") }.getMessage - assert(message.contains("No handler for Hive udf")) + assert(message.contains("No handler for Hive UDF")) sql("DROP TEMPORARY FUNCTION IF EXISTS testUDAFAverage") } @@ -343,7 +343,7 @@ class HiveUDFSuite extends QueryTest with TestHiveSingleton with SQLTestUtils { val message = intercept[AnalysisException] { sql("SELECT testUDTFExplode() FROM testUDF") }.getMessage - assert(message.contains("No handler for Hive udf")) + assert(message.contains("No handler for Hive UDF")) sql("DROP TEMPORARY FUNCTION IF EXISTS testUDTFExplode") } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 6199253d34..14a1d4cd30 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -24,6 +24,7 @@ import scala.collection.JavaConverters._ import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.{EliminateSubqueryAliases, FunctionRegistry} +import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.functions._ import org.apache.spark.sql.hive.{HiveContext, MetastoreRelation} @@ -67,22 +68,43 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { import hiveContext.implicits._ test("UDTF") { - sql(s"ADD JAR ${hiveContext.getHiveFile("TestUDTF.jar").getCanonicalPath()}") - // The function source code can be found at: - // https://cwiki.apache.org/confluence/display/Hive/DeveloperGuide+UDTF - sql( - """ - |CREATE TEMPORARY FUNCTION udtf_count2 - |AS 'org.apache.spark.sql.hive.execution.GenericUDTFCount2' - """.stripMargin) + withUserDefinedFunction("udtf_count2" -> true) { + sql(s"ADD JAR ${hiveContext.getHiveFile("TestUDTF.jar").getCanonicalPath()}") + // The function source code can be found at: + // https://cwiki.apache.org/confluence/display/Hive/DeveloperGuide+UDTF + sql( + """ + |CREATE TEMPORARY FUNCTION udtf_count2 + |AS 'org.apache.spark.sql.hive.execution.GenericUDTFCount2' + """.stripMargin) - checkAnswer( - sql("SELECT key, cc FROM src LATERAL VIEW udtf_count2(value) dd AS cc"), - Row(97, 500) :: Row(97, 500) :: Nil) + checkAnswer( + sql("SELECT key, cc FROM src LATERAL VIEW udtf_count2(value) dd AS cc"), + Row(97, 500) :: Row(97, 500) :: Nil) - checkAnswer( - sql("SELECT udtf_count2(a) FROM (SELECT 1 AS a FROM src LIMIT 3) t"), - Row(3) :: Row(3) :: Nil) + checkAnswer( + sql("SELECT udtf_count2(a) FROM (SELECT 1 AS a FROM src LIMIT 3) t"), + Row(3) :: Row(3) :: Nil) + } + } + + test("permanent UDTF") { + withUserDefinedFunction("udtf_count_temp" -> false) { + sql( + s""" + |CREATE FUNCTION udtf_count_temp + |AS 'org.apache.spark.sql.hive.execution.GenericUDTFCount2' + |USING JAR '${hiveContext.getHiveFile("TestUDTF.jar").getCanonicalPath()}' + """.stripMargin) + + checkAnswer( + sql("SELECT key, cc FROM src LATERAL VIEW udtf_count_temp(value) dd AS cc"), + Row(97, 500) :: Row(97, 500) :: Nil) + + checkAnswer( + sql("SELECT udtf_count_temp(a) FROM (SELECT 1 AS a FROM src LIMIT 3) t"), + Row(3) :: Row(3) :: Nil) + } } test("SPARK-6835: udtf in lateral view") { @@ -169,9 +191,7 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { } test("show functions") { - val allBuiltinFunctions = - (FunctionRegistry.builtin.listFunction().toSet[String] ++ - org.apache.hadoop.hive.ql.exec.FunctionRegistry.getFunctionNames.asScala).toList.sorted + val allBuiltinFunctions = FunctionRegistry.builtin.listFunction().toSet[String].toList.sorted // The TestContext is shared by all the test cases, some functions may be registered before // this, so we check that all the builtin functions are returned. val allFunctions = sql("SHOW functions").collect().map(r => r(0)) @@ -183,11 +203,16 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { checkAnswer(sql("SHOW functions abc.abs"), Row("abs")) checkAnswer(sql("SHOW functions `abc`.`abs`"), Row("abs")) checkAnswer(sql("SHOW functions `abc`.`abs`"), Row("abs")) - checkAnswer(sql("SHOW functions `~`"), Row("~")) + // TODO: Re-enable this test after we fix SPARK-14335. + // checkAnswer(sql("SHOW functions `~`"), Row("~")) checkAnswer(sql("SHOW functions `a function doens't exist`"), Nil) - checkAnswer(sql("SHOW functions `weekofyea.*`"), Row("weekofyear")) + checkAnswer(sql("SHOW functions `weekofyea*`"), Row("weekofyear")) // this probably will failed if we add more function with `sha` prefixing. - checkAnswer(sql("SHOW functions `sha.*`"), Row("sha") :: Row("sha1") :: Row("sha2") :: Nil) + checkAnswer(sql("SHOW functions `sha*`"), Row("sha") :: Row("sha1") :: Row("sha2") :: Nil) + // Test '|' for alternation. + checkAnswer( + sql("SHOW functions 'sha*|weekofyea*'"), + Row("sha") :: Row("sha1") :: Row("sha2") :: Row("weekofyear") :: Nil) } test("describe functions") { @@ -211,10 +236,11 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { checkExistence(sql("describe functioN abcadf"), true, "Function: abcadf not found.") - checkExistence(sql("describe functioN `~`"), true, - "Function: ~", - "Class: org.apache.hadoop.hive.ql.udf.UDFOPBitNot", - "Usage: ~ n - Bitwise not") + // TODO: Re-enable this test after we fix SPARK-14335. + // checkExistence(sql("describe functioN `~`"), true, + // "Function: ~", + // "Class: org.apache.hadoop.hive.ql.udf.UDFOPBitNot", + // "Usage: ~ n - Bitwise not") } test("SPARK-5371: union with null and sum") { -- cgit v1.2.3 From 9ee5c257176d5c7989031d260e74e3eca530c120 Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Tue, 5 Apr 2016 13:18:39 -0700 Subject: [SPARK-14353] Dataset Time Window `window` API for Python, and SQL ## What changes were proposed in this pull request? The `window` function was added to Dataset with [this PR](https://github.com/apache/spark/pull/12008). This PR adds the Python, and SQL, API for this function. With this PR, SQL, Java, and Scala will share the same APIs as in users can use: - `window(timeColumn, windowDuration)` - `window(timeColumn, windowDuration, slideDuration)` - `window(timeColumn, windowDuration, slideDuration, startTime)` In Python, users can access all APIs above, but in addition they can do - In Python: `window(timeColumn, windowDuration, startTime=...)` that is, they can provide the startTime without providing the `slideDuration`. In this case, we will generate tumbling windows. ## How was this patch tested? Unit tests + manual tests Author: Burak Yavuz Closes #12136 from brkyvz/python-windows. --- python/pyspark/sql/functions.py | 49 +++++++++++++++++++ .../sql/catalyst/analysis/FunctionRegistry.scala | 5 +- .../sql/catalyst/expressions/TimeWindow.scala | 35 +++++++++++++ .../apache/spark/sql/catalyst/trees/TreeNode.scala | 27 +++++++--- .../sql/catalyst/expressions/TimeWindowSuite.scala | 37 +++++++++++++- .../scala/org/apache/spark/sql/functions.scala | 9 ++-- .../spark/sql/DataFrameTimeWindowingSuite.scala | 57 ++++++++++++++++++++++ 7 files changed, 204 insertions(+), 15 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 3b20ba5177..5017ab5b36 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -1053,6 +1053,55 @@ def to_utc_timestamp(timestamp, tz): return Column(sc._jvm.functions.to_utc_timestamp(_to_java_column(timestamp), tz)) +@since(2.0) +@ignore_unicode_prefix +def window(timeColumn, windowDuration, slideDuration=None, startTime=None): + """Bucketize rows into one or more time windows given a timestamp specifying column. Window + starts are inclusive but the window ends are exclusive, e.g. 12:05 will be in the window + [12:05,12:10) but not in [12:00,12:05). Windows can support microsecond precision. Windows in + the order of months are not supported. + + The time column must be of TimestampType. + + Durations are provided as strings, e.g. '1 second', '1 day 12 hours', '2 minutes'. Valid + interval strings are 'week', 'day', 'hour', 'minute', 'second', 'millisecond', 'microsecond'. + If the `slideDuration` is not provided, the windows will be tumbling windows. + + The startTime is the offset with respect to 1970-01-01 00:00:00 UTC with which to start + window intervals. For example, in order to have hourly tumbling windows that start 15 minutes + past the hour, e.g. 12:15-13:15, 13:15-14:15... provide `startTime` as `15 minutes`. + + The output column will be a struct called 'window' by default with the nested columns 'start' + and 'end', where 'start' and 'end' will be of `TimestampType`. + + >>> df = sqlContext.createDataFrame([("2016-03-11 09:00:07", 1)]).toDF("date", "val") + >>> w = df.groupBy(window("date", "5 seconds")).agg(sum("val").alias("sum")) + >>> w.select(w.window.start.cast("string").alias("start"), + ... w.window.end.cast("string").alias("end"), "sum").collect() + [Row(start=u'2016-03-11 09:00:05', end=u'2016-03-11 09:00:10', sum=1)] + """ + def check_string_field(field, fieldName): + if not field or type(field) is not str: + raise TypeError("%s should be provided as a string" % fieldName) + + sc = SparkContext._active_spark_context + time_col = _to_java_column(timeColumn) + check_string_field(windowDuration, "windowDuration") + if slideDuration and startTime: + check_string_field(slideDuration, "slideDuration") + check_string_field(startTime, "startTime") + res = sc._jvm.functions.window(time_col, windowDuration, slideDuration, startTime) + elif slideDuration: + check_string_field(slideDuration, "slideDuration") + res = sc._jvm.functions.window(time_col, windowDuration, slideDuration) + elif startTime: + check_string_field(startTime, "startTime") + res = sc._jvm.functions.window(time_col, windowDuration, windowDuration, startTime) + else: + res = sc._jvm.functions.window(time_col, windowDuration) + return Column(res) + + # ---------------------------- misc functions ---------------------------------- @since(1.5) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 7af5ffbe47..1ebdf49348 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -364,7 +364,10 @@ object FunctionRegistry { } Try(f.newInstance(expressions : _*).asInstanceOf[Expression]) match { case Success(e) => e - case Failure(e) => throw new AnalysisException(e.getMessage) + case Failure(e) => + // the exception is an invocation exception. To get a meaningful message, we need the + // cause. + throw new AnalysisException(e.getCause.getMessage) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala index 8e13833486..daf3de95dd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.commons.lang.StringUtils +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.TypeCheckFailure import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} @@ -34,6 +35,28 @@ case class TimeWindow( with Unevaluable with NonSQLExpression { + ////////////////////////// + // SQL Constructors + ////////////////////////// + + def this( + timeColumn: Expression, + windowDuration: Expression, + slideDuration: Expression, + startTime: Expression) = { + this(timeColumn, TimeWindow.parseExpression(windowDuration), + TimeWindow.parseExpression(windowDuration), TimeWindow.parseExpression(startTime)) + } + + def this(timeColumn: Expression, windowDuration: Expression, slideDuration: Expression) = { + this(timeColumn, TimeWindow.parseExpression(windowDuration), + TimeWindow.parseExpression(windowDuration), 0) + } + + def this(timeColumn: Expression, windowDuration: Expression) = { + this(timeColumn, windowDuration, windowDuration) + } + override def child: Expression = timeColumn override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType) override def dataType: DataType = new StructType() @@ -104,6 +127,18 @@ object TimeWindow { cal.microseconds } + /** + * Parses the duration expression to generate the long value for the original constructor so + * that we can use `window` in SQL. + */ + private def parseExpression(expr: Expression): Long = expr match { + case NonNullLiteral(s, StringType) => getIntervalInMicroSeconds(s.toString) + case IntegerLiteral(i) => i.toLong + case NonNullLiteral(l, LongType) => l.toString.toLong + case _ => throw new AnalysisException("The duration and time inputs to window must be " + + "an integer, long or string literal.") + } + def apply( timeColumn: Expression, windowDuration: String, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala index 6b7997e903..232ca43588 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala @@ -22,6 +22,7 @@ import java.util.UUID import scala.collection.Map import scala.collection.mutable.Stack +import org.apache.commons.lang.ClassUtils import org.json4s.JsonAST._ import org.json4s.JsonDSL._ import org.json4s.jackson.JsonMethods._ @@ -365,20 +366,32 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { * @param newArgs the new product arguments. */ def makeCopy(newArgs: Array[AnyRef]): BaseType = attachTree(this, "makeCopy") { + // Skip no-arg constructors that are just there for kryo. val ctors = getClass.getConstructors.filter(_.getParameterTypes.size != 0) if (ctors.isEmpty) { sys.error(s"No valid constructor for $nodeName") } - val defaultCtor = ctors.maxBy(_.getParameterTypes.size) + val allArgs: Array[AnyRef] = if (otherCopyArgs.isEmpty) { + newArgs + } else { + newArgs ++ otherCopyArgs + } + val defaultCtor = ctors.find { ctor => + if (ctor.getParameterTypes.length != allArgs.length) { + false + } else if (allArgs.contains(null)) { + // if there is a `null`, we can't figure out the class, therefore we should just fallback + // to older heuristic + false + } else { + val argsArray: Array[Class[_]] = allArgs.map(_.getClass) + ClassUtils.isAssignable(argsArray, ctor.getParameterTypes, true /* autoboxing */) + } + }.getOrElse(ctors.maxBy(_.getParameterTypes.length)) // fall back to older heuristic try { CurrentOrigin.withOrigin(origin) { - // Skip no-arg constructors that are just there for kryo. - if (otherCopyArgs.isEmpty) { - defaultCtor.newInstance(newArgs: _*).asInstanceOf[BaseType] - } else { - defaultCtor.newInstance((newArgs ++ otherCopyArgs).toArray: _*).asInstanceOf[BaseType] - } + defaultCtor.newInstance(allArgs.toArray: _*).asInstanceOf[BaseType] } } catch { case e: java.lang.IllegalArgumentException => diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/TimeWindowSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/TimeWindowSuite.scala index 71f969aee2..b82cf8d169 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/TimeWindowSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/TimeWindowSuite.scala @@ -17,10 +17,13 @@ package org.apache.spark.sql.catalyst.expressions +import org.scalatest.PrivateMethodTester + import org.apache.spark.SparkFunSuite import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.types.LongType -class TimeWindowSuite extends SparkFunSuite with ExpressionEvalHelper { +class TimeWindowSuite extends SparkFunSuite with ExpressionEvalHelper with PrivateMethodTester { test("time window is unevaluable") { intercept[UnsupportedOperationException] { @@ -73,4 +76,36 @@ class TimeWindowSuite extends SparkFunSuite with ExpressionEvalHelper { === seconds) } } + + private val parseExpression = PrivateMethod[Long]('parseExpression) + + test("parse sql expression for duration in microseconds - string") { + val dur = TimeWindow.invokePrivate(parseExpression(Literal("5 seconds"))) + assert(dur.isInstanceOf[Long]) + assert(dur === 5000000) + } + + test("parse sql expression for duration in microseconds - integer") { + val dur = TimeWindow.invokePrivate(parseExpression(Literal(100))) + assert(dur.isInstanceOf[Long]) + assert(dur === 100) + } + + test("parse sql expression for duration in microseconds - long") { + val dur = TimeWindow.invokePrivate(parseExpression(Literal.create(2 << 52, LongType))) + assert(dur.isInstanceOf[Long]) + assert(dur === (2 << 52)) + } + + test("parse sql expression for duration in microseconds - invalid interval") { + intercept[IllegalArgumentException] { + TimeWindow.invokePrivate(parseExpression(Literal("2 apples"))) + } + } + + test("parse sql expression for duration in microseconds - invalid expression") { + intercept[AnalysisException] { + TimeWindow.invokePrivate(parseExpression(Rand(123))) + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index da58ba2add..5bc0034cb0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -2574,8 +2574,7 @@ object functions { * processing time. * * @param timeColumn The column or the expression to use as the timestamp for windowing by time. - * The time can be as TimestampType or LongType, however when using LongType, - * the time must be given in seconds. + * The time column must be of TimestampType. * @param windowDuration A string specifying the width of the window, e.g. `10 minutes`, * `1 second`. Check [[org.apache.spark.unsafe.types.CalendarInterval]] for * valid duration identifiers. @@ -2629,8 +2628,7 @@ object functions { * processing time. * * @param timeColumn The column or the expression to use as the timestamp for windowing by time. - * The time can be as TimestampType or LongType, however when using LongType, - * the time must be given in seconds. + * The time column must be of TimestampType. * @param windowDuration A string specifying the width of the window, e.g. `10 minutes`, * `1 second`. Check [[org.apache.spark.unsafe.types.CalendarInterval]] for * valid duration identifiers. @@ -2672,8 +2670,7 @@ object functions { * processing time. * * @param timeColumn The column or the expression to use as the timestamp for windowing by time. - * The time can be as TimestampType or LongType, however when using LongType, - * the time must be given in seconds. + * The time column must be of TimestampType. * @param windowDuration A string specifying the width of the window, e.g. `10 minutes`, * `1 second`. Check [[org.apache.spark.unsafe.types.CalendarInterval]] for * valid duration identifiers. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTimeWindowingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTimeWindowingSuite.scala index e8103a31d5..06584ec21e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTimeWindowingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTimeWindowingSuite.scala @@ -239,4 +239,61 @@ class DataFrameTimeWindowingSuite extends QueryTest with SharedSQLContext with B Row("2016-03-27 09:00:00.68", "2016-03-27 09:00:00.88", 1)) ) } + + private def withTempTable(f: String => Unit): Unit = { + val tableName = "temp" + Seq( + ("2016-03-27 19:39:34", 1), + ("2016-03-27 19:39:56", 2), + ("2016-03-27 19:39:27", 4)).toDF("time", "value").registerTempTable(tableName) + try { + f(tableName) + } finally { + sqlContext.dropTempTable(tableName) + } + } + + test("time window in SQL with single string expression") { + withTempTable { table => + checkAnswer( + sqlContext.sql(s"""select window(time, "10 seconds"), value from $table""") + .select($"window.start".cast(StringType), $"window.end".cast(StringType), $"value"), + Seq( + Row("2016-03-27 19:39:20", "2016-03-27 19:39:30", 4), + Row("2016-03-27 19:39:30", "2016-03-27 19:39:40", 1), + Row("2016-03-27 19:39:50", "2016-03-27 19:40:00", 2) + ) + ) + } + } + + test("time window in SQL with with two expressions") { + withTempTable { table => + checkAnswer( + sqlContext.sql( + s"""select window(time, "10 seconds", 10000000), value from $table""") + .select($"window.start".cast(StringType), $"window.end".cast(StringType), $"value"), + Seq( + Row("2016-03-27 19:39:20", "2016-03-27 19:39:30", 4), + Row("2016-03-27 19:39:30", "2016-03-27 19:39:40", 1), + Row("2016-03-27 19:39:50", "2016-03-27 19:40:00", 2) + ) + ) + } + } + + test("time window in SQL with with three expressions") { + withTempTable { table => + checkAnswer( + sqlContext.sql( + s"""select window(time, "10 seconds", 10000000, "5 seconds"), value from $table""") + .select($"window.start".cast(StringType), $"window.end".cast(StringType), $"value"), + Seq( + Row("2016-03-27 19:39:25", "2016-03-27 19:39:35", 1), + Row("2016-03-27 19:39:25", "2016-03-27 19:39:35", 4), + Row("2016-03-27 19:39:55", "2016-03-27 19:40:05", 2) + ) + ) + } + } } -- cgit v1.2.3 From c59abad052b7beec4ef550049413e95578e545be Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Tue, 5 Apr 2016 13:31:00 -0700 Subject: [SPARK-14402][SQL] initcap UDF doesn't match Hive/Oracle behavior in lowercasing rest of string ## What changes were proposed in this pull request? Current, SparkSQL `initCap` is using `toTitleCase` function. However, `UTF8String.toTitleCase` implementation changes only the first letter and just copy the other letters: e.g. sParK --> SParK. This is the correct implementation `toTitleCase`. ``` hive> select initcap('sParK'); Spark ``` ``` scala> sql("select initcap('sParK')").head res0: org.apache.spark.sql.Row = [SParK] ``` This PR updates the implementation of `initcap` using `toLowerCase` and `toTitleCase`. ## How was this patch tested? Pass the Jenkins tests (including new testcase). Author: Dongjoon Hyun Closes #12175 from dongjoon-hyun/SPARK-14402. --- .../spark/sql/catalyst/expressions/stringExpressions.scala | 11 ++++++++--- .../sql/catalyst/expressions/StringExpressionsSuite.scala | 1 + .../scala/org/apache/spark/sql/StringFunctionsSuite.scala | 6 +++--- 3 files changed, 12 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index 3ee19cc4ad..b6ea03cd5c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -618,19 +618,24 @@ case class FormatString(children: Expression*) extends Expression with ImplicitC } /** - * Returns string, with the first letter of each word in uppercase. + * Returns string, with the first letter of each word in uppercase, all other letters in lowercase. * Words are delimited by whitespace. */ +@ExpressionDescription( + usage = "_FUNC_(str) - " + + "Returns str, with the first letter of each word in uppercase, all other letters in " + + "lowercase. Words are delimited by white space.", + extended = "> SELECT initcap('sPark sql');\n 'Spark Sql'") case class InitCap(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { override def inputTypes: Seq[DataType] = Seq(StringType) override def dataType: DataType = StringType override def nullSafeEval(string: Any): Any = { - string.asInstanceOf[UTF8String].toTitleCase + string.asInstanceOf[UTF8String].toLowerCase.toTitleCase } override def genCode(ctx: CodegenContext, ev: ExprCode): String = { - defineCodeGen(ctx, ev, str => s"$str.toTitleCase()") + defineCodeGen(ctx, ev, str => s"$str.toLowerCase().toTitleCase()") } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala index 99e3b13ce8..2cf8ca7000 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala @@ -382,6 +382,7 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(InitCap(Literal("a b")), "A B") checkEvaluation(InitCap(Literal(" a")), " A") checkEvaluation(InitCap(Literal("the test")), "The Test") + checkEvaluation(InitCap(Literal("sParK")), "Spark") // scalastyle:off // non ascii characters are not allowed in the code, so we disable the scalastyle here. checkEvaluation(InitCap(Literal("世界")), "世界") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala index e2090b0a83..6809f26968 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala @@ -272,12 +272,12 @@ class StringFunctionsSuite extends QueryTest with SharedSQLContext { } test("initcap function") { - val df = Seq(("ab", "a B")).toDF("l", "r") + val df = Seq(("ab", "a B", "sParK")).toDF("x", "y", "z") checkAnswer( - df.select(initcap($"l"), initcap($"r")), Row("Ab", "A B")) + df.select(initcap($"x"), initcap($"y"), initcap($"z")), Row("Ab", "A B", "Spark")) checkAnswer( - df.selectExpr("InitCap(l)", "InitCap(r)"), Row("Ab", "A B")) + df.selectExpr("InitCap(x)", "InitCap(y)", "InitCap(z)"), Row("Ab", "A B", "Spark")) } test("number format function") { -- cgit v1.2.3 From 45d8cdee3945bf94d0f1bd93a12e4cb0d416468e Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Tue, 5 Apr 2016 14:54:07 -0700 Subject: [SPARK-14129][SPARK-14128][SQL] Alter table DDL commands ## What changes were proposed in this pull request? In Spark 2.0, we want to handle the most common `ALTER TABLE` commands ourselves instead of passing the entire query text to Hive. This is done using the new `SessionCatalog` API introduced recently. The commands supported in this patch include: ``` ALTER TABLE ... RENAME TO ... ALTER TABLE ... SET TBLPROPERTIES ... ALTER TABLE ... UNSET TBLPROPERTIES ... ALTER TABLE ... SET LOCATION ... ALTER TABLE ... SET SERDE ... ``` The commands we explicitly do not support are: ``` ALTER TABLE ... CLUSTERED BY ... ALTER TABLE ... SKEWED BY ... ALTER TABLE ... NOT CLUSTERED ALTER TABLE ... NOT SORTED ALTER TABLE ... NOT SKEWED ALTER TABLE ... NOT STORED AS DIRECTORIES ``` For these we throw exceptions complaining that they are not supported. ## How was this patch tested? `DDLSuite` Author: Andrew Or Closes #12121 from andrewor14/alter-table-ddl. --- .../sql/catalyst/analysis/FunctionRegistry.scala | 12 + .../sql/catalyst/catalog/SessionCatalog.scala | 42 +++ .../spark/sql/catalyst/util/StringKeyHashMap.scala | 2 + .../spark/sql/execution/SparkSqlParser.scala | 121 ++------ .../apache/spark/sql/execution/command/ddl.scala | 216 +++++++++++--- .../sql/execution/command/DDLCommandSuite.scala | 127 +------- .../spark/sql/execution/command/DDLSuite.scala | 330 ++++++++++++++++++--- .../hive/execution/HiveCompatibilitySuite.scala | 10 +- .../apache/spark/sql/hive/HiveSessionCatalog.scala | 2 +- 9 files changed, 562 insertions(+), 300 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 1ebdf49348..f239b33e44 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -54,6 +54,10 @@ trait FunctionRegistry { /** Checks if a function with a given name exists. */ def functionExists(name: String): Boolean = lookupFunction(name).isDefined + + /** Clear all registered functions. */ + def clear(): Unit + } class SimpleFunctionRegistry extends FunctionRegistry { @@ -93,6 +97,10 @@ class SimpleFunctionRegistry extends FunctionRegistry { functionBuilders.remove(name).isDefined } + override def clear(): Unit = { + functionBuilders.clear() + } + def copy(): SimpleFunctionRegistry = synchronized { val registry = new SimpleFunctionRegistry functionBuilders.iterator.foreach { case (name, (info, builder)) => @@ -132,6 +140,10 @@ object EmptyFunctionRegistry extends FunctionRegistry { throw new UnsupportedOperationException } + override def clear(): Unit = { + throw new UnsupportedOperationException + } + } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala index c08ffbb235..62a3b1c105 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala @@ -304,11 +304,18 @@ class SessionCatalog( dbTables ++ _tempTables } + // TODO: It's strange that we have both refresh and invalidate here. + /** * Refresh the cache entry for a metastore table, if any. */ def refreshTable(name: TableIdentifier): Unit = { /* no-op */ } + /** + * Invalidate the cache entry for a metastore table, if any. + */ + def invalidateTable(name: TableIdentifier): Unit = { /* no-op */ } + /** * Drop all existing temporary tables. * For testing only. @@ -595,6 +602,11 @@ class SessionCatalog( } } + /** + * List all functions in the specified database, including temporary functions. + */ + def listFunctions(db: String): Seq[FunctionIdentifier] = listFunctions(db, "*") + /** * List all matching functions in the specified database, including temporary functions. */ @@ -609,4 +621,34 @@ class SessionCatalog( // So, the returned list may have two entries for the same function. dbFunctions ++ loadedFunctions } + + + // ----------------- + // | Other methods | + // ----------------- + + /** + * Drop all existing databases (except "default") along with all associated tables, + * partitions and functions, and set the current database to "default". + * + * This is mainly used for tests. + */ + private[sql] def reset(): Unit = { + val default = "default" + listDatabases().filter(_ != default).foreach { db => + dropDatabase(db, ignoreIfNotExists = false, cascade = true) + } + tempTables.clear() + functionRegistry.clear() + // restore built-in functions + FunctionRegistry.builtin.listFunction().foreach { f => + val expressionInfo = FunctionRegistry.builtin.lookupFunction(f) + val functionBuilder = FunctionRegistry.builtin.lookupFunctionBuilder(f) + require(expressionInfo.isDefined, s"built-in function '$f' is missing expression info") + require(functionBuilder.isDefined, s"built-in function '$f' is missing function builder") + functionRegistry.registerFunction(f, expressionInfo.get, functionBuilder.get) + } + setCurrentDatabase(default) + } + } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringKeyHashMap.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringKeyHashMap.scala index 191d5e6399..d5d151a580 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringKeyHashMap.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringKeyHashMap.scala @@ -41,4 +41,6 @@ class StringKeyHashMap[T](normalizer: (String) => String) { def remove(key: String): Option[T] = base.remove(normalizer(key)) def iterator: Iterator[(String, T)] = base.toIterator + + def clear(): Unit = base.clear() } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala index 382cc61fac..d3086fc91e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution import scala.collection.JavaConverters._ -import org.apache.spark.sql.SaveMode +import org.apache.spark.sql.{AnalysisException, SaveMode} import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.parser.{AbstractSqlParser, AstBuilder, ParseException} import org.apache.spark.sql.catalyst.parser.SqlBaseParser._ @@ -378,8 +378,7 @@ class SparkSqlAstBuilder extends AstBuilder { override def visitRenameTable(ctx: RenameTableContext): LogicalPlan = withOrigin(ctx) { AlterTableRename( visitTableIdentifier(ctx.from), - visitTableIdentifier(ctx.to))( - command(ctx)) + visitTableIdentifier(ctx.to)) } /** @@ -395,8 +394,7 @@ class SparkSqlAstBuilder extends AstBuilder { ctx: SetTablePropertiesContext): LogicalPlan = withOrigin(ctx) { AlterTableSetProperties( visitTableIdentifier(ctx.tableIdentifier), - visitTablePropertyList(ctx.tablePropertyList))( - command(ctx)) + visitTablePropertyList(ctx.tablePropertyList)) } /** @@ -404,17 +402,16 @@ class SparkSqlAstBuilder extends AstBuilder { * * For example: * {{{ - * ALTER TABLE table UNSET TBLPROPERTIES IF EXISTS ('comment', 'key'); - * ALTER VIEW view UNSET TBLPROPERTIES IF EXISTS ('comment', 'key'); + * ALTER TABLE table UNSET TBLPROPERTIES [IF EXISTS] ('comment', 'key'); + * ALTER VIEW view UNSET TBLPROPERTIES [IF EXISTS] ('comment', 'key'); * }}} */ override def visitUnsetTableProperties( ctx: UnsetTablePropertiesContext): LogicalPlan = withOrigin(ctx) { AlterTableUnsetProperties( visitTableIdentifier(ctx.tableIdentifier), - visitTablePropertyList(ctx.tablePropertyList), - ctx.EXISTS != null)( - command(ctx)) + visitTablePropertyList(ctx.tablePropertyList).keys.toSeq, + ctx.EXISTS != null) } /** @@ -432,116 +429,41 @@ class SparkSqlAstBuilder extends AstBuilder { Option(ctx.STRING).map(string), Option(ctx.tablePropertyList).map(visitTablePropertyList), // TODO a partition spec is allowed to have optional values. This is currently violated. - Option(ctx.partitionSpec).map(visitNonOptionalPartitionSpec))( - command(ctx)) + Option(ctx.partitionSpec).map(visitNonOptionalPartitionSpec)) } - /** - * Create an [[AlterTableStorageProperties]] command. - * - * For example: - * {{{ - * ALTER TABLE table CLUSTERED BY (col, ...) [SORTED BY (col, ...)] INTO n BUCKETS; - * }}} - */ + // TODO: don't even bother parsing alter table commands related to bucketing and skewing + override def visitBucketTable(ctx: BucketTableContext): LogicalPlan = withOrigin(ctx) { - AlterTableStorageProperties( - visitTableIdentifier(ctx.tableIdentifier), - visitBucketSpec(ctx.bucketSpec))( - command(ctx)) + throw new AnalysisException( + "Operation not allowed: ALTER TABLE ... CLUSTERED BY ... INTO N BUCKETS") } - /** - * Create an [[AlterTableNotClustered]] command. - * - * For example: - * {{{ - * ALTER TABLE table NOT CLUSTERED; - * }}} - */ override def visitUnclusterTable(ctx: UnclusterTableContext): LogicalPlan = withOrigin(ctx) { - AlterTableNotClustered(visitTableIdentifier(ctx.tableIdentifier))(command(ctx)) + throw new AnalysisException("Operation not allowed: ALTER TABLE ... NOT CLUSTERED") } - /** - * Create an [[AlterTableNotSorted]] command. - * - * For example: - * {{{ - * ALTER TABLE table NOT SORTED; - * }}} - */ override def visitUnsortTable(ctx: UnsortTableContext): LogicalPlan = withOrigin(ctx) { - AlterTableNotSorted(visitTableIdentifier(ctx.tableIdentifier))(command(ctx)) + throw new AnalysisException("Operation not allowed: ALTER TABLE ... NOT SORTED") } - /** - * Create an [[AlterTableSkewed]] command. - * - * For example: - * {{{ - * ALTER TABLE table SKEWED BY (col1, col2) - * ON ((col1_value, col2_value) [, (col1_value, col2_value), ...]) - * [STORED AS DIRECTORIES]; - * }}} - */ override def visitSkewTable(ctx: SkewTableContext): LogicalPlan = withOrigin(ctx) { - val table = visitTableIdentifier(ctx.tableIdentifier) - val (cols, values, storedAsDirs) = visitSkewSpec(ctx.skewSpec) - AlterTableSkewed(table, cols, values, storedAsDirs)(command(ctx)) + throw new AnalysisException("Operation not allowed: ALTER TABLE ... SKEWED BY ...") } - /** - * Create an [[AlterTableNotSorted]] command. - * - * For example: - * {{{ - * ALTER TABLE table NOT SKEWED; - * }}} - */ override def visitUnskewTable(ctx: UnskewTableContext): LogicalPlan = withOrigin(ctx) { - AlterTableNotSkewed(visitTableIdentifier(ctx.tableIdentifier))(command(ctx)) + throw new AnalysisException("Operation not allowed: ALTER TABLE ... NOT SKEWED") } - /** - * Create an [[AlterTableNotStoredAsDirs]] command. - * - * For example: - * {{{ - * ALTER TABLE table NOT STORED AS DIRECTORIES - * }}} - */ override def visitUnstoreTable(ctx: UnstoreTableContext): LogicalPlan = withOrigin(ctx) { - AlterTableNotStoredAsDirs(visitTableIdentifier(ctx.tableIdentifier))(command(ctx)) + throw new AnalysisException( + "Operation not allowed: ALTER TABLE ... NOT STORED AS DIRECTORIES") } - /** - * Create an [[AlterTableSkewedLocation]] command. - * - * For example: - * {{{ - * ALTER TABLE table SET SKEWED LOCATION (col1="loc1" [, (col2, col3)="loc2", ...] ); - * }}} - */ override def visitSetTableSkewLocations( ctx: SetTableSkewLocationsContext): LogicalPlan = withOrigin(ctx) { - val skewedMap = ctx.skewedLocationList.skewedLocation.asScala.flatMap { - slCtx => - val location = string(slCtx.STRING) - if (slCtx.constant != null) { - Seq(visitStringConstant(slCtx.constant) -> location) - } else { - // TODO this is similar to what was in the original implementation. However this does not - // make to much sense to me since we should be storing a tuple of values (not column - // names) for which we want a dedicated storage location. - visitConstantList(slCtx.constantList).map(_ -> location) - } - }.toMap - - AlterTableSkewedLocation( - visitTableIdentifier(ctx.tableIdentifier), - skewedMap)( - command(ctx)) + throw new AnalysisException( + "Operation not allowed: ALTER TABLE ... SET SKEWED LOCATION ...") } /** @@ -703,8 +625,7 @@ class SparkSqlAstBuilder extends AstBuilder { AlterTableSetLocation( visitTableIdentifier(ctx.tableIdentifier), Option(ctx.partitionSpec).map(visitNonOptionalPartitionSpec), - visitLocationSpec(ctx.locationSpec))( - command(ctx)) + visitLocationSpec(ctx.locationSpec)) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala index 6896881910..0d38c41a3f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala @@ -18,12 +18,11 @@ package org.apache.spark.sql.execution.command import org.apache.spark.internal.Logging -import org.apache.spark.sql.{Row, SQLContext} +import org.apache.spark.sql.{AnalysisException, Row, SQLContext} import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.catalyst.catalog.CatalogDatabase +import org.apache.spark.sql.catalyst.catalog.{CatalogDatabase, CatalogTable} import org.apache.spark.sql.catalyst.catalog.ExternalCatalog.TablePartitionSpec import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} -import org.apache.spark.sql.execution.datasources.BucketSpec import org.apache.spark.sql.types._ @@ -175,67 +174,133 @@ case class DescribeDatabase( } } -/** Rename in ALTER TABLE/VIEW: change the name of a table/view to a different name. */ +/** + * A command that renames a table/view. + * + * The syntax of this command is: + * {{{ + * ALTER TABLE table1 RENAME TO table2; + * ALTER VIEW view1 RENAME TO view2; + * }}} + */ case class AlterTableRename( oldName: TableIdentifier, - newName: TableIdentifier)(sql: String) - extends NativeDDLCommand(sql) with Logging + newName: TableIdentifier) + extends RunnableCommand { + + override def run(sqlContext: SQLContext): Seq[Row] = { + val catalog = sqlContext.sessionState.catalog + catalog.invalidateTable(oldName) + catalog.renameTable(oldName, newName) + Seq.empty[Row] + } -/** Set Properties in ALTER TABLE/VIEW: add metadata to a table/view. */ +} + +/** + * A command that sets table/view properties. + * + * The syntax of this command is: + * {{{ + * ALTER TABLE table1 SET TBLPROPERTIES ('key1' = 'val1', 'key2' = 'val2', ...); + * ALTER VIEW view1 SET TBLPROPERTIES ('key1' = 'val1', 'key2' = 'val2', ...); + * }}} + */ case class AlterTableSetProperties( tableName: TableIdentifier, - properties: Map[String, String])(sql: String) - extends NativeDDLCommand(sql) with Logging + properties: Map[String, String]) + extends RunnableCommand { + + override def run(sqlContext: SQLContext): Seq[Row] = { + val catalog = sqlContext.sessionState.catalog + val table = catalog.getTable(tableName) + val newProperties = table.properties ++ properties + if (DDLUtils.isDatasourceTable(newProperties)) { + throw new AnalysisException( + "alter table properties is not supported for tables defined using the datasource API") + } + val newTable = table.copy(properties = newProperties) + catalog.alterTable(newTable) + Seq.empty[Row] + } + +} -/** Unset Properties in ALTER TABLE/VIEW: remove metadata from a table/view. */ +/** + * A command that unsets table/view properties. + * + * The syntax of this command is: + * {{{ + * ALTER TABLE table1 UNSET TBLPROPERTIES [IF EXISTS] ('key1', 'key2', ...); + * ALTER VIEW view1 UNSET TBLPROPERTIES [IF EXISTS] ('key1', 'key2', ...); + * }}} + */ case class AlterTableUnsetProperties( tableName: TableIdentifier, - properties: Map[String, String], - ifExists: Boolean)(sql: String) - extends NativeDDLCommand(sql) with Logging + propKeys: Seq[String], + ifExists: Boolean) + extends RunnableCommand { + + override def run(sqlContext: SQLContext): Seq[Row] = { + val catalog = sqlContext.sessionState.catalog + val table = catalog.getTable(tableName) + if (DDLUtils.isDatasourceTable(table)) { + throw new AnalysisException( + "alter table properties is not supported for datasource tables") + } + if (!ifExists) { + propKeys.foreach { k => + if (!table.properties.contains(k)) { + throw new AnalysisException( + s"attempted to unset non-existent property '$k' in table '$tableName'") + } + } + } + val newProperties = table.properties.filter { case (k, _) => !propKeys.contains(k) } + val newTable = table.copy(properties = newProperties) + catalog.alterTable(newTable) + Seq.empty[Row] + } + +} +/** + * A command that sets the serde class and/or serde properties of a table/view. + * + * The syntax of this command is: + * {{{ + * ALTER TABLE table [PARTITION spec] SET SERDE serde_name [WITH SERDEPROPERTIES props]; + * ALTER TABLE table [PARTITION spec] SET SERDEPROPERTIES serde_properties; + * }}} + */ case class AlterTableSerDeProperties( tableName: TableIdentifier, serdeClassName: Option[String], serdeProperties: Option[Map[String, String]], - partition: Option[Map[String, String]])(sql: String) - extends NativeDDLCommand(sql) with Logging - -case class AlterTableStorageProperties( - tableName: TableIdentifier, - buckets: BucketSpec)(sql: String) - extends NativeDDLCommand(sql) with Logging + partition: Option[Map[String, String]]) + extends RunnableCommand { -case class AlterTableNotClustered( - tableName: TableIdentifier)(sql: String) extends NativeDDLCommand(sql) with Logging + // should never happen if we parsed things correctly + require(serdeClassName.isDefined || serdeProperties.isDefined, + "alter table attempted to set neither serde class name nor serde properties") -case class AlterTableNotSorted( - tableName: TableIdentifier)(sql: String) extends NativeDDLCommand(sql) with Logging + override def run(sqlContext: SQLContext): Seq[Row] = { + val catalog = sqlContext.sessionState.catalog + val table = catalog.getTable(tableName) + // Do not support setting serde for datasource tables + if (serdeClassName.isDefined && DDLUtils.isDatasourceTable(table)) { + throw new AnalysisException( + "alter table serde is not supported for datasource tables") + } + val newTable = table.withNewStorage( + serde = serdeClassName.orElse(table.storage.serde), + serdeProperties = table.storage.serdeProperties ++ serdeProperties.getOrElse(Map())) + catalog.alterTable(newTable) + Seq.empty[Row] + } -case class AlterTableSkewed( - tableName: TableIdentifier, - // e.g. (dt, country) - skewedCols: Seq[String], - // e.g. ('2008-08-08', 'us), ('2009-09-09', 'uk') - skewedValues: Seq[Seq[String]], - storedAsDirs: Boolean)(sql: String) - extends NativeDDLCommand(sql) with Logging { - - require(skewedValues.forall(_.size == skewedCols.size), - "number of columns in skewed values do not match number of skewed columns provided") } -case class AlterTableNotSkewed( - tableName: TableIdentifier)(sql: String) extends NativeDDLCommand(sql) with Logging - -case class AlterTableNotStoredAsDirs( - tableName: TableIdentifier)(sql: String) extends NativeDDLCommand(sql) with Logging - -case class AlterTableSkewedLocation( - tableName: TableIdentifier, - skewedMap: Map[String, String])(sql: String) - extends NativeDDLCommand(sql) with Logging - /** * Add Partition in ALTER TABLE/VIEW: add the table/view partitions. * 'partitionSpecsAndLocs': the syntax of ALTER VIEW is identical to ALTER TABLE, @@ -292,11 +357,53 @@ case class AlterTableSetFileFormat( genericFormat: Option[String])(sql: String) extends NativeDDLCommand(sql) with Logging +/** + * A command that sets the location of a table or a partition. + * + * For normal tables, this just sets the location URI in the table/partition's storage format. + * For datasource tables, this sets a "path" parameter in the table/partition's serde properties. + * + * The syntax of this command is: + * {{{ + * ALTER TABLE table_name [PARTITION partition_spec] SET LOCATION "loc"; + * }}} + */ case class AlterTableSetLocation( tableName: TableIdentifier, partitionSpec: Option[TablePartitionSpec], - location: String)(sql: String) - extends NativeDDLCommand(sql) with Logging + location: String) + extends RunnableCommand { + + override def run(sqlContext: SQLContext): Seq[Row] = { + val catalog = sqlContext.sessionState.catalog + val table = catalog.getTable(tableName) + partitionSpec match { + case Some(spec) => + // Partition spec is specified, so we set the location only for this partition + val part = catalog.getPartition(tableName, spec) + val newPart = + if (DDLUtils.isDatasourceTable(table)) { + part.copy(storage = part.storage.copy( + serdeProperties = part.storage.serdeProperties ++ Map("path" -> location))) + } else { + part.copy(storage = part.storage.copy(locationUri = Some(location))) + } + catalog.alterPartitions(tableName, Seq(newPart)) + case None => + // No partition spec is specified, so we set the location for the table itself + val newTable = + if (DDLUtils.isDatasourceTable(table)) { + table.withNewStorage( + serdeProperties = table.storage.serdeProperties ++ Map("path" -> location)) + } else { + table.withNewStorage(locationUri = Some(location)) + } + catalog.alterTable(newTable) + } + Seq.empty[Row] + } + +} case class AlterTableTouch( tableName: TableIdentifier, @@ -341,3 +448,16 @@ case class AlterTableReplaceCol( restrict: Boolean, cascade: Boolean)(sql: String) extends NativeDDLCommand(sql) with Logging + + +private object DDLUtils { + + def isDatasourceTable(props: Map[String, String]): Boolean = { + props.contains("spark.sql.sources.provider") + } + + def isDatasourceTable(table: CatalogTable): Boolean = { + isDatasourceTable(table.properties) + } +} + diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala index c42e8e7233..618c9a58a6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala @@ -205,10 +205,10 @@ class DDLCommandSuite extends PlanTest { val parsed_view = parser.parsePlan(sql_view) val expected_table = AlterTableRename( TableIdentifier("table_name", None), - TableIdentifier("new_table_name", None))(sql_table) + TableIdentifier("new_table_name", None)) val expected_view = AlterTableRename( TableIdentifier("table_name", None), - TableIdentifier("new_table_name", None))(sql_view) + TableIdentifier("new_table_name", None)) comparePlans(parsed_table, expected_table) comparePlans(parsed_view, expected_view) } @@ -235,14 +235,14 @@ class DDLCommandSuite extends PlanTest { val tableIdent = TableIdentifier("table_name", None) val expected1_table = AlterTableSetProperties( - tableIdent, Map("test" -> "test", "comment" -> "new_comment"))(sql1_table) + tableIdent, Map("test" -> "test", "comment" -> "new_comment")) val expected2_table = AlterTableUnsetProperties( - tableIdent, Map("comment" -> null, "test" -> null), ifExists = false)(sql2_table) + tableIdent, Seq("comment", "test"), ifExists = false) val expected3_table = AlterTableUnsetProperties( - tableIdent, Map("comment" -> null, "test" -> null), ifExists = true)(sql3_table) - val expected1_view = expected1_table.copy()(sql = sql1_view) - val expected2_view = expected2_table.copy()(sql = sql2_view) - val expected3_view = expected3_table.copy()(sql = sql3_view) + tableIdent, Seq("comment", "test"), ifExists = true) + val expected1_view = expected1_table + val expected2_view = expected2_table + val expected3_view = expected3_table comparePlans(parsed1_table, expected1_table) comparePlans(parsed2_table, expected2_table) @@ -282,97 +282,24 @@ class DDLCommandSuite extends PlanTest { val parsed5 = parser.parsePlan(sql5) val tableIdent = TableIdentifier("table_name", None) val expected1 = AlterTableSerDeProperties( - tableIdent, Some("org.apache.class"), None, None)(sql1) + tableIdent, Some("org.apache.class"), None, None) val expected2 = AlterTableSerDeProperties( tableIdent, Some("org.apache.class"), Some(Map("columns" -> "foo,bar", "field.delim" -> ",")), - None)(sql2) + None) val expected3 = AlterTableSerDeProperties( - tableIdent, None, Some(Map("columns" -> "foo,bar", "field.delim" -> ",")), None)(sql3) + tableIdent, None, Some(Map("columns" -> "foo,bar", "field.delim" -> ",")), None) val expected4 = AlterTableSerDeProperties( tableIdent, Some("org.apache.class"), Some(Map("columns" -> "foo,bar", "field.delim" -> ",")), - Some(Map("test" -> null, "dt" -> "2008-08-08", "country" -> "us")))(sql4) + Some(Map("test" -> null, "dt" -> "2008-08-08", "country" -> "us"))) val expected5 = AlterTableSerDeProperties( tableIdent, None, Some(Map("columns" -> "foo,bar", "field.delim" -> ",")), - Some(Map("test" -> null, "dt" -> "2008-08-08", "country" -> "us")))(sql5) - comparePlans(parsed1, expected1) - comparePlans(parsed2, expected2) - comparePlans(parsed3, expected3) - comparePlans(parsed4, expected4) - comparePlans(parsed5, expected5) - } - - test("alter table: storage properties") { - val sql1 = "ALTER TABLE table_name CLUSTERED BY (dt, country) INTO 10 BUCKETS" - val sql2 = "ALTER TABLE table_name CLUSTERED BY (dt, country) SORTED BY " + - "(dt, country DESC) INTO 10 BUCKETS" - val sql3 = "ALTER TABLE table_name NOT CLUSTERED" - val sql4 = "ALTER TABLE table_name NOT SORTED" - val parsed1 = parser.parsePlan(sql1) - val parsed2 = parser.parsePlan(sql2) - val parsed3 = parser.parsePlan(sql3) - val parsed4 = parser.parsePlan(sql4) - val tableIdent = TableIdentifier("table_name", None) - val cols = List("dt", "country") - // TODO: also test the sort directions once we keep track of that - val expected1 = AlterTableStorageProperties( - tableIdent, BucketSpec(10, cols, Nil))(sql1) - val expected2 = AlterTableStorageProperties( - tableIdent, BucketSpec(10, cols, cols))(sql2) - val expected3 = AlterTableNotClustered(tableIdent)(sql3) - val expected4 = AlterTableNotSorted(tableIdent)(sql4) - comparePlans(parsed1, expected1) - comparePlans(parsed2, expected2) - comparePlans(parsed3, expected3) - comparePlans(parsed4, expected4) - } - - test("alter table: skewed") { - val sql1 = - """ - |ALTER TABLE table_name SKEWED BY (dt, country) ON - |(('2008-08-08', 'us'), ('2009-09-09', 'uk'), ('2010-10-10', 'cn')) STORED AS DIRECTORIES - """.stripMargin - val sql2 = - """ - |ALTER TABLE table_name SKEWED BY (dt, country) ON - |('2008-08-08', 'us') STORED AS DIRECTORIES - """.stripMargin - val sql3 = - """ - |ALTER TABLE table_name SKEWED BY (dt, country) ON - |(('2008-08-08', 'us'), ('2009-09-09', 'uk')) - """.stripMargin - val sql4 = "ALTER TABLE table_name NOT SKEWED" - val sql5 = "ALTER TABLE table_name NOT STORED AS DIRECTORIES" - val parsed1 = parser.parsePlan(sql1) - val parsed2 = parser.parsePlan(sql2) - val parsed3 = parser.parsePlan(sql3) - val parsed4 = parser.parsePlan(sql4) - val parsed5 = parser.parsePlan(sql5) - val tableIdent = TableIdentifier("table_name", None) - val expected1 = AlterTableSkewed( - tableIdent, - Seq("dt", "country"), - Seq(List("2008-08-08", "us"), List("2009-09-09", "uk"), List("2010-10-10", "cn")), - storedAsDirs = true)(sql1) - val expected2 = AlterTableSkewed( - tableIdent, - Seq("dt", "country"), - Seq(List("2008-08-08", "us")), - storedAsDirs = true)(sql2) - val expected3 = AlterTableSkewed( - tableIdent, - Seq("dt", "country"), - Seq(List("2008-08-08", "us"), List("2009-09-09", "uk")), - storedAsDirs = false)(sql3) - val expected4 = AlterTableNotSkewed(tableIdent)(sql4) - val expected5 = AlterTableNotStoredAsDirs(tableIdent)(sql5) + Some(Map("test" -> null, "dt" -> "2008-08-08", "country" -> "us"))) comparePlans(parsed1, expected1) comparePlans(parsed2, expected2) comparePlans(parsed3, expected3) @@ -380,30 +307,6 @@ class DDLCommandSuite extends PlanTest { comparePlans(parsed5, expected5) } - test("alter table: skewed location") { - val sql1 = - """ - |ALTER TABLE table_name SET SKEWED LOCATION - |('123'='location1', 'test'='location2') - """.stripMargin - val sql2 = - """ - |ALTER TABLE table_name SET SKEWED LOCATION - |(('2008-08-08', 'us')='location1', 'test'='location2') - """.stripMargin - val parsed1 = parser.parsePlan(sql1) - val parsed2 = parser.parsePlan(sql2) - val tableIdent = TableIdentifier("table_name", None) - val expected1 = AlterTableSkewedLocation( - tableIdent, - Map("123" -> "location1", "test" -> "location2"))(sql1) - val expected2 = AlterTableSkewedLocation( - tableIdent, - Map("2008-08-08" -> "location1", "us" -> "location1", "test" -> "location2"))(sql2) - comparePlans(parsed1, expected1) - comparePlans(parsed2, expected2) - } - // ALTER TABLE table_name ADD [IF NOT EXISTS] PARTITION partition_spec // [LOCATION 'location1'] partition_spec [LOCATION 'location2'] ...; test("alter table: add partition") { @@ -615,11 +518,11 @@ class DDLCommandSuite extends PlanTest { val expected1 = AlterTableSetLocation( tableIdent, None, - "new location")(sql1) + "new location") val expected2 = AlterTableSetLocation( tableIdent, Some(Map("dt" -> "2008-08-08", "country" -> "us")), - "new location")(sql2) + "new location") comparePlans(parsed1, expected1) comparePlans(parsed2, expected2) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala index 885a04af59..d8e2c94a8a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala @@ -19,34 +19,63 @@ package org.apache.spark.sql.execution.command import java.io.File +import org.scalatest.BeforeAndAfterEach + import org.apache.spark.sql.{AnalysisException, QueryTest, Row} -import org.apache.spark.sql.catalyst.catalog.CatalogDatabase +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.catalog.{CatalogDatabase, CatalogStorageFormat} +import org.apache.spark.sql.catalyst.catalog.{CatalogTable, CatalogTableType} +import org.apache.spark.sql.catalyst.catalog.{CatalogTablePartition, SessionCatalog} +import org.apache.spark.sql.catalyst.catalog.ExternalCatalog.TablePartitionSpec import org.apache.spark.sql.test.SharedSQLContext -class DDLSuite extends QueryTest with SharedSQLContext { - +class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { private val escapedIdentifier = "`(.+)`".r + override def afterEach(): Unit = { + try { + // drop all databases, tables and functions after each test + sqlContext.sessionState.catalog.reset() + } finally { + super.afterEach() + } + } + /** * Strip backticks, if any, from the string. */ - def cleanIdentifier(ident: String): String = { + private def cleanIdentifier(ident: String): String = { ident match { case escapedIdentifier(i) => i case plainIdent => plainIdent } } - /** - * Drops database `databaseName` after calling `f`. - */ - private def withDatabase(dbNames: String*)(f: => Unit): Unit = { - try f finally { - dbNames.foreach { name => - sqlContext.sql(s"DROP DATABASE IF EXISTS $name CASCADE") - } - sqlContext.sessionState.catalog.setCurrentDatabase("default") + private def assertUnsupported(query: String): Unit = { + val e = intercept[AnalysisException] { + sql(query) } + assert(e.getMessage.toLowerCase.contains("operation not allowed")) + } + + private def createDatabase(catalog: SessionCatalog, name: String): Unit = { + catalog.createDatabase(CatalogDatabase(name, "", "", Map()), ignoreIfExists = false) + } + + private def createTable(catalog: SessionCatalog, name: TableIdentifier): Unit = { + catalog.createTable(CatalogTable( + identifier = name, + tableType = CatalogTableType.EXTERNAL_TABLE, + storage = CatalogStorageFormat(None, None, None, None, Map()), + schema = Seq()), ignoreIfExists = false) + } + + private def createTablePartition( + catalog: SessionCatalog, + spec: TablePartitionSpec, + tableName: TableIdentifier): Unit = { + val part = CatalogTablePartition(spec, CatalogStorageFormat(None, None, None, None, Map())) + catalog.createPartitions(tableName, Seq(part), ignoreIfExists = false) } test("Create/Drop Database") { @@ -55,7 +84,7 @@ class DDLSuite extends QueryTest with SharedSQLContext { val databaseNames = Seq("db1", "`database`") databaseNames.foreach { dbName => - withDatabase(dbName) { + try { val dbNameWithoutBackTicks = cleanIdentifier(dbName) sql(s"CREATE DATABASE $dbName") @@ -67,6 +96,8 @@ class DDLSuite extends QueryTest with SharedSQLContext { Map.empty)) sql(s"DROP DATABASE $dbName CASCADE") assert(!catalog.databaseExists(dbNameWithoutBackTicks)) + } finally { + catalog.reset() } } } @@ -76,8 +107,8 @@ class DDLSuite extends QueryTest with SharedSQLContext { val databaseNames = Seq("db1", "`database`") databaseNames.foreach { dbName => - val dbNameWithoutBackTicks = cleanIdentifier(dbName) - withDatabase(dbName) { + try { + val dbNameWithoutBackTicks = cleanIdentifier(dbName) sql(s"CREATE DATABASE $dbName") val db1 = catalog.getDatabase(dbNameWithoutBackTicks) assert(db1 == CatalogDatabase( @@ -90,6 +121,8 @@ class DDLSuite extends QueryTest with SharedSQLContext { sql(s"CREATE DATABASE $dbName") }.getMessage assert(message.contains(s"Database '$dbNameWithoutBackTicks' already exists.")) + } finally { + catalog.reset() } } } @@ -99,7 +132,7 @@ class DDLSuite extends QueryTest with SharedSQLContext { val databaseNames = Seq("db1", "`database`") databaseNames.foreach { dbName => - withDatabase(dbName) { + try { val dbNameWithoutBackTicks = cleanIdentifier(dbName) val location = System.getProperty("java.io.tmpdir") + File.separator + s"$dbNameWithoutBackTicks.db" @@ -129,6 +162,8 @@ class DDLSuite extends QueryTest with SharedSQLContext { Row("Description", "") :: Row("Location", location) :: Row("Properties", "((a,a), (b,b), (c,c), (d,d))") :: Nil) + } finally { + catalog.reset() } } } @@ -159,6 +194,131 @@ class DDLSuite extends QueryTest with SharedSQLContext { } } + // TODO: test drop database in restrict mode + + test("alter table: rename") { + val catalog = sqlContext.sessionState.catalog + val tableIdent1 = TableIdentifier("tab1", Some("dbx")) + val tableIdent2 = TableIdentifier("tab2", Some("dbx")) + createDatabase(catalog, "dbx") + createDatabase(catalog, "dby") + createTable(catalog, tableIdent1) + assert(catalog.listTables("dbx") == Seq(tableIdent1)) + sql("ALTER TABLE dbx.tab1 RENAME TO dbx.tab2") + assert(catalog.listTables("dbx") == Seq(tableIdent2)) + catalog.setCurrentDatabase("dbx") + // rename without explicitly specifying database + sql("ALTER TABLE tab2 RENAME TO tab1") + assert(catalog.listTables("dbx") == Seq(tableIdent1)) + // table to rename does not exist + intercept[AnalysisException] { + sql("ALTER TABLE dbx.does_not_exist RENAME TO dbx.tab2") + } + // destination database is different + intercept[AnalysisException] { + sql("ALTER TABLE dbx.tab1 RENAME TO dby.tab2") + } + } + + test("alter table: set location") { + testSetLocation(isDatasourceTable = false) + } + + test("alter table: set location (datasource table)") { + testSetLocation(isDatasourceTable = true) + } + + test("alter table: set properties") { + val catalog = sqlContext.sessionState.catalog + val tableIdent = TableIdentifier("tab1", Some("dbx")) + createDatabase(catalog, "dbx") + createTable(catalog, tableIdent) + assert(catalog.getTable(tableIdent).properties.isEmpty) + // set table properties + sql("ALTER TABLE dbx.tab1 SET TBLPROPERTIES ('andrew' = 'or14', 'kor' = 'bel')") + assert(catalog.getTable(tableIdent).properties == Map("andrew" -> "or14", "kor" -> "bel")) + // set table properties without explicitly specifying database + catalog.setCurrentDatabase("dbx") + sql("ALTER TABLE tab1 SET TBLPROPERTIES ('kor' = 'belle', 'kar' = 'bol')") + assert(catalog.getTable(tableIdent).properties == + Map("andrew" -> "or14", "kor" -> "belle", "kar" -> "bol")) + // table to alter does not exist + intercept[AnalysisException] { + sql("ALTER TABLE does_not_exist SET TBLPROPERTIES ('winner' = 'loser')") + } + // throw exception for datasource tables + convertToDatasourceTable(catalog, tableIdent) + val e = intercept[AnalysisException] { + sql("ALTER TABLE tab1 SET TBLPROPERTIES ('sora' = 'bol')") + } + assert(e.getMessage.contains("datasource")) + } + + test("alter table: unset properties") { + val catalog = sqlContext.sessionState.catalog + val tableIdent = TableIdentifier("tab1", Some("dbx")) + createDatabase(catalog, "dbx") + createTable(catalog, tableIdent) + // unset table properties + sql("ALTER TABLE dbx.tab1 SET TBLPROPERTIES ('j' = 'am', 'p' = 'an', 'c' = 'lan')") + sql("ALTER TABLE dbx.tab1 UNSET TBLPROPERTIES ('j')") + assert(catalog.getTable(tableIdent).properties == Map("p" -> "an", "c" -> "lan")) + // unset table properties without explicitly specifying database + catalog.setCurrentDatabase("dbx") + sql("ALTER TABLE tab1 UNSET TBLPROPERTIES ('p')") + assert(catalog.getTable(tableIdent).properties == Map("c" -> "lan")) + // table to alter does not exist + intercept[AnalysisException] { + sql("ALTER TABLE does_not_exist UNSET TBLPROPERTIES ('c' = 'lan')") + } + // property to unset does not exist + val e = intercept[AnalysisException] { + sql("ALTER TABLE tab1 UNSET TBLPROPERTIES ('c', 'xyz')") + } + assert(e.getMessage.contains("xyz")) + // property to unset does not exist, but "IF EXISTS" is specified + sql("ALTER TABLE tab1 UNSET TBLPROPERTIES IF EXISTS ('c', 'xyz')") + assert(catalog.getTable(tableIdent).properties.isEmpty) + // throw exception for datasource tables + convertToDatasourceTable(catalog, tableIdent) + val e1 = intercept[AnalysisException] { + sql("ALTER TABLE tab1 UNSET TBLPROPERTIES ('sora')") + } + assert(e1.getMessage.contains("datasource")) + } + + test("alter table: set serde") { + testSetSerde(isDatasourceTable = false) + } + + test("alter table: set serde (datasource table)") { + testSetSerde(isDatasourceTable = true) + } + + test("alter table: bucketing is not supported") { + val catalog = sqlContext.sessionState.catalog + val tableIdent = TableIdentifier("tab1", Some("dbx")) + createDatabase(catalog, "dbx") + createTable(catalog, tableIdent) + assertUnsupported("ALTER TABLE dbx.tab1 CLUSTERED BY (blood, lemon, grape) INTO 11 BUCKETS") + assertUnsupported("ALTER TABLE dbx.tab1 CLUSTERED BY (fuji) SORTED BY (grape) INTO 5 BUCKETS") + assertUnsupported("ALTER TABLE dbx.tab1 NOT CLUSTERED") + assertUnsupported("ALTER TABLE dbx.tab1 NOT SORTED") + } + + test("alter table: skew is not supported") { + val catalog = sqlContext.sessionState.catalog + val tableIdent = TableIdentifier("tab1", Some("dbx")) + createDatabase(catalog, "dbx") + createTable(catalog, tableIdent) + assertUnsupported("ALTER TABLE dbx.tab1 SKEWED BY (dt, country) ON " + + "(('2008-08-08', 'us'), ('2009-09-09', 'uk'), ('2010-10-10', 'cn'))") + assertUnsupported("ALTER TABLE dbx.tab1 SKEWED BY (dt, country) ON " + + "(('2008-08-08', 'us'), ('2009-09-09', 'uk')) STORED AS DIRECTORIES") + assertUnsupported("ALTER TABLE dbx.tab1 NOT SKEWED") + assertUnsupported("ALTER TABLE dbx.tab1 NOT STORED AS DIRECTORIES") + } + // TODO: ADD a testcase for Drop Database in Restric when we can create tables in SQLContext test("show tables") { @@ -206,29 +366,129 @@ class DDLSuite extends QueryTest with SharedSQLContext { } test("show databases") { - withDatabase("showdb1A", "showdb2B") { - sql("CREATE DATABASE showdb1A") - sql("CREATE DATABASE showdb2B") + sql("CREATE DATABASE showdb1A") + sql("CREATE DATABASE showdb2B") - assert( - sql("SHOW DATABASES").count() >= 2) + assert( + sql("SHOW DATABASES").count() >= 2) - checkAnswer( - sql("SHOW DATABASES LIKE '*db1A'"), - Row("showdb1A") :: Nil) + checkAnswer( + sql("SHOW DATABASES LIKE '*db1A'"), + Row("showdb1A") :: Nil) - checkAnswer( - sql("SHOW DATABASES LIKE 'showdb1A'"), - Row("showdb1A") :: Nil) + checkAnswer( + sql("SHOW DATABASES LIKE 'showdb1A'"), + Row("showdb1A") :: Nil) - checkAnswer( - sql("SHOW DATABASES LIKE '*db1A|*db2B'"), - Row("showdb1A") :: - Row("showdb2B") :: Nil) + checkAnswer( + sql("SHOW DATABASES LIKE '*db1A|*db2B'"), + Row("showdb1A") :: + Row("showdb2B") :: Nil) - checkAnswer( - sql("SHOW DATABASES LIKE 'non-existentdb'"), - Nil) + checkAnswer( + sql("SHOW DATABASES LIKE 'non-existentdb'"), + Nil) + } + + private def convertToDatasourceTable( + catalog: SessionCatalog, + tableIdent: TableIdentifier): Unit = { + catalog.alterTable(catalog.getTable(tableIdent).copy( + properties = Map("spark.sql.sources.provider" -> "csv"))) + } + + private def testSetLocation(isDatasourceTable: Boolean): Unit = { + val catalog = sqlContext.sessionState.catalog + val tableIdent = TableIdentifier("tab1", Some("dbx")) + val partSpec = Map("a" -> "1") + createDatabase(catalog, "dbx") + createTable(catalog, tableIdent) + createTablePartition(catalog, partSpec, tableIdent) + if (isDatasourceTable) { + convertToDatasourceTable(catalog, tableIdent) + } + assert(catalog.getTable(tableIdent).storage.locationUri.isEmpty) + assert(catalog.getTable(tableIdent).storage.serdeProperties.isEmpty) + assert(catalog.getPartition(tableIdent, partSpec).storage.locationUri.isEmpty) + assert(catalog.getPartition(tableIdent, partSpec).storage.serdeProperties.isEmpty) + // Verify that the location is set to the expected string + def verifyLocation(expected: String, spec: Option[TablePartitionSpec] = None): Unit = { + val storageFormat = spec + .map { s => catalog.getPartition(tableIdent, s).storage } + .getOrElse { catalog.getTable(tableIdent).storage } + if (isDatasourceTable) { + assert(storageFormat.serdeProperties.get("path") === Some(expected)) + } else { + assert(storageFormat.locationUri === Some(expected)) + } + } + // set table location + sql("ALTER TABLE dbx.tab1 SET LOCATION '/path/to/your/lovely/heart'") + verifyLocation("/path/to/your/lovely/heart") + // set table partition location + sql("ALTER TABLE dbx.tab1 PARTITION (a='1') SET LOCATION '/path/to/part/ways'") + verifyLocation("/path/to/part/ways", Some(partSpec)) + // set table location without explicitly specifying database + catalog.setCurrentDatabase("dbx") + sql("ALTER TABLE tab1 SET LOCATION '/swanky/steak/place'") + verifyLocation("/swanky/steak/place") + // set table partition location without explicitly specifying database + sql("ALTER TABLE tab1 PARTITION (a='1') SET LOCATION 'vienna'") + verifyLocation("vienna", Some(partSpec)) + // table to alter does not exist + intercept[AnalysisException] { + sql("ALTER TABLE dbx.does_not_exist SET LOCATION '/mister/spark'") + } + // partition to alter does not exist + intercept[AnalysisException] { + sql("ALTER TABLE dbx.tab1 PARTITION (b='2') SET LOCATION '/mister/spark'") } } + + private def testSetSerde(isDatasourceTable: Boolean): Unit = { + val catalog = sqlContext.sessionState.catalog + val tableIdent = TableIdentifier("tab1", Some("dbx")) + createDatabase(catalog, "dbx") + createTable(catalog, tableIdent) + if (isDatasourceTable) { + convertToDatasourceTable(catalog, tableIdent) + } + assert(catalog.getTable(tableIdent).storage.serde.isEmpty) + assert(catalog.getTable(tableIdent).storage.serdeProperties.isEmpty) + // set table serde and/or properties (should fail on datasource tables) + if (isDatasourceTable) { + val e1 = intercept[AnalysisException] { + sql("ALTER TABLE dbx.tab1 SET SERDE 'whatever'") + } + val e2 = intercept[AnalysisException] { + sql("ALTER TABLE dbx.tab1 SET SERDE 'org.apache.madoop' " + + "WITH SERDEPROPERTIES ('k' = 'v', 'kay' = 'vee')") + } + assert(e1.getMessage.contains("datasource")) + assert(e2.getMessage.contains("datasource")) + } else { + sql("ALTER TABLE dbx.tab1 SET SERDE 'org.apache.jadoop'") + assert(catalog.getTable(tableIdent).storage.serde == Some("org.apache.jadoop")) + assert(catalog.getTable(tableIdent).storage.serdeProperties.isEmpty) + sql("ALTER TABLE dbx.tab1 SET SERDE 'org.apache.madoop' " + + "WITH SERDEPROPERTIES ('k' = 'v', 'kay' = 'vee')") + assert(catalog.getTable(tableIdent).storage.serde == Some("org.apache.madoop")) + assert(catalog.getTable(tableIdent).storage.serdeProperties == + Map("k" -> "v", "kay" -> "vee")) + } + // set serde properties only + sql("ALTER TABLE dbx.tab1 SET SERDEPROPERTIES ('k' = 'vvv', 'kay' = 'vee')") + assert(catalog.getTable(tableIdent).storage.serdeProperties == + Map("k" -> "vvv", "kay" -> "vee")) + // set things without explicitly specifying database + catalog.setCurrentDatabase("dbx") + sql("ALTER TABLE tab1 SET SERDEPROPERTIES ('kay' = 'veee')") + assert(catalog.getTable(tableIdent).storage.serdeProperties == + Map("k" -> "vvv", "kay" -> "veee")) + // table to alter does not exist + intercept[AnalysisException] { + sql("ALTER TABLE does_not_exist SET SERDEPROPERTIES ('x' = 'y')") + } + } + } diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala index 4b4f88ece0..b01f556f0a 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala @@ -360,6 +360,12 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "show_create_table_serde", "show_create_table_view", + // These tests try to change how a table is bucketed, which we don't support + "alter4", + "sort_merge_join_desc_5", + "sort_merge_join_desc_6", + "sort_merge_join_desc_7", + // Index commands are not supported "drop_index", "drop_index_removes_partition_dirs", @@ -381,7 +387,6 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "alias_casted_column", "alter2", "alter3", - "alter4", "alter5", "alter_merge_2", "alter_partition_format_loc", @@ -880,9 +885,6 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "sort_merge_join_desc_2", "sort_merge_join_desc_3", "sort_merge_join_desc_4", - "sort_merge_join_desc_5", - "sort_merge_join_desc_6", - "sort_merge_join_desc_7", "stats0", "stats_aggregator_error_1", "stats_empty_partition", diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala index d315f39a91..0cccc22e5a 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala @@ -94,7 +94,7 @@ private[sql] class HiveSessionCatalog( metastoreCatalog.refreshTable(name) } - def invalidateTable(name: TableIdentifier): Unit = { + override def invalidateTable(name: TableIdentifier): Unit = { metastoreCatalog.invalidateTable(name) } -- cgit v1.2.3 From 7329fe272d3ead7db9bc3e1e32adb7329dabc607 Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Tue, 5 Apr 2016 15:18:35 -0700 Subject: [SPARK-14411][SQL] Add a note to warn that onQueryProgress is asynchronous ## What changes were proposed in this pull request? onQueryProgress is asynchronous so the user may see some future status of `ContinuousQuery`. This PR just updated comments to warn it. ## How was this patch tested? Only updated comments. Author: Shixiong Zhu Closes #12180 from zsxwing/ContinuousQueryListener-doc. --- .../org/apache/spark/sql/util/ContinuousQueryListener.scala | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/util/ContinuousQueryListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/util/ContinuousQueryListener.scala index 2c5358cbd7..bf78be9d9f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/util/ContinuousQueryListener.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/util/ContinuousQueryListener.scala @@ -34,11 +34,19 @@ abstract class ContinuousQueryListener { * @note This is called synchronously with * [[org.apache.spark.sql.DataFrameWriter `DataFrameWriter.startStream()`]], * that is, `onQueryStart` will be called on all listeners before - * `DataFrameWriter.startStream()` returns the corresponding [[ContinuousQuery]]. + * `DataFrameWriter.startStream()` returns the corresponding [[ContinuousQuery]]. Please + * don't block this method as it will block your query. */ def onQueryStarted(queryStarted: QueryStarted) - /** Called when there is some status update (ingestion rate updated, etc. */ + /** + * Called when there is some status update (ingestion rate updated, etc.) + * + * @note This method is asynchronous. The status in [[ContinuousQuery]] will always be + * latest no matter when this method is called. Therefore, the status of [[ContinuousQuery]] + * may be changed before/when you process the event. E.g., you may find [[ContinuousQuery]] + * is terminated when you are processing [[QueryProgress]]. + */ def onQueryProgress(queryProgress: QueryProgress) /** Called when a query is stopped, with or without error */ -- cgit v1.2.3 From d5ee9d5c240fca5c15b21efc4a760b06a1f39fd6 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Tue, 5 Apr 2016 15:19:51 -0700 Subject: [SPARK-529][SQL] Modify SQLConf to use new config API from core. Because SQL keeps track of all known configs, some customization was needed in SQLConf to allow that, since the core API does not have that feature. Tested via existing (and slightly updated) unit tests. Author: Marcelo Vanzin Closes #11570 from vanzin/SPARK-529-sql. --- .../spark/internal/config/ConfigBuilder.scala | 44 +- .../org/apache/spark/internal/config/package.scala | 44 +- .../spark/internal/config/ConfigEntrySuite.scala | 28 +- project/SparkBuild.scala | 18 +- .../scala/org/apache/spark/sql/SQLContext.scala | 12 +- .../org/apache/spark/sql/internal/SQLConf.scala | 771 +++++++++------------ .../spark/sql/internal/SQLConfEntrySuite.scala | 29 +- .../apache/spark/sql/internal/SQLConfSuite.scala | 7 +- .../org/apache/spark/sql/hive/HiveContext.scala | 96 +-- .../org/apache/spark/deploy/yarn/config.scala | 92 +-- 10 files changed, 551 insertions(+), 590 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/internal/config/ConfigBuilder.scala b/core/src/main/scala/org/apache/spark/internal/config/ConfigBuilder.scala index 770b43697a..5d50e3851a 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/ConfigBuilder.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/ConfigBuilder.scala @@ -85,10 +85,12 @@ private[spark] class TypedConfigBuilder[T]( this(parent, converter, Option(_).map(_.toString).orNull) } + /** Apply a transformation to the user-provided values of the config entry. */ def transform(fn: T => T): TypedConfigBuilder[T] = { new TypedConfigBuilder(parent, s => fn(converter(s)), stringConverter) } + /** Check that user-provided values for the config match a pre-defined set. */ def checkValues(validValues: Set[T]): TypedConfigBuilder[T] = { transform { v => if (!validValues.contains(v)) { @@ -99,30 +101,38 @@ private[spark] class TypedConfigBuilder[T]( } } + /** Turns the config entry into a sequence of values of the underlying type. */ def toSequence: TypedConfigBuilder[Seq[T]] = { new TypedConfigBuilder(parent, stringToSeq(_, converter), seqToString(_, stringConverter)) } - /** Creates a [[ConfigEntry]] that does not require a default value. */ - def optional: OptionalConfigEntry[T] = { - new OptionalConfigEntry[T](parent.key, converter, stringConverter, parent._doc, parent._public) + /** Creates a [[ConfigEntry]] that does not have a default value. */ + def createOptional: OptionalConfigEntry[T] = { + val entry = new OptionalConfigEntry[T](parent.key, converter, stringConverter, parent._doc, + parent._public) + parent._onCreate.foreach(_(entry)) + entry } /** Creates a [[ConfigEntry]] that has a default value. */ - def withDefault(default: T): ConfigEntry[T] = { + def createWithDefault(default: T): ConfigEntry[T] = { val transformedDefault = converter(stringConverter(default)) - new ConfigEntryWithDefault[T](parent.key, transformedDefault, converter, stringConverter, - parent._doc, parent._public) + val entry = new ConfigEntryWithDefault[T](parent.key, transformedDefault, converter, + stringConverter, parent._doc, parent._public) + parent._onCreate.foreach(_(entry)) + entry } /** * Creates a [[ConfigEntry]] that has a default value. The default value is provided as a * [[String]] and must be a valid value for the entry. */ - def withDefaultString(default: String): ConfigEntry[T] = { + def createWithDefaultString(default: String): ConfigEntry[T] = { val typedDefault = converter(default) - new ConfigEntryWithDefault[T](parent.key, typedDefault, converter, stringConverter, parent._doc, - parent._public) + val entry = new ConfigEntryWithDefault[T](parent.key, typedDefault, converter, stringConverter, + parent._doc, parent._public) + parent._onCreate.foreach(_(entry)) + entry } } @@ -136,10 +146,11 @@ private[spark] case class ConfigBuilder(key: String) { import ConfigHelpers._ - var _public = true - var _doc = "" + private[config] var _public = true + private[config] var _doc = "" + private[config] var _onCreate: Option[ConfigEntry[_] => Unit] = None - def internal: ConfigBuilder = { + def internal(): ConfigBuilder = { _public = false this } @@ -149,6 +160,15 @@ private[spark] case class ConfigBuilder(key: String) { this } + /** + * Registers a callback for when the config entry is finally instantiated. Currently used by + * SQLConf to keep track of SQL configuration entries. + */ + def onCreate(callback: ConfigEntry[_] => Unit): ConfigBuilder = { + _onCreate = Option(callback) + this + } + def intConf: TypedConfigBuilder[Int] = { new TypedConfigBuilder(this, toNumber(_, _.toInt, key, "int")) } diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index 968c5192ac..94b50ee065 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -23,68 +23,70 @@ import org.apache.spark.network.util.ByteUnit package object config { private[spark] val DRIVER_CLASS_PATH = - ConfigBuilder(SparkLauncher.DRIVER_EXTRA_CLASSPATH).stringConf.optional + ConfigBuilder(SparkLauncher.DRIVER_EXTRA_CLASSPATH).stringConf.createOptional private[spark] val DRIVER_JAVA_OPTIONS = - ConfigBuilder(SparkLauncher.DRIVER_EXTRA_JAVA_OPTIONS).stringConf.optional + ConfigBuilder(SparkLauncher.DRIVER_EXTRA_JAVA_OPTIONS).stringConf.createOptional private[spark] val DRIVER_LIBRARY_PATH = - ConfigBuilder(SparkLauncher.DRIVER_EXTRA_LIBRARY_PATH).stringConf.optional + ConfigBuilder(SparkLauncher.DRIVER_EXTRA_LIBRARY_PATH).stringConf.createOptional private[spark] val DRIVER_USER_CLASS_PATH_FIRST = - ConfigBuilder("spark.driver.userClassPathFirst").booleanConf.withDefault(false) + ConfigBuilder("spark.driver.userClassPathFirst").booleanConf.createWithDefault(false) private[spark] val DRIVER_MEMORY = ConfigBuilder("spark.driver.memory") .bytesConf(ByteUnit.MiB) - .withDefaultString("1g") + .createWithDefaultString("1g") private[spark] val EXECUTOR_CLASS_PATH = - ConfigBuilder(SparkLauncher.EXECUTOR_EXTRA_CLASSPATH).stringConf.optional + ConfigBuilder(SparkLauncher.EXECUTOR_EXTRA_CLASSPATH).stringConf.createOptional private[spark] val EXECUTOR_JAVA_OPTIONS = - ConfigBuilder(SparkLauncher.EXECUTOR_EXTRA_JAVA_OPTIONS).stringConf.optional + ConfigBuilder(SparkLauncher.EXECUTOR_EXTRA_JAVA_OPTIONS).stringConf.createOptional private[spark] val EXECUTOR_LIBRARY_PATH = - ConfigBuilder(SparkLauncher.EXECUTOR_EXTRA_LIBRARY_PATH).stringConf.optional + ConfigBuilder(SparkLauncher.EXECUTOR_EXTRA_LIBRARY_PATH).stringConf.createOptional private[spark] val EXECUTOR_USER_CLASS_PATH_FIRST = - ConfigBuilder("spark.executor.userClassPathFirst").booleanConf.withDefault(false) + ConfigBuilder("spark.executor.userClassPathFirst").booleanConf.createWithDefault(false) private[spark] val EXECUTOR_MEMORY = ConfigBuilder("spark.executor.memory") .bytesConf(ByteUnit.MiB) - .withDefaultString("1g") + .createWithDefaultString("1g") - private[spark] val IS_PYTHON_APP = ConfigBuilder("spark.yarn.isPython").internal - .booleanConf.withDefault(false) + private[spark] val IS_PYTHON_APP = ConfigBuilder("spark.yarn.isPython").internal() + .booleanConf.createWithDefault(false) - private[spark] val CPUS_PER_TASK = ConfigBuilder("spark.task.cpus").intConf.withDefault(1) + private[spark] val CPUS_PER_TASK = ConfigBuilder("spark.task.cpus").intConf.createWithDefault(1) private[spark] val DYN_ALLOCATION_MIN_EXECUTORS = - ConfigBuilder("spark.dynamicAllocation.minExecutors").intConf.withDefault(0) + ConfigBuilder("spark.dynamicAllocation.minExecutors").intConf.createWithDefault(0) private[spark] val DYN_ALLOCATION_INITIAL_EXECUTORS = ConfigBuilder("spark.dynamicAllocation.initialExecutors") .fallbackConf(DYN_ALLOCATION_MIN_EXECUTORS) private[spark] val DYN_ALLOCATION_MAX_EXECUTORS = - ConfigBuilder("spark.dynamicAllocation.maxExecutors").intConf.withDefault(Int.MaxValue) + ConfigBuilder("spark.dynamicAllocation.maxExecutors").intConf.createWithDefault(Int.MaxValue) private[spark] val SHUFFLE_SERVICE_ENABLED = - ConfigBuilder("spark.shuffle.service.enabled").booleanConf.withDefault(false) + ConfigBuilder("spark.shuffle.service.enabled").booleanConf.createWithDefault(false) private[spark] val KEYTAB = ConfigBuilder("spark.yarn.keytab") .doc("Location of user's keytab.") - .stringConf.optional + .stringConf.createOptional private[spark] val PRINCIPAL = ConfigBuilder("spark.yarn.principal") .doc("Name of the Kerberos principal.") - .stringConf.optional + .stringConf.createOptional - private[spark] val EXECUTOR_INSTANCES = ConfigBuilder("spark.executor.instances").intConf.optional + private[spark] val EXECUTOR_INSTANCES = ConfigBuilder("spark.executor.instances") + .intConf + .createOptional private[spark] val PY_FILES = ConfigBuilder("spark.submit.pyFiles") - .internal + .internal() .stringConf .toSequence - .withDefault(Nil) + .createWithDefault(Nil) } diff --git a/core/src/test/scala/org/apache/spark/internal/config/ConfigEntrySuite.scala b/core/src/test/scala/org/apache/spark/internal/config/ConfigEntrySuite.scala index 0644148eae..337fd7e85e 100644 --- a/core/src/test/scala/org/apache/spark/internal/config/ConfigEntrySuite.scala +++ b/core/src/test/scala/org/apache/spark/internal/config/ConfigEntrySuite.scala @@ -26,7 +26,7 @@ class ConfigEntrySuite extends SparkFunSuite { test("conf entry: int") { val conf = new SparkConf() - val iConf = ConfigBuilder("spark.int").intConf.withDefault(1) + val iConf = ConfigBuilder("spark.int").intConf.createWithDefault(1) assert(conf.get(iConf) === 1) conf.set(iConf, 2) assert(conf.get(iConf) === 2) @@ -34,21 +34,21 @@ class ConfigEntrySuite extends SparkFunSuite { test("conf entry: long") { val conf = new SparkConf() - val lConf = ConfigBuilder("spark.long").longConf.withDefault(0L) + val lConf = ConfigBuilder("spark.long").longConf.createWithDefault(0L) conf.set(lConf, 1234L) assert(conf.get(lConf) === 1234L) } test("conf entry: double") { val conf = new SparkConf() - val dConf = ConfigBuilder("spark.double").doubleConf.withDefault(0.0) + val dConf = ConfigBuilder("spark.double").doubleConf.createWithDefault(0.0) conf.set(dConf, 20.0) assert(conf.get(dConf) === 20.0) } test("conf entry: boolean") { val conf = new SparkConf() - val bConf = ConfigBuilder("spark.boolean").booleanConf.withDefault(false) + val bConf = ConfigBuilder("spark.boolean").booleanConf.createWithDefault(false) assert(!conf.get(bConf)) conf.set(bConf, true) assert(conf.get(bConf)) @@ -56,7 +56,7 @@ class ConfigEntrySuite extends SparkFunSuite { test("conf entry: optional") { val conf = new SparkConf() - val optionalConf = ConfigBuilder("spark.optional").intConf.optional + val optionalConf = ConfigBuilder("spark.optional").intConf.createOptional assert(conf.get(optionalConf) === None) conf.set(optionalConf, 1) assert(conf.get(optionalConf) === Some(1)) @@ -64,7 +64,7 @@ class ConfigEntrySuite extends SparkFunSuite { test("conf entry: fallback") { val conf = new SparkConf() - val parentConf = ConfigBuilder("spark.int").intConf.withDefault(1) + val parentConf = ConfigBuilder("spark.int").intConf.createWithDefault(1) val confWithFallback = ConfigBuilder("spark.fallback").fallbackConf(parentConf) assert(conf.get(confWithFallback) === 1) conf.set(confWithFallback, 2) @@ -74,7 +74,7 @@ class ConfigEntrySuite extends SparkFunSuite { test("conf entry: time") { val conf = new SparkConf() - val time = ConfigBuilder("spark.time").timeConf(TimeUnit.SECONDS).withDefaultString("1h") + val time = ConfigBuilder("spark.time").timeConf(TimeUnit.SECONDS).createWithDefaultString("1h") assert(conf.get(time) === 3600L) conf.set(time.key, "1m") assert(conf.get(time) === 60L) @@ -82,7 +82,7 @@ class ConfigEntrySuite extends SparkFunSuite { test("conf entry: bytes") { val conf = new SparkConf() - val bytes = ConfigBuilder("spark.bytes").bytesConf(ByteUnit.KiB).withDefaultString("1m") + val bytes = ConfigBuilder("spark.bytes").bytesConf(ByteUnit.KiB).createWithDefaultString("1m") assert(conf.get(bytes) === 1024L) conf.set(bytes.key, "1k") assert(conf.get(bytes) === 1L) @@ -90,7 +90,7 @@ class ConfigEntrySuite extends SparkFunSuite { test("conf entry: string seq") { val conf = new SparkConf() - val seq = ConfigBuilder("spark.seq").stringConf.toSequence.withDefault(Seq()) + val seq = ConfigBuilder("spark.seq").stringConf.toSequence.createWithDefault(Seq()) conf.set(seq.key, "1,,2, 3 , , 4") assert(conf.get(seq) === Seq("1", "2", "3", "4")) conf.set(seq, Seq("1", "2")) @@ -99,7 +99,7 @@ class ConfigEntrySuite extends SparkFunSuite { test("conf entry: int seq") { val conf = new SparkConf() - val seq = ConfigBuilder("spark.seq").intConf.toSequence.withDefault(Seq()) + val seq = ConfigBuilder("spark.seq").intConf.toSequence.createWithDefault(Seq()) conf.set(seq.key, "1,,2, 3 , , 4") assert(conf.get(seq) === Seq(1, 2, 3, 4)) conf.set(seq, Seq(1, 2)) @@ -111,7 +111,7 @@ class ConfigEntrySuite extends SparkFunSuite { val transformationConf = ConfigBuilder("spark.transformation") .stringConf .transform(_.toLowerCase()) - .withDefault("FOO") + .createWithDefault("FOO") assert(conf.get(transformationConf) === "foo") conf.set(transformationConf, "BAR") @@ -123,7 +123,7 @@ class ConfigEntrySuite extends SparkFunSuite { val enum = ConfigBuilder("spark.enum") .stringConf .checkValues(Set("a", "b", "c")) - .withDefault("a") + .createWithDefault("a") assert(conf.get(enum) === "a") conf.set(enum, "b") @@ -138,7 +138,7 @@ class ConfigEntrySuite extends SparkFunSuite { test("conf entry: conversion error") { val conf = new SparkConf() - val conversionTest = ConfigBuilder("spark.conversionTest").doubleConf.optional + val conversionTest = ConfigBuilder("spark.conversionTest").doubleConf.createOptional conf.set(conversionTest.key, "abc") val conversionError = intercept[IllegalArgumentException] { conf.get(conversionTest) @@ -148,7 +148,7 @@ class ConfigEntrySuite extends SparkFunSuite { test("default value handling is null-safe") { val conf = new SparkConf() - val stringConf = ConfigBuilder("spark.string").stringConf.withDefault(null) + val stringConf = ConfigBuilder("spark.string").stringConf.createWithDefault(null) assert(conf.get(stringConf) === null) } diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index b32480b164..60124ef0a1 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -20,6 +20,7 @@ import java.nio.file.Files import scala.util.Properties import scala.collection.JavaConverters._ +import scala.collection.mutable.Stack import sbt._ import sbt.Classpaths.publishTask @@ -742,8 +743,21 @@ object TestSettings { parallelExecution in Test := false, // Make sure the test temp directory exists. resourceGenerators in Test <+= resourceManaged in Test map { outDir: File => - if (!new File(testTempDir).isDirectory()) { - require(new File(testTempDir).mkdirs(), s"Error creating temp directory $testTempDir.") + var dir = new File(testTempDir) + if (!dir.isDirectory()) { + // Because File.mkdirs() can fail if multiple callers are trying to create the same + // parent directory, this code tries to create parents one at a time, and avoids + // failures when the directories have been created by somebody else. + val stack = new Stack[File]() + while (!dir.isDirectory()) { + stack.push(dir) + dir = dir.getParentFile() + } + + while (stack.nonEmpty) { + val d = stack.pop() + require(d.mkdir() || d.isDirectory(), s"Failed to create directory $d") + } } Seq[File]() }, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 587ba1ea05..1c9cb79ba4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -29,6 +29,7 @@ import org.apache.spark.{SparkConf, SparkContext, SparkException} import org.apache.spark.annotation.{DeveloperApi, Experimental} import org.apache.spark.api.java.{JavaRDD, JavaSparkContext} import org.apache.spark.internal.Logging +import org.apache.spark.internal.config.ConfigEntry import org.apache.spark.rdd.RDD import org.apache.spark.scheduler.{SparkListener, SparkListenerApplicationEnd} import org.apache.spark.sql.catalyst._ @@ -41,7 +42,6 @@ import org.apache.spark.sql.execution.command.ShowTablesCommand import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.ui.{SQLListener, SQLTab} import org.apache.spark.sql.internal.{SessionState, SQLConf} -import org.apache.spark.sql.internal.SQLConf.SQLConfEntry import org.apache.spark.sql.sources.BaseRelation import org.apache.spark.sql.types._ import org.apache.spark.sql.util.ExecutionListenerManager @@ -138,7 +138,7 @@ class SQLContext private[sql]( def setConf(props: Properties): Unit = conf.setConf(props) /** Set the given Spark SQL configuration property. */ - private[sql] def setConf[T](entry: SQLConfEntry[T], value: T): Unit = conf.setConf(entry, value) + private[sql] def setConf[T](entry: ConfigEntry[T], value: T): Unit = conf.setConf(entry, value) /** * Set the given Spark SQL configuration property. @@ -158,16 +158,16 @@ class SQLContext private[sql]( /** * Return the value of Spark SQL configuration property for the given key. If the key is not set - * yet, return `defaultValue` in [[SQLConfEntry]]. + * yet, return `defaultValue` in [[ConfigEntry]]. */ - private[sql] def getConf[T](entry: SQLConfEntry[T]): T = conf.getConf(entry) + private[sql] def getConf[T](entry: ConfigEntry[T]): T = conf.getConf(entry) /** * Return the value of Spark SQL configuration property for the given key. If the key is not set - * yet, return `defaultValue`. This is useful when `defaultValue` in SQLConfEntry is not the + * yet, return `defaultValue`. This is useful when `defaultValue` in ConfigEntry is not the * desired one. */ - private[sql] def getConf[T](entry: SQLConfEntry[T], defaultValue: T): T = { + private[sql] def getConf[T](entry: ConfigEntry[T], defaultValue: T): T = { conf.getConf(entry, defaultValue) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index a7c0be63fc..927af89949 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -25,6 +25,8 @@ import scala.collection.immutable import org.apache.parquet.hadoop.ParquetOutputCommitter import org.apache.spark.internal.Logging +import org.apache.spark.internal.config._ +import org.apache.spark.network.util.ByteUnit import org.apache.spark.sql.catalyst.CatalystConf import org.apache.spark.util.Utils @@ -36,418 +38,305 @@ import org.apache.spark.util.Utils object SQLConf { private val sqlConfEntries = java.util.Collections.synchronizedMap( - new java.util.HashMap[String, SQLConfEntry[_]]()) + new java.util.HashMap[String, ConfigEntry[_]]()) - /** - * An entry contains all meta information for a configuration. - * - * @param key the key for the configuration - * @param defaultValue the default value for the configuration - * @param valueConverter how to convert a string to the value. It should throw an exception if the - * string does not have the required format. - * @param stringConverter how to convert a value to a string that the user can use it as a valid - * string value. It's usually `toString`. But sometimes, a custom converter - * is necessary. E.g., if T is List[String], `a, b, c` is better than - * `List(a, b, c)`. - * @param doc the document for the configuration - * @param isPublic if this configuration is public to the user. If it's `false`, this - * configuration is only used internally and we should not expose it to the user. - * @tparam T the value type - */ - class SQLConfEntry[T] private( - val key: String, - val defaultValue: Option[T], - val valueConverter: String => T, - val stringConverter: T => String, - val doc: String, - val isPublic: Boolean) { - - def defaultValueString: String = defaultValue.map(stringConverter).getOrElse("") - - override def toString: String = { - s"SQLConfEntry(key = $key, defaultValue=$defaultValueString, doc=$doc, isPublic = $isPublic)" - } + private def register(entry: ConfigEntry[_]): Unit = sqlConfEntries.synchronized { + require(!sqlConfEntries.containsKey(entry.key), + s"Duplicate SQLConfigEntry. ${entry.key} has been registered") + sqlConfEntries.put(entry.key, entry) } - object SQLConfEntry { - - private def apply[T]( - key: String, - defaultValue: Option[T], - valueConverter: String => T, - stringConverter: T => String, - doc: String, - isPublic: Boolean): SQLConfEntry[T] = - sqlConfEntries.synchronized { - if (sqlConfEntries.containsKey(key)) { - throw new IllegalArgumentException(s"Duplicate SQLConfEntry. $key has been registered") - } - val entry = - new SQLConfEntry[T](key, defaultValue, valueConverter, stringConverter, doc, isPublic) - sqlConfEntries.put(key, entry) - entry - } - - def intConf( - key: String, - defaultValue: Option[Int] = None, - doc: String = "", - isPublic: Boolean = true): SQLConfEntry[Int] = - SQLConfEntry(key, defaultValue, { v => - try { - v.toInt - } catch { - case _: NumberFormatException => - throw new IllegalArgumentException(s"$key should be int, but was $v") - } - }, _.toString, doc, isPublic) - - def longConf( - key: String, - defaultValue: Option[Long] = None, - doc: String = "", - isPublic: Boolean = true): SQLConfEntry[Long] = - SQLConfEntry(key, defaultValue, { v => - try { - v.toLong - } catch { - case _: NumberFormatException => - throw new IllegalArgumentException(s"$key should be long, but was $v") - } - }, _.toString, doc, isPublic) - - def longMemConf( - key: String, - defaultValue: Option[Long] = None, - doc: String = "", - isPublic: Boolean = true): SQLConfEntry[Long] = - SQLConfEntry(key, defaultValue, { v => - try { - v.toLong - } catch { - case _: NumberFormatException => - try { - Utils.byteStringAsBytes(v) - } catch { - case _: NumberFormatException => - throw new IllegalArgumentException(s"$key should be long, but was $v") - } - } - }, _.toString, doc, isPublic) - - def doubleConf( - key: String, - defaultValue: Option[Double] = None, - doc: String = "", - isPublic: Boolean = true): SQLConfEntry[Double] = - SQLConfEntry(key, defaultValue, { v => - try { - v.toDouble - } catch { - case _: NumberFormatException => - throw new IllegalArgumentException(s"$key should be double, but was $v") - } - }, _.toString, doc, isPublic) - - def booleanConf( - key: String, - defaultValue: Option[Boolean] = None, - doc: String = "", - isPublic: Boolean = true): SQLConfEntry[Boolean] = - SQLConfEntry(key, defaultValue, { v => - try { - v.toBoolean - } catch { - case _: IllegalArgumentException => - throw new IllegalArgumentException(s"$key should be boolean, but was $v") - } - }, _.toString, doc, isPublic) - - def stringConf( - key: String, - defaultValue: Option[String] = None, - doc: String = "", - isPublic: Boolean = true): SQLConfEntry[String] = - SQLConfEntry(key, defaultValue, v => v, v => v, doc, isPublic) - - def enumConf[T]( - key: String, - valueConverter: String => T, - validValues: Set[T], - defaultValue: Option[T] = None, - doc: String = "", - isPublic: Boolean = true): SQLConfEntry[T] = - SQLConfEntry(key, defaultValue, v => { - val _v = valueConverter(v) - if (!validValues.contains(_v)) { - throw new IllegalArgumentException( - s"The value of $key should be one of ${validValues.mkString(", ")}, but was $v") - } - _v - }, _.toString, doc, isPublic) - - def seqConf[T]( - key: String, - valueConverter: String => T, - defaultValue: Option[Seq[T]] = None, - doc: String = "", - isPublic: Boolean = true): SQLConfEntry[Seq[T]] = { - SQLConfEntry( - key, defaultValue, _.split(",").map(valueConverter), _.mkString(","), doc, isPublic) - } + private[sql] object SQLConfigBuilder { - def stringSeqConf( - key: String, - defaultValue: Option[Seq[String]] = None, - doc: String = "", - isPublic: Boolean = true): SQLConfEntry[Seq[String]] = { - seqConf(key, s => s, defaultValue, doc, isPublic) - } - } + def apply(key: String): ConfigBuilder = new ConfigBuilder(key).onCreate(register) - import SQLConfEntry._ + } - val ALLOW_MULTIPLE_CONTEXTS = booleanConf("spark.sql.allowMultipleContexts", - defaultValue = Some(true), - doc = "When set to true, creating multiple SQLContexts/HiveContexts is allowed. " + + val ALLOW_MULTIPLE_CONTEXTS = SQLConfigBuilder("spark.sql.allowMultipleContexts") + .doc("When set to true, creating multiple SQLContexts/HiveContexts is allowed. " + "When set to false, only one SQLContext/HiveContext is allowed to be created " + "through the constructor (new SQLContexts/HiveContexts created through newSession " + "method is allowed). Please note that this conf needs to be set in Spark Conf. Once " + "a SQLContext/HiveContext has been created, changing the value of this conf will not " + - "have effect.", - isPublic = true) - - val COMPRESS_CACHED = booleanConf("spark.sql.inMemoryColumnarStorage.compressed", - defaultValue = Some(true), - doc = "When set to true Spark SQL will automatically select a compression codec for each " + - "column based on statistics of the data.", - isPublic = false) - - val COLUMN_BATCH_SIZE = intConf("spark.sql.inMemoryColumnarStorage.batchSize", - defaultValue = Some(10000), - doc = "Controls the size of batches for columnar caching. Larger batch sizes can improve " + - "memory utilization and compression, but risk OOMs when caching data.", - isPublic = false) + "have effect.") + .booleanConf + .createWithDefault(true) + + val COMPRESS_CACHED = SQLConfigBuilder("spark.sql.inMemoryColumnarStorage.compressed") + .internal() + .doc("When set to true Spark SQL will automatically select a compression codec for each " + + "column based on statistics of the data.") + .booleanConf + .createWithDefault(true) + + val COLUMN_BATCH_SIZE = SQLConfigBuilder("spark.sql.inMemoryColumnarStorage.batchSize") + .internal() + .doc("Controls the size of batches for columnar caching. Larger batch sizes can improve " + + "memory utilization and compression, but risk OOMs when caching data.") + .intConf + .createWithDefault(10000) val IN_MEMORY_PARTITION_PRUNING = - booleanConf("spark.sql.inMemoryColumnarStorage.partitionPruning", - defaultValue = Some(true), - doc = "When true, enable partition pruning for in-memory columnar tables.", - isPublic = false) - - val PREFER_SORTMERGEJOIN = booleanConf("spark.sql.join.preferSortMergeJoin", - defaultValue = Some(true), - doc = "When true, prefer sort merge join over shuffle hash join.", - isPublic = false) - - val AUTO_BROADCASTJOIN_THRESHOLD = intConf("spark.sql.autoBroadcastJoinThreshold", - defaultValue = Some(10 * 1024 * 1024), - doc = "Configures the maximum size in bytes for a table that will be broadcast to all worker " + + SQLConfigBuilder("spark.sql.inMemoryColumnarStorage.partitionPruning") + .internal() + .doc("When true, enable partition pruning for in-memory columnar tables.") + .booleanConf + .createWithDefault(true) + + val PREFER_SORTMERGEJOIN = SQLConfigBuilder("spark.sql.join.preferSortMergeJoin") + .internal() + .doc("When true, prefer sort merge join over shuffle hash join.") + .booleanConf + .createWithDefault(true) + + val AUTO_BROADCASTJOIN_THRESHOLD = SQLConfigBuilder("spark.sql.autoBroadcastJoinThreshold") + .doc("Configures the maximum size in bytes for a table that will be broadcast to all worker " + "nodes when performing a join. By setting this value to -1 broadcasting can be disabled. " + "Note that currently statistics are only supported for Hive Metastore tables where the " + "commandANALYZE TABLE <tableName> COMPUTE STATISTICS noscan has been run.") + .intConf + .createWithDefault(10 * 1024 * 1024) - val DEFAULT_SIZE_IN_BYTES = longConf( - "spark.sql.defaultSizeInBytes", - doc = "The default table size used in query planning. By default, it is set to a larger " + + val DEFAULT_SIZE_IN_BYTES = SQLConfigBuilder("spark.sql.defaultSizeInBytes") + .internal() + .doc("The default table size used in query planning. By default, it is set to a larger " + "value than `spark.sql.autoBroadcastJoinThreshold` to be more conservative. That is to say " + "by default the optimizer will not choose to broadcast a table unless it knows for sure " + - "its size is small enough.", - isPublic = false) + "its size is small enough.") + .longConf + .createWithDefault(-1) - val SHUFFLE_PARTITIONS = intConf("spark.sql.shuffle.partitions", - defaultValue = Some(200), - doc = "The default number of partitions to use when shuffling data for joins or aggregations.") + val SHUFFLE_PARTITIONS = SQLConfigBuilder("spark.sql.shuffle.partitions") + .doc("The default number of partitions to use when shuffling data for joins or aggregations.") + .intConf + .createWithDefault(200) val SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE = - longMemConf("spark.sql.adaptive.shuffle.targetPostShuffleInputSize", - defaultValue = Some(64 * 1024 * 1024), - doc = "The target post-shuffle input size in bytes of a task.") + SQLConfigBuilder("spark.sql.adaptive.shuffle.targetPostShuffleInputSize") + .doc("The target post-shuffle input size in bytes of a task.") + .bytesConf(ByteUnit.BYTE) + .createWithDefault(64 * 1024 * 1024) - val ADAPTIVE_EXECUTION_ENABLED = booleanConf("spark.sql.adaptive.enabled", - defaultValue = Some(false), - doc = "When true, enable adaptive query execution.") + val ADAPTIVE_EXECUTION_ENABLED = SQLConfigBuilder("spark.sql.adaptive.enabled") + .doc("When true, enable adaptive query execution.") + .booleanConf + .createWithDefault(false) val SHUFFLE_MIN_NUM_POSTSHUFFLE_PARTITIONS = - intConf("spark.sql.adaptive.minNumPostShufflePartitions", - defaultValue = Some(-1), - doc = "The advisory minimal number of post-shuffle partitions provided to " + + SQLConfigBuilder("spark.sql.adaptive.minNumPostShufflePartitions") + .internal() + .doc("The advisory minimal number of post-shuffle partitions provided to " + "ExchangeCoordinator. This setting is used in our test to make sure we " + "have enough parallelism to expose issues that will not be exposed with a " + "single partition. When the value is a non-positive value, this setting will " + - "not be provided to ExchangeCoordinator.", - isPublic = false) - - val SUBEXPRESSION_ELIMINATION_ENABLED = booleanConf("spark.sql.subexpressionElimination.enabled", - defaultValue = Some(true), - doc = "When true, common subexpressions will be eliminated.", - isPublic = false) - - val CASE_SENSITIVE = booleanConf("spark.sql.caseSensitive", - defaultValue = Some(true), - doc = "Whether the query analyzer should be case sensitive or not.") - - val USE_FILE_SCAN = booleanConf("spark.sql.sources.fileScan", - defaultValue = Some(true), - doc = "Use the new FileScanRDD path for reading HDSF based data sources.", - isPublic = false) - - val PARQUET_SCHEMA_MERGING_ENABLED = booleanConf("spark.sql.parquet.mergeSchema", - defaultValue = Some(false), - doc = "When true, the Parquet data source merges schemas collected from all data files, " + - "otherwise the schema is picked from the summary file or a random data file " + - "if no summary file is available.") - - val PARQUET_SCHEMA_RESPECT_SUMMARIES = booleanConf("spark.sql.parquet.respectSummaryFiles", - defaultValue = Some(false), - doc = "When true, we make assumption that all part-files of Parquet are consistent with " + - "summary files and we will ignore them when merging schema. Otherwise, if this is " + - "false, which is the default, we will merge all part-files. This should be considered " + - "as expert-only option, and shouldn't be enabled before knowing what it means exactly.") - - val PARQUET_BINARY_AS_STRING = booleanConf("spark.sql.parquet.binaryAsString", - defaultValue = Some(false), - doc = "Some other Parquet-producing systems, in particular Impala and older versions of " + + "not be provided to ExchangeCoordinator.") + .intConf + .createWithDefault(-1) + + val SUBEXPRESSION_ELIMINATION_ENABLED = + SQLConfigBuilder("spark.sql.subexpressionElimination.enabled") + .internal() + .doc("When true, common subexpressions will be eliminated.") + .booleanConf + .createWithDefault(true) + + val CASE_SENSITIVE = SQLConfigBuilder("spark.sql.caseSensitive") + .doc("Whether the query analyzer should be case sensitive or not.") + .booleanConf + .createWithDefault(true) + + val USE_FILE_SCAN = SQLConfigBuilder("spark.sql.sources.fileScan") + .internal() + .doc("Use the new FileScanRDD path for reading HDSF based data sources.") + .booleanConf + .createWithDefault(true) + + val PARQUET_SCHEMA_MERGING_ENABLED = SQLConfigBuilder("spark.sql.parquet.mergeSchema") + .doc("When true, the Parquet data source merges schemas collected from all data files, " + + "otherwise the schema is picked from the summary file or a random data file " + + "if no summary file is available.") + .booleanConf + .createWithDefault(false) + + val PARQUET_SCHEMA_RESPECT_SUMMARIES = SQLConfigBuilder("spark.sql.parquet.respectSummaryFiles") + .doc("When true, we make assumption that all part-files of Parquet are consistent with " + + "summary files and we will ignore them when merging schema. Otherwise, if this is " + + "false, which is the default, we will merge all part-files. This should be considered " + + "as expert-only option, and shouldn't be enabled before knowing what it means exactly.") + .booleanConf + .createWithDefault(false) + + val PARQUET_BINARY_AS_STRING = SQLConfigBuilder("spark.sql.parquet.binaryAsString") + .doc("Some other Parquet-producing systems, in particular Impala and older versions of " + "Spark SQL, do not differentiate between binary data and strings when writing out the " + "Parquet schema. This flag tells Spark SQL to interpret binary data as a string to provide " + "compatibility with these systems.") + .booleanConf + .createWithDefault(false) - val PARQUET_INT96_AS_TIMESTAMP = booleanConf("spark.sql.parquet.int96AsTimestamp", - defaultValue = Some(true), - doc = "Some Parquet-producing systems, in particular Impala, store Timestamp into INT96. " + + val PARQUET_INT96_AS_TIMESTAMP = SQLConfigBuilder("spark.sql.parquet.int96AsTimestamp") + .doc("Some Parquet-producing systems, in particular Impala, store Timestamp into INT96. " + "Spark would also store Timestamp as INT96 because we need to avoid precision lost of the " + "nanoseconds field. This flag tells Spark SQL to interpret INT96 data as a timestamp to " + "provide compatibility with these systems.") + .booleanConf + .createWithDefault(true) - val PARQUET_CACHE_METADATA = booleanConf("spark.sql.parquet.cacheMetadata", - defaultValue = Some(true), - doc = "Turns on caching of Parquet schema metadata. Can speed up querying of static data.") + val PARQUET_CACHE_METADATA = SQLConfigBuilder("spark.sql.parquet.cacheMetadata") + .doc("Turns on caching of Parquet schema metadata. Can speed up querying of static data.") + .booleanConf + .createWithDefault(true) - val PARQUET_COMPRESSION = enumConf("spark.sql.parquet.compression.codec", - valueConverter = v => v.toLowerCase, - validValues = Set("uncompressed", "snappy", "gzip", "lzo"), - defaultValue = Some("gzip"), - doc = "Sets the compression codec use when writing Parquet files. Acceptable values include: " + + val PARQUET_COMPRESSION = SQLConfigBuilder("spark.sql.parquet.compression.codec") + .doc("Sets the compression codec use when writing Parquet files. Acceptable values include: " + "uncompressed, snappy, gzip, lzo.") - - val PARQUET_FILTER_PUSHDOWN_ENABLED = booleanConf("spark.sql.parquet.filterPushdown", - defaultValue = Some(true), - doc = "Enables Parquet filter push-down optimization when set to true.") - - val PARQUET_WRITE_LEGACY_FORMAT = booleanConf( - key = "spark.sql.parquet.writeLegacyFormat", - defaultValue = Some(false), - doc = "Whether to follow Parquet's format specification when converting Parquet schema to " + + .stringConf + .transform(_.toLowerCase()) + .checkValues(Set("uncompressed", "snappy", "gzip", "lzo")) + .createWithDefault("gzip") + + val PARQUET_FILTER_PUSHDOWN_ENABLED = SQLConfigBuilder("spark.sql.parquet.filterPushdown") + .doc("Enables Parquet filter push-down optimization when set to true.") + .booleanConf + .createWithDefault(true) + + val PARQUET_WRITE_LEGACY_FORMAT = SQLConfigBuilder("spark.sql.parquet.writeLegacyFormat") + .doc("Whether to follow Parquet's format specification when converting Parquet schema to " + "Spark SQL schema and vice versa.") + .booleanConf + .createWithDefault(false) - val PARQUET_OUTPUT_COMMITTER_CLASS = stringConf( - key = "spark.sql.parquet.output.committer.class", - defaultValue = Some(classOf[ParquetOutputCommitter].getName), - doc = "The output committer class used by Parquet. The specified class needs to be a " + + val PARQUET_OUTPUT_COMMITTER_CLASS = SQLConfigBuilder("spark.sql.parquet.output.committer.class") + .doc("The output committer class used by Parquet. The specified class needs to be a " + "subclass of org.apache.hadoop.mapreduce.OutputCommitter. Typically, it's also a subclass " + "of org.apache.parquet.hadoop.ParquetOutputCommitter. NOTE: 1. Instead of SQLConf, this " + "option must be set in Hadoop Configuration. 2. This option overrides " + "\"spark.sql.sources.outputCommitterClass\".") - - val PARQUET_VECTORIZED_READER_ENABLED = booleanConf( - key = "spark.sql.parquet.enableVectorizedReader", - defaultValue = Some(true), - doc = "Enables vectorized parquet decoding.") - - val ORC_FILTER_PUSHDOWN_ENABLED = booleanConf("spark.sql.orc.filterPushdown", - defaultValue = Some(false), - doc = "When true, enable filter pushdown for ORC files.") - - val HIVE_VERIFY_PARTITION_PATH = booleanConf("spark.sql.hive.verifyPartitionPath", - defaultValue = Some(false), - doc = "When true, check all the partition paths under the table\'s root directory " + - "when reading data stored in HDFS.") - - val HIVE_METASTORE_PARTITION_PRUNING = booleanConf("spark.sql.hive.metastorePartitionPruning", - defaultValue = Some(false), - doc = "When true, some predicates will be pushed down into the Hive metastore so that " + - "unmatching partitions can be eliminated earlier.") - - val NATIVE_VIEW = booleanConf("spark.sql.nativeView", - defaultValue = Some(true), - doc = "When true, CREATE VIEW will be handled by Spark SQL instead of Hive native commands. " + - "Note that this function is experimental and should ony be used when you are using " + - "non-hive-compatible tables written by Spark SQL. The SQL string used to create " + - "view should be fully qualified, i.e. use `tbl1`.`col1` instead of `*` whenever " + - "possible, or you may get wrong result.", - isPublic = false) - - val CANONICAL_NATIVE_VIEW = booleanConf("spark.sql.nativeView.canonical", - defaultValue = Some(true), - doc = "When this option and spark.sql.nativeView are both true, Spark SQL tries to handle " + - "CREATE VIEW statement using SQL query string generated from view definition logical " + - "plan. If the logical plan doesn't have a SQL representation, we fallback to the " + - "original native view implementation.", - isPublic = false) - - val COLUMN_NAME_OF_CORRUPT_RECORD = stringConf("spark.sql.columnNameOfCorruptRecord", - defaultValue = Some("_corrupt_record"), - doc = "The name of internal column for storing raw/un-parsed JSON records that fail to parse.") - - val BROADCAST_TIMEOUT = intConf("spark.sql.broadcastTimeout", - defaultValue = Some(5 * 60), - doc = "Timeout in seconds for the broadcast wait time in broadcast joins.") + .stringConf + .createWithDefault(classOf[ParquetOutputCommitter].getName) + + val PARQUET_VECTORIZED_READER_ENABLED = + SQLConfigBuilder("spark.sql.parquet.enableVectorizedReader") + .doc("Enables vectorized parquet decoding.") + .booleanConf + .createWithDefault(true) + + val ORC_FILTER_PUSHDOWN_ENABLED = SQLConfigBuilder("spark.sql.orc.filterPushdown") + .doc("When true, enable filter pushdown for ORC files.") + .booleanConf + .createWithDefault(false) + + val HIVE_VERIFY_PARTITION_PATH = SQLConfigBuilder("spark.sql.hive.verifyPartitionPath") + .doc("When true, check all the partition paths under the table\'s root directory " + + "when reading data stored in HDFS.") + .booleanConf + .createWithDefault(false) + + val HIVE_METASTORE_PARTITION_PRUNING = + SQLConfigBuilder("spark.sql.hive.metastorePartitionPruning") + .doc("When true, some predicates will be pushed down into the Hive metastore so that " + + "unmatching partitions can be eliminated earlier.") + .booleanConf + .createWithDefault(false) + + val NATIVE_VIEW = SQLConfigBuilder("spark.sql.nativeView") + .internal() + .doc("When true, CREATE VIEW will be handled by Spark SQL instead of Hive native commands. " + + "Note that this function is experimental and should ony be used when you are using " + + "non-hive-compatible tables written by Spark SQL. The SQL string used to create " + + "view should be fully qualified, i.e. use `tbl1`.`col1` instead of `*` whenever " + + "possible, or you may get wrong result.") + .booleanConf + .createWithDefault(true) + + val CANONICAL_NATIVE_VIEW = SQLConfigBuilder("spark.sql.nativeView.canonical") + .internal() + .doc("When this option and spark.sql.nativeView are both true, Spark SQL tries to handle " + + "CREATE VIEW statement using SQL query string generated from view definition logical " + + "plan. If the logical plan doesn't have a SQL representation, we fallback to the " + + "original native view implementation.") + .booleanConf + .createWithDefault(true) + + val COLUMN_NAME_OF_CORRUPT_RECORD = SQLConfigBuilder("spark.sql.columnNameOfCorruptRecord") + .doc("The name of internal column for storing raw/un-parsed JSON records that fail to parse.") + .stringConf + .createWithDefault("_corrupt_record") + + val BROADCAST_TIMEOUT = SQLConfigBuilder("spark.sql.broadcastTimeout") + .doc("Timeout in seconds for the broadcast wait time in broadcast joins.") + .intConf + .createWithDefault(5 * 60) // This is only used for the thriftserver - val THRIFTSERVER_POOL = stringConf("spark.sql.thriftserver.scheduler.pool", - doc = "Set a Fair Scheduler pool for a JDBC client session.") - - val THRIFTSERVER_UI_STATEMENT_LIMIT = intConf("spark.sql.thriftserver.ui.retainedStatements", - defaultValue = Some(200), - doc = "The number of SQL statements kept in the JDBC/ODBC web UI history.") - - val THRIFTSERVER_UI_SESSION_LIMIT = intConf("spark.sql.thriftserver.ui.retainedSessions", - defaultValue = Some(200), - doc = "The number of SQL client sessions kept in the JDBC/ODBC web UI history.") + val THRIFTSERVER_POOL = SQLConfigBuilder("spark.sql.thriftserver.scheduler.pool") + .doc("Set a Fair Scheduler pool for a JDBC client session.") + .stringConf + .createOptional + + val THRIFTSERVER_UI_STATEMENT_LIMIT = + SQLConfigBuilder("spark.sql.thriftserver.ui.retainedStatements") + .doc("The number of SQL statements kept in the JDBC/ODBC web UI history.") + .intConf + .createWithDefault(200) + + val THRIFTSERVER_UI_SESSION_LIMIT = SQLConfigBuilder("spark.sql.thriftserver.ui.retainedSessions") + .doc("The number of SQL client sessions kept in the JDBC/ODBC web UI history.") + .intConf + .createWithDefault(200) // This is used to set the default data source - val DEFAULT_DATA_SOURCE_NAME = stringConf("spark.sql.sources.default", - defaultValue = Some("org.apache.spark.sql.parquet"), - doc = "The default data source to use in input/output.") + val DEFAULT_DATA_SOURCE_NAME = SQLConfigBuilder("spark.sql.sources.default") + .doc("The default data source to use in input/output.") + .stringConf + .createWithDefault("org.apache.spark.sql.parquet") // This is used to control the when we will split a schema's JSON string to multiple pieces // in order to fit the JSON string in metastore's table property (by default, the value has // a length restriction of 4000 characters). We will split the JSON string of a schema // to its length exceeds the threshold. - val SCHEMA_STRING_LENGTH_THRESHOLD = intConf("spark.sql.sources.schemaStringLengthThreshold", - defaultValue = Some(4000), - doc = "The maximum length allowed in a single cell when " + - "storing additional schema information in Hive's metastore.", - isPublic = false) - - val PARTITION_DISCOVERY_ENABLED = booleanConf("spark.sql.sources.partitionDiscovery.enabled", - defaultValue = Some(true), - doc = "When true, automatically discover data partitions.") + val SCHEMA_STRING_LENGTH_THRESHOLD = + SQLConfigBuilder("spark.sql.sources.schemaStringLengthThreshold") + .doc("The maximum length allowed in a single cell when " + + "storing additional schema information in Hive's metastore.") + .internal() + .intConf + .createWithDefault(4000) + + val PARTITION_DISCOVERY_ENABLED = SQLConfigBuilder("spark.sql.sources.partitionDiscovery.enabled") + .doc("When true, automatically discover data partitions.") + .booleanConf + .createWithDefault(true) val PARTITION_COLUMN_TYPE_INFERENCE = - booleanConf("spark.sql.sources.partitionColumnTypeInference.enabled", - defaultValue = Some(true), - doc = "When true, automatically infer the data types for partitioned columns.") + SQLConfigBuilder("spark.sql.sources.partitionColumnTypeInference.enabled") + .doc("When true, automatically infer the data types for partitioned columns.") + .booleanConf + .createWithDefault(true) val PARTITION_MAX_FILES = - intConf("spark.sql.sources.maxConcurrentWrites", - defaultValue = Some(1), - doc = "The maximum number of concurrent files to open before falling back on sorting when " + + SQLConfigBuilder("spark.sql.sources.maxConcurrentWrites") + .doc("The maximum number of concurrent files to open before falling back on sorting when " + "writing out files using dynamic partitioning.") - - val BUCKETING_ENABLED = booleanConf("spark.sql.sources.bucketing.enabled", - defaultValue = Some(true), - doc = "When false, we will treat bucketed table as normal table.") - - val ORDER_BY_ORDINAL = booleanConf("spark.sql.orderByOrdinal", - defaultValue = Some(true), - doc = "When true, the ordinal numbers are treated as the position in the select list. " + - "When false, the ordinal numbers in order/sort By clause are ignored.") - - val GROUP_BY_ORDINAL = booleanConf("spark.sql.groupByOrdinal", - defaultValue = Some(true), - doc = "When true, the ordinal numbers in group by clauses are treated as the position " + + .intConf + .createWithDefault(1) + + val BUCKETING_ENABLED = SQLConfigBuilder("spark.sql.sources.bucketing.enabled") + .doc("When false, we will treat bucketed table as normal table") + .booleanConf + .createWithDefault(true) + + val ORDER_BY_ORDINAL = SQLConfigBuilder("spark.sql.orderByOrdinal") + .doc("When true, the ordinal numbers are treated as the position in the select list. " + + "When false, the ordinal numbers in order/sort By clause are ignored.") + .booleanConf + .createWithDefault(true) + + val GROUP_BY_ORDINAL = SQLConfigBuilder("spark.sql.groupByOrdinal") + .doc("When true, the ordinal numbers in group by clauses are treated as the position " + "in the select list. When false, the ordinal numbers are ignored.") + .booleanConf + .createWithDefault(true) // The output committer class used by HadoopFsRelation. The specified class needs to be a // subclass of org.apache.hadoop.mapreduce.OutputCommitter. @@ -457,89 +346,95 @@ object SQLConf { // 1. Instead of SQLConf, this option *must be set in Hadoop Configuration*. // 2. This option can be overridden by "spark.sql.parquet.output.committer.class". val OUTPUT_COMMITTER_CLASS = - stringConf("spark.sql.sources.outputCommitterClass", isPublic = false) + SQLConfigBuilder("spark.sql.sources.outputCommitterClass").internal().stringConf.createOptional - val PARALLEL_PARTITION_DISCOVERY_THRESHOLD = intConf( - key = "spark.sql.sources.parallelPartitionDiscovery.threshold", - defaultValue = Some(32), - doc = "The degree of parallelism for schema merging and partition discovery of " + - "Parquet data sources.") + val PARALLEL_PARTITION_DISCOVERY_THRESHOLD = + SQLConfigBuilder("spark.sql.sources.parallelPartitionDiscovery.threshold") + .doc("The degree of parallelism for schema merging and partition discovery of " + + "Parquet data sources.") + .intConf + .createWithDefault(32) // Whether to perform eager analysis when constructing a dataframe. // Set to false when debugging requires the ability to look at invalid query plans. - val DATAFRAME_EAGER_ANALYSIS = booleanConf( - "spark.sql.eagerAnalysis", - defaultValue = Some(true), - doc = "When true, eagerly applies query analysis on DataFrame operations.", - isPublic = false) + val DATAFRAME_EAGER_ANALYSIS = SQLConfigBuilder("spark.sql.eagerAnalysis") + .internal() + .doc("When true, eagerly applies query analysis on DataFrame operations.") + .booleanConf + .createWithDefault(true) // Whether to automatically resolve ambiguity in join conditions for self-joins. // See SPARK-6231. - val DATAFRAME_SELF_JOIN_AUTO_RESOLVE_AMBIGUITY = booleanConf( - "spark.sql.selfJoinAutoResolveAmbiguity", - defaultValue = Some(true), - isPublic = false) + val DATAFRAME_SELF_JOIN_AUTO_RESOLVE_AMBIGUITY = + SQLConfigBuilder("spark.sql.selfJoinAutoResolveAmbiguity") + .internal() + .booleanConf + .createWithDefault(true) // Whether to retain group by columns or not in GroupedData.agg. - val DATAFRAME_RETAIN_GROUP_COLUMNS = booleanConf( - "spark.sql.retainGroupColumns", - defaultValue = Some(true), - isPublic = false) - - val DATAFRAME_PIVOT_MAX_VALUES = intConf( - "spark.sql.pivotMaxValues", - defaultValue = Some(10000), - doc = "When doing a pivot without specifying values for the pivot column this is the maximum " + - "number of (distinct) values that will be collected without error." - ) - - val RUN_SQL_ON_FILES = booleanConf("spark.sql.runSQLOnFiles", - defaultValue = Some(true), - isPublic = false, - doc = "When true, we could use `datasource`.`path` as table in SQL query." - ) - - val WHOLESTAGE_CODEGEN_ENABLED = booleanConf("spark.sql.codegen.wholeStage", - defaultValue = Some(true), - doc = "When true, the whole stage (of multiple operators) will be compiled into single java" + - " method.", - isPublic = false) - - val FILES_MAX_PARTITION_BYTES = longConf("spark.sql.files.maxPartitionBytes", - defaultValue = Some(128 * 1024 * 1024), // parquet.block.size - doc = "The maximum number of bytes to pack into a single partition when reading files.", - isPublic = true) - - val FILES_OPEN_COST_IN_BYTES = longConf("spark.sql.files.openCostInBytes", - defaultValue = Some(4 * 1024 * 1024), - doc = "The estimated cost to open a file, measured by the number of bytes could be scanned in" + + val DATAFRAME_RETAIN_GROUP_COLUMNS = SQLConfigBuilder("spark.sql.retainGroupColumns") + .internal() + .booleanConf + .createWithDefault(true) + + val DATAFRAME_PIVOT_MAX_VALUES = SQLConfigBuilder("spark.sql.pivotMaxValues") + .doc("When doing a pivot without specifying values for the pivot column this is the maximum " + + "number of (distinct) values that will be collected without error.") + .intConf + .createWithDefault(10000) + + val RUN_SQL_ON_FILES = SQLConfigBuilder("spark.sql.runSQLOnFiles") + .internal() + .doc("When true, we could use `datasource`.`path` as table in SQL query.") + .booleanConf + .createWithDefault(true) + + val WHOLESTAGE_CODEGEN_ENABLED = SQLConfigBuilder("spark.sql.codegen.wholeStage") + .internal() + .doc("When true, the whole stage (of multiple operators) will be compiled into single java" + + " method.") + .booleanConf + .createWithDefault(true) + + val FILES_MAX_PARTITION_BYTES = SQLConfigBuilder("spark.sql.files.maxPartitionBytes") + .doc("The maximum number of bytes to pack into a single partition when reading files.") + .longConf + .createWithDefault(128 * 1024 * 1024) // parquet.block.size + + val FILES_OPEN_COST_IN_BYTES = SQLConfigBuilder("spark.sql.files.openCostInBytes") + .internal() + .doc("The estimated cost to open a file, measured by the number of bytes could be scanned in" + " the same time. This is used when putting multiple files into a partition. It's better to" + " over estimated, then the partitions with small files will be faster than partitions with" + - " bigger files (which is scheduled first).", - isPublic = false) - - val EXCHANGE_REUSE_ENABLED = booleanConf("spark.sql.exchange.reuse", - defaultValue = Some(true), - doc = "When true, the planner will try to find out duplicated exchanges and re-use them.", - isPublic = false) - - val STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT = intConf( - "spark.sql.streaming.stateStore.minDeltasForSnapshot", - defaultValue = Some(10), - doc = "Minimum number of state store delta files that needs to be generated before they " + - "consolidated into snapshots.", - isPublic = false) - - val STATE_STORE_MIN_VERSIONS_TO_RETAIN = intConf( - "spark.sql.streaming.stateStore.minBatchesToRetain", - defaultValue = Some(2), - doc = "Minimum number of versions of a state store's data to retain after cleaning.", - isPublic = false) - - val CHECKPOINT_LOCATION = stringConf("spark.sql.streaming.checkpointLocation", - defaultValue = None, - doc = "The default location for storing checkpoint data for continuously executing queries.", - isPublic = true) + " bigger files (which is scheduled first).") + .longConf + .createWithDefault(4 * 1024 * 1024) + + val EXCHANGE_REUSE_ENABLED = SQLConfigBuilder("spark.sql.exchange.reuse") + .internal() + .doc("When true, the planner will try to find out duplicated exchanges and re-use them.") + .booleanConf + .createWithDefault(true) + + val STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT = + SQLConfigBuilder("spark.sql.streaming.stateStore.minDeltasForSnapshot") + .internal() + .doc("Minimum number of state store delta files that needs to be generated before they " + + "consolidated into snapshots.") + .intConf + .createWithDefault(10) + + val STATE_STORE_MIN_VERSIONS_TO_RETAIN = + SQLConfigBuilder("spark.sql.streaming.stateStore.minBatchesToRetain") + .internal() + .doc("Minimum number of versions of a state store's data to retain after cleaning.") + .intConf + .createWithDefault(2) + + val CHECKPOINT_LOCATION = SQLConfigBuilder("spark.sql.streaming.checkpointLocation") + .doc("The default location for storing checkpoint data for continuously executing queries.") + .stringConf + .createOptional object Deprecated { val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks" @@ -562,7 +457,7 @@ object SQLConf { * * SQLConf is thread-safe (internally synchronized, so safe to be used in multiple threads). */ -class SQLConf extends Serializable with CatalystConf with Logging { +private[sql] class SQLConf extends Serializable with CatalystConf with Logging { import SQLConf._ /** Only low degree of contention is expected for conf, thus NOT using ConcurrentHashMap. */ @@ -686,7 +581,7 @@ class SQLConf extends Serializable with CatalystConf with Logging { } /** Set the given Spark SQL configuration property. */ - def setConf[T](entry: SQLConfEntry[T], value: T): Unit = { + def setConf[T](entry: ConfigEntry[T], value: T): Unit = { require(entry != null, "entry cannot be null") require(value != null, s"value cannot be null for key: ${entry.key}") require(sqlConfEntries.get(entry.key) == entry, s"$entry is not registered") @@ -706,24 +601,34 @@ class SQLConf extends Serializable with CatalystConf with Logging { /** * Return the value of Spark SQL configuration property for the given key. If the key is not set - * yet, return `defaultValue`. This is useful when `defaultValue` in SQLConfEntry is not the + * yet, return `defaultValue`. This is useful when `defaultValue` in ConfigEntry is not the * desired one. */ - def getConf[T](entry: SQLConfEntry[T], defaultValue: T): T = { + def getConf[T](entry: ConfigEntry[T], defaultValue: T): T = { require(sqlConfEntries.get(entry.key) == entry, s"$entry is not registered") Option(settings.get(entry.key)).map(entry.valueConverter).getOrElse(defaultValue) } /** * Return the value of Spark SQL configuration property for the given key. If the key is not set - * yet, return `defaultValue` in [[SQLConfEntry]]. + * yet, return `defaultValue` in [[ConfigEntry]]. */ - def getConf[T](entry: SQLConfEntry[T]): T = { + def getConf[T](entry: ConfigEntry[T]): T = { require(sqlConfEntries.get(entry.key) == entry, s"$entry is not registered") Option(settings.get(entry.key)).map(entry.valueConverter).orElse(entry.defaultValue). getOrElse(throw new NoSuchElementException(entry.key)) } + /** + * Return the value of an optional Spark SQL configuration property for the given key. If the key + * is not set yet, throw an exception. + */ + def getConf[T](entry: OptionalConfigEntry[T]): T = { + require(sqlConfEntries.get(entry.key) == entry, s"$entry is not registered") + Option(settings.get(entry.key)).map(entry.rawValueConverter). + getOrElse(throw new NoSuchElementException(entry.key)) + } + /** * Return the `string` value of Spark SQL configuration property for the given key. If the key is * not set yet, return `defaultValue`. @@ -765,7 +670,7 @@ class SQLConf extends Serializable with CatalystConf with Logging { settings.remove(key) } - def unsetConf(entry: SQLConfEntry[_]): Unit = { + private[spark] def unsetConf(entry: ConfigEntry[_]): Unit = { settings.remove(entry.key) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfEntrySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfEntrySuite.scala index 2b89fa9f23..cc69199139 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfEntrySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfEntrySuite.scala @@ -26,7 +26,7 @@ class SQLConfEntrySuite extends SparkFunSuite { test("intConf") { val key = "spark.sql.SQLConfEntrySuite.int" - val confEntry = SQLConfEntry.intConf(key) + val confEntry = SQLConfigBuilder(key).intConf.createWithDefault(1) assert(conf.getConf(confEntry, 5) === 5) conf.setConf(confEntry, 10) @@ -45,7 +45,7 @@ class SQLConfEntrySuite extends SparkFunSuite { test("longConf") { val key = "spark.sql.SQLConfEntrySuite.long" - val confEntry = SQLConfEntry.longConf(key) + val confEntry = SQLConfigBuilder(key).longConf.createWithDefault(1L) assert(conf.getConf(confEntry, 5L) === 5L) conf.setConf(confEntry, 10L) @@ -64,7 +64,7 @@ class SQLConfEntrySuite extends SparkFunSuite { test("booleanConf") { val key = "spark.sql.SQLConfEntrySuite.boolean" - val confEntry = SQLConfEntry.booleanConf(key) + val confEntry = SQLConfigBuilder(key).booleanConf.createWithDefault(true) assert(conf.getConf(confEntry, false) === false) conf.setConf(confEntry, true) @@ -83,7 +83,7 @@ class SQLConfEntrySuite extends SparkFunSuite { test("doubleConf") { val key = "spark.sql.SQLConfEntrySuite.double" - val confEntry = SQLConfEntry.doubleConf(key) + val confEntry = SQLConfigBuilder(key).doubleConf.createWithDefault(1d) assert(conf.getConf(confEntry, 5.0) === 5.0) conf.setConf(confEntry, 10.0) @@ -102,7 +102,7 @@ class SQLConfEntrySuite extends SparkFunSuite { test("stringConf") { val key = "spark.sql.SQLConfEntrySuite.string" - val confEntry = SQLConfEntry.stringConf(key) + val confEntry = SQLConfigBuilder(key).stringConf.createWithDefault(null) assert(conf.getConf(confEntry, "abc") === "abc") conf.setConf(confEntry, "abcd") @@ -116,7 +116,10 @@ class SQLConfEntrySuite extends SparkFunSuite { test("enumConf") { val key = "spark.sql.SQLConfEntrySuite.enum" - val confEntry = SQLConfEntry.enumConf(key, v => v, Set("a", "b", "c"), defaultValue = Some("a")) + val confEntry = SQLConfigBuilder(key) + .stringConf + .checkValues(Set("a", "b", "c")) + .createWithDefault("a") assert(conf.getConf(confEntry) === "a") conf.setConf(confEntry, "b") @@ -135,8 +138,10 @@ class SQLConfEntrySuite extends SparkFunSuite { test("stringSeqConf") { val key = "spark.sql.SQLConfEntrySuite.stringSeq" - val confEntry = SQLConfEntry.stringSeqConf("spark.sql.SQLConfEntrySuite.stringSeq", - defaultValue = Some(Nil)) + val confEntry = SQLConfigBuilder(key) + .stringConf + .toSequence + .createWithDefault(Nil) assert(conf.getConf(confEntry, Seq("a", "b", "c")) === Seq("a", "b", "c")) conf.setConf(confEntry, Seq("a", "b", "c", "d")) @@ -147,4 +152,12 @@ class SQLConfEntrySuite extends SparkFunSuite { assert(conf.getConfString(key) === "a,b,c,d,e") assert(conf.getConf(confEntry, Seq("a", "b", "c")) === Seq("a", "b", "c", "d", "e")) } + + test("duplicate entry") { + val key = "spark.sql.SQLConfEntrySuite.duplicate" + SQLConfigBuilder(key).stringConf.createOptional + intercept[IllegalArgumentException] { + SQLConfigBuilder(key).stringConf.createOptional + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala index e944d328a3..e687e6a5ce 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala @@ -119,15 +119,10 @@ class SQLConfSuite extends QueryTest with SharedSQLContext { } intercept[IllegalArgumentException] { - // This value less than Int.MinValue + // This value less than Long.MinValue sqlContext.setConf(SQLConf.SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE.key, "-90000000000g") } - // Test invalid input - intercept[IllegalArgumentException] { - // This value exceeds Long.MaxValue - sqlContext.setConf(SQLConf.SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE.key, "-1g") - } sqlContext.conf.clear() } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index 073b954a5f..505e5c0bb6 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -42,6 +42,7 @@ import org.apache.hadoop.util.VersionInfo import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.api.java.JavaSparkContext import org.apache.spark.internal.Logging +import org.apache.spark.internal.config.ConfigEntry import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis._ @@ -54,8 +55,7 @@ import org.apache.spark.sql.execution.ui.SQLListener import org.apache.spark.sql.hive.client._ import org.apache.spark.sql.hive.execution.{DescribeHiveTableCommand, HiveNativeCommand} import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.internal.SQLConf.SQLConfEntry -import org.apache.spark.sql.internal.SQLConf.SQLConfEntry._ +import org.apache.spark.sql.internal.SQLConf._ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.Utils @@ -318,7 +318,7 @@ class HiveContext private[hive]( hiveconf.set(key, value) } - override private[sql] def setConf[T](entry: SQLConfEntry[T], value: T): Unit = { + override private[sql] def setConf[T](entry: ConfigEntry[T], value: T): Unit = { setConf(entry.key, entry.stringConverter(value)) } @@ -413,19 +413,19 @@ private[hive] object HiveContext extends Logging { /** The version of hive used internally by Spark SQL. */ val hiveExecutionVersion: String = "1.2.1" - val HIVE_METASTORE_VERSION = stringConf("spark.sql.hive.metastore.version", - defaultValue = Some(hiveExecutionVersion), - doc = "Version of the Hive metastore. Available options are " + + val HIVE_METASTORE_VERSION = SQLConfigBuilder("spark.sql.hive.metastore.version") + .doc("Version of the Hive metastore. Available options are " + s"0.12.0 through $hiveExecutionVersion.") + .stringConf + .createWithDefault(hiveExecutionVersion) - val HIVE_EXECUTION_VERSION = stringConf( - key = "spark.sql.hive.version", - defaultValue = Some(hiveExecutionVersion), - doc = "Version of Hive used internally by Spark SQL.") + val HIVE_EXECUTION_VERSION = SQLConfigBuilder("spark.sql.hive.version") + .doc("Version of Hive used internally by Spark SQL.") + .stringConf + .createWithDefault(hiveExecutionVersion) - val HIVE_METASTORE_JARS = stringConf("spark.sql.hive.metastore.jars", - defaultValue = Some("builtin"), - doc = s""" + val HIVE_METASTORE_JARS = SQLConfigBuilder("spark.sql.hive.metastore.jars") + .doc(s""" | Location of the jars that should be used to instantiate the HiveMetastoreClient. | This property can be one of three options: " | 1. "builtin" @@ -436,49 +436,61 @@ private[hive] object HiveContext extends Logging { | 2. "maven" | Use Hive jars of specified version downloaded from Maven repositories. | 3. A classpath in the standard format for both Hive and Hadoop. - """.stripMargin) - val CONVERT_METASTORE_PARQUET = booleanConf("spark.sql.hive.convertMetastoreParquet", - defaultValue = Some(true), - doc = "When set to false, Spark SQL will use the Hive SerDe for parquet tables instead of " + - "the built in support.") - - val CONVERT_METASTORE_PARQUET_WITH_SCHEMA_MERGING = booleanConf( - "spark.sql.hive.convertMetastoreParquet.mergeSchema", - defaultValue = Some(false), - doc = "When true, also tries to merge possibly different but compatible Parquet schemas in " + - "different Parquet data files. This configuration is only effective " + - "when \"spark.sql.hive.convertMetastoreParquet\" is true.") + """.stripMargin) + .stringConf + .createWithDefault("builtin") - val CONVERT_METASTORE_ORC = booleanConf("spark.sql.hive.convertMetastoreOrc", - defaultValue = Some(true), - doc = "When set to false, Spark SQL will use the Hive SerDe for ORC tables instead of " + + val CONVERT_METASTORE_PARQUET = SQLConfigBuilder("spark.sql.hive.convertMetastoreParquet") + .doc("When set to false, Spark SQL will use the Hive SerDe for parquet tables instead of " + "the built in support.") - - val CONVERT_CTAS = booleanConf("spark.sql.hive.convertCTAS", - defaultValue = Some(false), - doc = "When true, a table created by a Hive CTAS statement (no USING clause) will be " + + .booleanConf + .createWithDefault(true) + + val CONVERT_METASTORE_PARQUET_WITH_SCHEMA_MERGING = + SQLConfigBuilder("spark.sql.hive.convertMetastoreParquet.mergeSchema") + .doc("When true, also tries to merge possibly different but compatible Parquet schemas in " + + "different Parquet data files. This configuration is only effective " + + "when \"spark.sql.hive.convertMetastoreParquet\" is true.") + .booleanConf + .createWithDefault(false) + + val CONVERT_CTAS = SQLConfigBuilder("spark.sql.hive.convertCTAS") + .doc("When true, a table created by a Hive CTAS statement (no USING clause) will be " + "converted to a data source table, using the data source set by spark.sql.sources.default.") + .booleanConf + .createWithDefault(false) - val HIVE_METASTORE_SHARED_PREFIXES = stringSeqConf("spark.sql.hive.metastore.sharedPrefixes", - defaultValue = Some(jdbcPrefixes), - doc = "A comma separated list of class prefixes that should be loaded using the classloader " + + val CONVERT_METASTORE_ORC = SQLConfigBuilder("spark.sql.hive.convertMetastoreOrc") + .doc("When set to false, Spark SQL will use the Hive SerDe for ORC tables instead of " + + "the built in support.") + .booleanConf + .createWithDefault(true) + + val HIVE_METASTORE_SHARED_PREFIXES = SQLConfigBuilder("spark.sql.hive.metastore.sharedPrefixes") + .doc("A comma separated list of class prefixes that should be loaded using the classloader " + "that is shared between Spark SQL and a specific version of Hive. An example of classes " + "that should be shared is JDBC drivers that are needed to talk to the metastore. Other " + "classes that need to be shared are those that interact with classes that are already " + "shared. For example, custom appenders that are used by log4j.") + .stringConf + .toSequence + .createWithDefault(jdbcPrefixes) private def jdbcPrefixes = Seq( "com.mysql.jdbc", "org.postgresql", "com.microsoft.sqlserver", "oracle.jdbc") - val HIVE_METASTORE_BARRIER_PREFIXES = stringSeqConf("spark.sql.hive.metastore.barrierPrefixes", - defaultValue = Some(Seq()), - doc = "A comma separated list of class prefixes that should explicitly be reloaded for each " + + val HIVE_METASTORE_BARRIER_PREFIXES = SQLConfigBuilder("spark.sql.hive.metastore.barrierPrefixes") + .doc("A comma separated list of class prefixes that should explicitly be reloaded for each " + "version of Hive that Spark SQL is communicating with. For example, Hive UDFs that are " + "declared in a prefix that typically would be shared (i.e. org.apache.spark.*).") - - val HIVE_THRIFT_SERVER_ASYNC = booleanConf("spark.sql.hive.thriftServer.async", - defaultValue = Some(true), - doc = "When set to true, Hive Thrift server executes SQL queries in an asynchronous way.") + .stringConf + .toSequence + .createWithDefault(Nil) + + val HIVE_THRIFT_SERVER_ASYNC = SQLConfigBuilder("spark.sql.hive.thriftServer.async") + .doc("When set to true, Hive Thrift server executes SQL queries in an asynchronous way.") + .booleanConf + .createWithDefault(true) /** * The version of the hive client that will be used to communicate with the metastore. Note that diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/config.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/config.scala index 5188a3e229..8d576bebb0 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/config.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/config.scala @@ -31,82 +31,82 @@ package object config { "in YARN Application Reports, which can be used for filtering when querying YARN.") .stringConf .toSequence - .optional + .createOptional private[spark] val ATTEMPT_FAILURE_VALIDITY_INTERVAL_MS = ConfigBuilder("spark.yarn.am.attemptFailuresValidityInterval") .doc("Interval after which AM failures will be considered independent and " + "not accumulate towards the attempt count.") .timeConf(TimeUnit.MILLISECONDS) - .optional + .createOptional private[spark] val MAX_APP_ATTEMPTS = ConfigBuilder("spark.yarn.maxAppAttempts") .doc("Maximum number of AM attempts before failing the app.") .intConf - .optional + .createOptional private[spark] val USER_CLASS_PATH_FIRST = ConfigBuilder("spark.yarn.user.classpath.first") .doc("Whether to place user jars in front of Spark's classpath.") .booleanConf - .withDefault(false) + .createWithDefault(false) private[spark] val GATEWAY_ROOT_PATH = ConfigBuilder("spark.yarn.config.gatewayPath") .doc("Root of configuration paths that is present on gateway nodes, and will be replaced " + "with the corresponding path in cluster machines.") .stringConf - .withDefault(null) + .createWithDefault(null) private[spark] val REPLACEMENT_ROOT_PATH = ConfigBuilder("spark.yarn.config.replacementPath") .doc(s"Path to use as a replacement for ${GATEWAY_ROOT_PATH.key} when launching processes " + "in the YARN cluster.") .stringConf - .withDefault(null) + .createWithDefault(null) private[spark] val QUEUE_NAME = ConfigBuilder("spark.yarn.queue") .stringConf - .withDefault("default") + .createWithDefault("default") private[spark] val HISTORY_SERVER_ADDRESS = ConfigBuilder("spark.yarn.historyServer.address") .stringConf - .optional + .createOptional /* File distribution. */ private[spark] val SPARK_ARCHIVE = ConfigBuilder("spark.yarn.archive") .doc("Location of archive containing jars files with Spark classes.") .stringConf - .optional + .createOptional private[spark] val SPARK_JARS = ConfigBuilder("spark.yarn.jars") .doc("Location of jars containing Spark classes.") .stringConf .toSequence - .optional + .createOptional private[spark] val ARCHIVES_TO_DISTRIBUTE = ConfigBuilder("spark.yarn.dist.archives") .stringConf .toSequence - .withDefault(Nil) + .createWithDefault(Nil) private[spark] val FILES_TO_DISTRIBUTE = ConfigBuilder("spark.yarn.dist.files") .stringConf .toSequence - .withDefault(Nil) + .createWithDefault(Nil) private[spark] val JARS_TO_DISTRIBUTE = ConfigBuilder("spark.yarn.dist.jars") .stringConf .toSequence - .withDefault(Nil) + .createWithDefault(Nil) private[spark] val PRESERVE_STAGING_FILES = ConfigBuilder("spark.yarn.preserve.staging.files") .doc("Whether to preserve temporary files created by the job in HDFS.") .booleanConf - .withDefault(false) + .createWithDefault(false) private[spark] val STAGING_FILE_REPLICATION = ConfigBuilder("spark.yarn.submit.file.replication") .doc("Replication factor for files uploaded by Spark to HDFS.") .intConf - .optional + .createOptional private[spark] val STAGING_DIR = ConfigBuilder("spark.yarn.stagingDir") .doc("Staging directory used while submitting applications.") @@ -119,146 +119,146 @@ package object config { .doc("In cluster mode, whether to wait for the application to finish before exiting the " + "launcher process.") .booleanConf - .withDefault(true) + .createWithDefault(true) private[spark] val REPORT_INTERVAL = ConfigBuilder("spark.yarn.report.interval") .doc("Interval between reports of the current app status in cluster mode.") .timeConf(TimeUnit.MILLISECONDS) - .withDefaultString("1s") + .createWithDefaultString("1s") /* Shared Client-mode AM / Driver configuration. */ private[spark] val AM_MAX_WAIT_TIME = ConfigBuilder("spark.yarn.am.waitTime") .timeConf(TimeUnit.MILLISECONDS) - .withDefaultString("100s") + .createWithDefaultString("100s") private[spark] val AM_NODE_LABEL_EXPRESSION = ConfigBuilder("spark.yarn.am.nodeLabelExpression") .doc("Node label expression for the AM.") .stringConf - .optional + .createOptional private[spark] val CONTAINER_LAUNCH_MAX_THREADS = ConfigBuilder("spark.yarn.containerLauncherMaxThreads") .intConf - .withDefault(25) + .createWithDefault(25) private[spark] val MAX_EXECUTOR_FAILURES = ConfigBuilder("spark.yarn.max.executor.failures") .intConf - .optional + .createOptional private[spark] val MAX_REPORTER_THREAD_FAILURES = ConfigBuilder("spark.yarn.scheduler.reporterThread.maxFailures") .intConf - .withDefault(5) + .createWithDefault(5) private[spark] val RM_HEARTBEAT_INTERVAL = ConfigBuilder("spark.yarn.scheduler.heartbeat.interval-ms") .timeConf(TimeUnit.MILLISECONDS) - .withDefaultString("3s") + .createWithDefaultString("3s") private[spark] val INITIAL_HEARTBEAT_INTERVAL = ConfigBuilder("spark.yarn.scheduler.initial-allocation.interval") .timeConf(TimeUnit.MILLISECONDS) - .withDefaultString("200ms") + .createWithDefaultString("200ms") private[spark] val SCHEDULER_SERVICES = ConfigBuilder("spark.yarn.services") .doc("A comma-separated list of class names of services to add to the scheduler.") .stringConf .toSequence - .withDefault(Nil) + .createWithDefault(Nil) /* Client-mode AM configuration. */ private[spark] val AM_CORES = ConfigBuilder("spark.yarn.am.cores") .intConf - .withDefault(1) + .createWithDefault(1) private[spark] val AM_JAVA_OPTIONS = ConfigBuilder("spark.yarn.am.extraJavaOptions") .doc("Extra Java options for the client-mode AM.") .stringConf - .optional + .createOptional private[spark] val AM_LIBRARY_PATH = ConfigBuilder("spark.yarn.am.extraLibraryPath") .doc("Extra native library path for the client-mode AM.") .stringConf - .optional + .createOptional private[spark] val AM_MEMORY_OVERHEAD = ConfigBuilder("spark.yarn.am.memoryOverhead") .bytesConf(ByteUnit.MiB) - .optional + .createOptional private[spark] val AM_MEMORY = ConfigBuilder("spark.yarn.am.memory") .bytesConf(ByteUnit.MiB) - .withDefaultString("512m") + .createWithDefaultString("512m") /* Driver configuration. */ private[spark] val DRIVER_CORES = ConfigBuilder("spark.driver.cores") .intConf - .withDefault(1) + .createWithDefault(1) private[spark] val DRIVER_MEMORY_OVERHEAD = ConfigBuilder("spark.yarn.driver.memoryOverhead") .bytesConf(ByteUnit.MiB) - .optional + .createOptional /* Executor configuration. */ private[spark] val EXECUTOR_CORES = ConfigBuilder("spark.executor.cores") .intConf - .withDefault(1) + .createWithDefault(1) private[spark] val EXECUTOR_MEMORY_OVERHEAD = ConfigBuilder("spark.yarn.executor.memoryOverhead") .bytesConf(ByteUnit.MiB) - .optional + .createOptional private[spark] val EXECUTOR_NODE_LABEL_EXPRESSION = ConfigBuilder("spark.yarn.executor.nodeLabelExpression") .doc("Node label expression for executors.") .stringConf - .optional + .createOptional /* Security configuration. */ private[spark] val CREDENTIAL_FILE_MAX_COUNT = ConfigBuilder("spark.yarn.credentials.file.retention.count") .intConf - .withDefault(5) + .createWithDefault(5) private[spark] val CREDENTIALS_FILE_MAX_RETENTION = ConfigBuilder("spark.yarn.credentials.file.retention.days") .intConf - .withDefault(5) + .createWithDefault(5) private[spark] val NAMENODES_TO_ACCESS = ConfigBuilder("spark.yarn.access.namenodes") .doc("Extra NameNode URLs for which to request delegation tokens. The NameNode that hosts " + "fs.defaultFS does not need to be listed here.") .stringConf .toSequence - .withDefault(Nil) + .createWithDefault(Nil) private[spark] val TOKEN_RENEWAL_INTERVAL = ConfigBuilder("spark.yarn.token.renewal.interval") - .internal + .internal() .timeConf(TimeUnit.MILLISECONDS) - .optional + .createOptional /* Private configs. */ private[spark] val CREDENTIALS_FILE_PATH = ConfigBuilder("spark.yarn.credentials.file") - .internal + .internal() .stringConf - .withDefault(null) + .createWithDefault(null) // Internal config to propagate the location of the user's jar to the driver/executors private[spark] val APP_JAR = ConfigBuilder("spark.yarn.user.jar") - .internal + .internal() .stringConf - .optional + .createOptional // Internal config to propagate the locations of any extra jars to add to the classpath // of the executors private[spark] val SECONDARY_JARS = ConfigBuilder("spark.yarn.secondary.jars") - .internal + .internal() .stringConf .toSequence - .optional + .createOptional } -- cgit v1.2.3 From 48682f6bf663e54cb63b7e95a4520d34b6fa890b Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Tue, 5 Apr 2016 18:10:40 -0500 Subject: [HOTFIX] Fix `optional` to `createOptional`. ## What changes were proposed in this pull request? This PR fixes the following line. ``` private[spark] val STAGING_DIR = ConfigBuilder("spark.yarn.stagingDir") .doc("Staging directory used while submitting applications.") .stringConf - .optional + .createOptional ``` ## How was this patch tested? Pass the build. Author: Dongjoon Hyun Closes #12187 from dongjoon-hyun/hotfix. --- yarn/src/main/scala/org/apache/spark/deploy/yarn/config.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/config.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/config.scala index 8d576bebb0..edfbfc5d58 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/config.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/config.scala @@ -111,7 +111,7 @@ package object config { private[spark] val STAGING_DIR = ConfigBuilder("spark.yarn.stagingDir") .doc("Staging directory used while submitting applications.") .stringConf - .optional + .createOptional /* Cluster-mode launcher configuration. */ -- cgit v1.2.3 From 1146c534d6c3806f3e920043ba06838ef02cd7e8 Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Tue, 5 Apr 2016 17:21:41 -0700 Subject: [SPARK-14353] Dataset Time Window `window` API for R ## What changes were proposed in this pull request? The `window` function was added to Dataset with [this PR](https://github.com/apache/spark/pull/12008). This PR adds the R API for this function. With this PR, SQL, Java, and Scala will share the same APIs as in users can use: - `window(timeColumn, windowDuration)` - `window(timeColumn, windowDuration, slideDuration)` - `window(timeColumn, windowDuration, slideDuration, startTime)` In Python and R, users can access all APIs above, but in addition they can do - In R: `window(timeColumn, windowDuration, startTime=...)` that is, they can provide the startTime without providing the `slideDuration`. In this case, we will generate tumbling windows. ## How was this patch tested? Unit tests + manual tests Author: Burak Yavuz Closes #12141 from brkyvz/R-windows. --- R/pkg/NAMESPACE | 1 + R/pkg/R/functions.R | 63 +++++++++++++++++++++++++++++++ R/pkg/R/generics.R | 4 ++ R/pkg/inst/tests/testthat/test_context.R | 2 +- R/pkg/inst/tests/testthat/test_sparkSQL.R | 36 ++++++++++++++++++ 5 files changed, 105 insertions(+), 1 deletion(-) diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index fa3fb0b09a..f48c61c1d5 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -265,6 +265,7 @@ exportMethods("%in%", "var_samp", "weekofyear", "when", + "window", "year") exportClasses("GroupedData") diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index d9c10b4a4b..db877b2d63 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -2131,6 +2131,69 @@ setMethod("from_unixtime", signature(x = "Column"), column(jc) }) +#' window +#' +#' Bucketize rows into one or more time windows given a timestamp specifying column. Window +#' starts are inclusive but the window ends are exclusive, e.g. 12:05 will be in the window +#' [12:05,12:10) but not in [12:00,12:05). Windows can support microsecond precision. Windows in +#' the order of months are not supported. +#' +#' The time column must be of TimestampType. +#' +#' Durations are provided as strings, e.g. '1 second', '1 day 12 hours', '2 minutes'. Valid +#' interval strings are 'week', 'day', 'hour', 'minute', 'second', 'millisecond', 'microsecond'. +#' If the `slideDuration` is not provided, the windows will be tumbling windows. +#' +#' The startTime is the offset with respect to 1970-01-01 00:00:00 UTC with which to start +#' window intervals. For example, in order to have hourly tumbling windows that start 15 minutes +#' past the hour, e.g. 12:15-13:15, 13:15-14:15... provide `startTime` as `15 minutes`. +#' +#' The output column will be a struct called 'window' by default with the nested columns 'start' +#' and 'end'. +#' +#' @family datetime_funcs +#' @rdname window +#' @name window +#' @export +#' @examples +#'\dontrun{ +#' # One minute windows every 15 seconds 10 seconds after the minute, e.g. 09:00:10-09:01:10, +#' # 09:00:25-09:01:25, 09:00:40-09:01:40, ... +#' window(df$time, "1 minute", "15 seconds", "10 seconds") +#' +#' # One minute tumbling windows 15 seconds after the minute, e.g. 09:00:15-09:01:15, +#' # 09:01:15-09:02:15... +#' window(df$time, "1 minute", startTime = "15 seconds") +#' +#' # Thirty second windows every 10 seconds, e.g. 09:00:00-09:00:30, 09:00:10-09:00:40, ... +#' window(df$time, "30 seconds", "10 seconds") +#'} +setMethod("window", signature(x = "Column"), + function(x, windowDuration, slideDuration = NULL, startTime = NULL) { + stopifnot(is.character(windowDuration)) + if (!is.null(slideDuration) && !is.null(startTime)) { + stopifnot(is.character(slideDuration) && is.character(startTime)) + jc <- callJStatic("org.apache.spark.sql.functions", + "window", + x@jc, windowDuration, slideDuration, startTime) + } else if (!is.null(slideDuration)) { + stopifnot(is.character(slideDuration)) + jc <- callJStatic("org.apache.spark.sql.functions", + "window", + x@jc, windowDuration, slideDuration) + } else if (!is.null(startTime)) { + stopifnot(is.character(startTime)) + jc <- callJStatic("org.apache.spark.sql.functions", + "window", + x@jc, windowDuration, windowDuration, startTime) + } else { + jc <- callJStatic("org.apache.spark.sql.functions", + "window", + x@jc, windowDuration) + } + column(jc) + }) + #' locate #' #' Locate the position of the first occurrence of substr. diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index c6990f4748..ecdeea5ec4 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -1152,6 +1152,10 @@ setGeneric("var_samp", function(x) { standardGeneric("var_samp") }) #' @export setGeneric("weekofyear", function(x) { standardGeneric("weekofyear") }) +#' @rdname window +#' @export +setGeneric("window", function(x, ...) { standardGeneric("window") }) + #' @rdname year #' @export setGeneric("year", function(x) { standardGeneric("year") }) diff --git a/R/pkg/inst/tests/testthat/test_context.R b/R/pkg/inst/tests/testthat/test_context.R index ad3f9722a4..6e06c974c2 100644 --- a/R/pkg/inst/tests/testthat/test_context.R +++ b/R/pkg/inst/tests/testthat/test_context.R @@ -26,7 +26,7 @@ test_that("Check masked functions", { maskedBySparkR <- masked[funcSparkROrEmpty] namesOfMasked <- c("describe", "cov", "filter", "lag", "na.omit", "predict", "sd", "var", "colnames", "colnames<-", "intersect", "rank", "rbind", "sample", "subset", - "summary", "transform", "drop") + "summary", "transform", "drop", "window") expect_equal(length(maskedBySparkR), length(namesOfMasked)) expect_equal(sort(maskedBySparkR), sort(namesOfMasked)) # above are those reported as masked when `library(SparkR)` diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R index eef365b42e..22eb3ec984 100644 --- a/R/pkg/inst/tests/testthat/test_sparkSQL.R +++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R @@ -1204,6 +1204,42 @@ test_that("greatest() and least() on a DataFrame", { expect_equal(collect(select(df, least(df$a, df$b)))[, 1], c(1, 3)) }) +test_that("time windowing (window()) with all inputs", { + df <- createDataFrame(sqlContext, data.frame(t = c("2016-03-11 09:00:07"), v = c(1))) + df$window <- window(df$t, "5 seconds", "5 seconds", "0 seconds") + local <- collect(df)$v + # Not checking time windows because of possible time zone issues. Just checking that the function + # works + expect_equal(local, c(1)) +}) + +test_that("time windowing (window()) with slide duration", { + df <- createDataFrame(sqlContext, data.frame(t = c("2016-03-11 09:00:07"), v = c(1))) + df$window <- window(df$t, "5 seconds", "2 seconds") + local <- collect(df)$v + # Not checking time windows because of possible time zone issues. Just checking that the function + # works + expect_equal(local, c(1, 1)) +}) + +test_that("time windowing (window()) with start time", { + df <- createDataFrame(sqlContext, data.frame(t = c("2016-03-11 09:00:07"), v = c(1))) + df$window <- window(df$t, "5 seconds", startTime = "2 seconds") + local <- collect(df)$v + # Not checking time windows because of possible time zone issues. Just checking that the function + # works + expect_equal(local, c(1)) +}) + +test_that("time windowing (window()) with just window duration", { + df <- createDataFrame(sqlContext, data.frame(t = c("2016-03-11 09:00:07"), v = c(1))) + df$window <- window(df$t, "5 seconds") + local <- collect(df)$v + # Not checking time windows because of possible time zone issues. Just checking that the function + # works + expect_equal(local, c(1)) +}) + test_that("when(), otherwise() and ifelse() on a DataFrame", { l <- list(list(a = 1, b = 2), list(a = 3, b = 4)) df <- createDataFrame(sqlContext, l) -- cgit v1.2.3 From 7d29c72f64f8637d8182fb7c495f87ab7ce86ea0 Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Tue, 5 Apr 2016 21:22:20 -0500 Subject: [SPARK-14359] Unit tests for java 8 lambda syntax with typed aggregates ## What changes were proposed in this pull request? Adds unit tests for java 8 lambda syntax with typed aggregates as a follow-up to #12168 ## How was this patch tested? Unit tests. Author: Eric Liang Closes #12181 from ericl/sc-2794-2. --- external/java8-tests/pom.xml | 12 +++ .../spark/sql/Java8DatasetAggregatorSuite.java | 61 +++++++++++++++ .../sql/sources/JavaDatasetAggregatorSuite.java | 86 +++++++++++----------- 3 files changed, 118 insertions(+), 41 deletions(-) create mode 100644 external/java8-tests/src/test/java/org/apache/spark/sql/Java8DatasetAggregatorSuite.java diff --git a/external/java8-tests/pom.xml b/external/java8-tests/pom.xml index f5a06467ee..1ea9196e9d 100644 --- a/external/java8-tests/pom.xml +++ b/external/java8-tests/pom.xml @@ -58,6 +58,18 @@ test-jar test + + org.apache.spark + spark-sql_${scala.binary.version} + ${project.version} + + + org.apache.spark + spark-sql_${scala.binary.version} + ${project.version} + test-jar + test + org.apache.spark spark-test-tags_${scala.binary.version} diff --git a/external/java8-tests/src/test/java/org/apache/spark/sql/Java8DatasetAggregatorSuite.java b/external/java8-tests/src/test/java/org/apache/spark/sql/Java8DatasetAggregatorSuite.java new file mode 100644 index 0000000000..23abfa3970 --- /dev/null +++ b/external/java8-tests/src/test/java/org/apache/spark/sql/Java8DatasetAggregatorSuite.java @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package test.org.apache.spark.sql.sources; + +import java.util.Arrays; + +import org.junit.Assert; +import org.junit.Test; +import scala.Tuple2; + +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.KeyValueGroupedDataset; +import org.apache.spark.sql.expressions.java.typed; + +/** + * Suite that replicates tests in JavaDatasetAggregatorSuite using lambda syntax. + */ +public class Java8DatasetAggregatorSuite extends JavaDatasetAggregatorSuiteBase { + @Test + public void testTypedAggregationAverage() { + KeyValueGroupedDataset> grouped = generateGroupedDataset(); + Dataset> agged = grouped.agg(typed.avg(v -> (double)(v._2() * 2))); + Assert.assertEquals(Arrays.asList(tuple2("a", 3.0), tuple2("b", 6.0)), agged.collectAsList()); + } + + @Test + public void testTypedAggregationCount() { + KeyValueGroupedDataset> grouped = generateGroupedDataset(); + Dataset> agged = grouped.agg(typed.count(v -> v)); + Assert.assertEquals(Arrays.asList(tuple2("a", 2), tuple2("b", 1)), agged.collectAsList()); + } + + @Test + public void testTypedAggregationSumDouble() { + KeyValueGroupedDataset> grouped = generateGroupedDataset(); + Dataset> agged = grouped.agg(typed.sum(v -> (double)v._2())); + Assert.assertEquals(Arrays.asList(tuple2("a", 3.0), tuple2("b", 3.0)), agged.collectAsList()); + } + + @Test + public void testTypedAggregationSumLong() { + KeyValueGroupedDataset> grouped = generateGroupedDataset(); + Dataset> agged = grouped.agg(typed.sumLong(v -> (long)v._2())); + Assert.assertEquals(Arrays.asList(tuple2("a", 3), tuple2("b", 3)), agged.collectAsList()); + } +} diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaDatasetAggregatorSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaDatasetAggregatorSuite.java index c8d0eecd5c..594f4675bd 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaDatasetAggregatorSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaDatasetAggregatorSuite.java @@ -41,46 +41,7 @@ import org.apache.spark.sql.test.TestSQLContext; /** * Suite for testing the aggregate functionality of Datasets in Java. */ -public class JavaDatasetAggregatorSuite implements Serializable { - private transient JavaSparkContext jsc; - private transient TestSQLContext context; - - @Before - public void setUp() { - // Trigger static initializer of TestData - SparkContext sc = new SparkContext("local[*]", "testing"); - jsc = new JavaSparkContext(sc); - context = new TestSQLContext(sc); - context.loadTestData(); - } - - @After - public void tearDown() { - context.sparkContext().stop(); - context = null; - jsc = null; - } - - private Tuple2 tuple2(T1 t1, T2 t2) { - return new Tuple2<>(t1, t2); - } - - private KeyValueGroupedDataset> generateGroupedDataset() { - Encoder> encoder = Encoders.tuple(Encoders.STRING(), Encoders.INT()); - List> data = - Arrays.asList(tuple2("a", 1), tuple2("a", 2), tuple2("b", 3)); - Dataset> ds = context.createDataset(data, encoder); - - return ds.groupByKey( - new MapFunction, String>() { - @Override - public String call(Tuple2 value) throws Exception { - return value._1(); - } - }, - Encoders.STRING()); - } - +public class JavaDatasetAggregatorSuite extends JavaDatasetAggregatorSuiteBase { @Test public void testTypedAggregationAnonClass() { KeyValueGroupedDataset> grouped = generateGroupedDataset(); @@ -100,7 +61,6 @@ public class JavaDatasetAggregatorSuite implements Serializable { } static class IntSumOf extends Aggregator, Integer, Integer> { - @Override public Integer zero() { return 0; @@ -170,3 +130,47 @@ public class JavaDatasetAggregatorSuite implements Serializable { Assert.assertEquals(Arrays.asList(tuple2("a", 3), tuple2("b", 3)), agged.collectAsList()); } } + +/** + * Common test base shared across this and Java8DatasetAggregatorSuite. + */ +class JavaDatasetAggregatorSuiteBase implements Serializable { + protected transient JavaSparkContext jsc; + protected transient TestSQLContext context; + + @Before + public void setUp() { + // Trigger static initializer of TestData + SparkContext sc = new SparkContext("local[*]", "testing"); + jsc = new JavaSparkContext(sc); + context = new TestSQLContext(sc); + context.loadTestData(); + } + + @After + public void tearDown() { + context.sparkContext().stop(); + context = null; + jsc = null; + } + + protected Tuple2 tuple2(T1 t1, T2 t2) { + return new Tuple2<>(t1, t2); + } + + protected KeyValueGroupedDataset> generateGroupedDataset() { + Encoder> encoder = Encoders.tuple(Encoders.STRING(), Encoders.INT()); + List> data = + Arrays.asList(tuple2("a", 1), tuple2("a", 2), tuple2("b", 3)); + Dataset> ds = context.createDataset(data, encoder); + + return ds.groupByKey( + new MapFunction, String>() { + @Override + public String call(Tuple2 value) throws Exception { + return value._1(); + } + }, + Encoders.STRING()); + } +} -- cgit v1.2.3 From 8e5c1cbf2c3d5eaa7d9dd35def177414a0d4cf82 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Tue, 5 Apr 2016 19:57:23 -0700 Subject: [SPARK-13211][STREAMING] StreamingContext throws NoSuchElementException when created from non-existent checkpoint directory ## What changes were proposed in this pull request? Take 2: avoid None.get NoSuchElementException in favor of more descriptive IllegalArgumentException if a non-existent checkpoint dir is used without a SparkContext ## How was this patch tested? Jenkins test plus new test for this particular case Author: Sean Owen Closes #12174 from srowen/SPARK-13211. --- .../main/scala/org/apache/spark/streaming/Checkpoint.scala | 3 +-- .../scala/org/apache/spark/streaming/StreamingContext.scala | 11 ++++------- .../scala/org/apache/spark/streaming/CheckpointSuite.scala | 5 +++++ 3 files changed, 10 insertions(+), 9 deletions(-) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala index f9f3d97ef3..5cc677d085 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala @@ -334,8 +334,7 @@ object CheckpointReader extends Logging { ignoreReadError: Boolean = false): Option[Checkpoint] = { val checkpointPath = new Path(checkpointDir) - // TODO(rxin): Why is this a def?! - def fs: FileSystem = checkpointPath.getFileSystem(hadoopConf) + val fs = checkpointPath.getFileSystem(hadoopConf) // Try to find the checkpoint files val checkpointFiles = Checkpoint.getCheckpointFiles(checkpointDir, Some(fs)).reverse diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala index ac37e8e022..83a1092b16 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala @@ -106,7 +106,7 @@ class StreamingContext private[streaming] ( * HDFS compatible filesystems */ def this(path: String, hadoopConf: Configuration) = - this(null, CheckpointReader.read(path, new SparkConf(), hadoopConf).get, null) + this(null, CheckpointReader.read(path, new SparkConf(), hadoopConf).orNull, null) /** * Recreate a StreamingContext from a checkpoint file. @@ -122,15 +122,12 @@ class StreamingContext private[streaming] ( def this(path: String, sparkContext: SparkContext) = { this( sparkContext, - CheckpointReader.read(path, sparkContext.conf, sparkContext.hadoopConfiguration).get, + CheckpointReader.read(path, sparkContext.conf, sparkContext.hadoopConfiguration).orNull, null) } - - if (_sc == null && _cp == null) { - throw new Exception("Spark Streaming cannot be initialized with " + - "both SparkContext and checkpoint as null") - } + require(_sc != null || _cp != null, + "Spark Streaming cannot be initialized with both SparkContext and checkpoint as null") private[streaming] val isCheckpointPresent: Boolean = _cp != null diff --git a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala index 9a3248b3e8..fbb25d4c59 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala @@ -228,6 +228,11 @@ class CheckpointSuite extends TestSuiteBase with DStreamCheckpointTester } } + test("non-existent checkpoint dir") { + // SPARK-13211 + intercept[IllegalArgumentException](new StreamingContext("nosuchdirectory")) + } + test("basic rdd checkpoints + dstream graph checkpoint recovery") { assert(batchDuration === Milliseconds(500), "batchDuration for this test must be 1 second") -- cgit v1.2.3 From f6456fa80ba442bfd7ce069fc23b7dbd993e6cb9 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 6 Apr 2016 12:09:10 +0800 Subject: [SPARK-14296][SQL] whole stage codegen support for Dataset.map ## What changes were proposed in this pull request? This PR adds a new operator `MapElements` for `Dataset.map`, it's a 1-1 mapping and is easier to adapt to whole stage codegen framework. ## How was this patch tested? new test in `WholeStageCodegenSuite` Author: Wenchen Fan Closes #12087 from cloud-fan/map. --- .../spark/sql/catalyst/analysis/unresolved.scala | 2 +- .../spark/sql/catalyst/expressions/objects.scala | 40 ++++++---- .../spark/sql/catalyst/optimizer/Optimizer.scala | 9 +++ .../spark/sql/catalyst/plans/logical/object.scala | 28 ++++++- .../main/scala/org/apache/spark/sql/Dataset.scala | 22 +++--- .../spark/sql/execution/SparkStrategies.scala | 2 + .../spark/sql/execution/WholeStageCodegen.scala | 11 ++- .../org/apache/spark/sql/execution/objects.scala | 69 ++++++++++++++++- .../org/apache/spark/sql/DatasetBenchmark.scala | 86 ++++++++++++++++++++++ .../scala/org/apache/spark/sql/QueryTest.scala | 5 +- .../sql/execution/WholeStageCodegenSuite.scala | 14 +++- 11 files changed, 247 insertions(+), 41 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/DatasetBenchmark.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala index b2f362b6b8..4ec43aba02 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala @@ -345,7 +345,7 @@ case class UnresolvedAlias(child: Expression, aliasName: Option[String] = None) * @param inputAttributes The input attributes used to resolve deserializer expression, can be empty * if we want to resolve deserializer by children output. */ -case class UnresolvedDeserializer(deserializer: Expression, inputAttributes: Seq[Attribute]) +case class UnresolvedDeserializer(deserializer: Expression, inputAttributes: Seq[Attribute] = Nil) extends UnaryExpression with Unevaluable with NonSQLExpression { // The input attributes used to resolve deserializer expression must be all resolved. require(inputAttributes.forall(_.resolved), "Input attributes must all be resolved.") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala index eebd43dae9..a0490e1351 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala @@ -119,18 +119,18 @@ case class Invoke( override def eval(input: InternalRow): Any = throw new UnsupportedOperationException("Only code-generated evaluation is supported.") - lazy val method = targetObject.dataType match { + @transient lazy val method = targetObject.dataType match { case ObjectType(cls) => - cls - .getMethods - .find(_.getName == functionName) - .getOrElse(sys.error(s"Couldn't find $functionName on $cls")) - .getReturnType - .getName - case _ => "" + val m = cls.getMethods.find(_.getName == functionName) + if (m.isEmpty) { + sys.error(s"Couldn't find $functionName on $cls") + } else { + m + } + case _ => None } - lazy val unboxer = (dataType, method) match { + lazy val unboxer = (dataType, method.map(_.getReturnType.getName).getOrElse("")) match { case (IntegerType, "java.lang.Object") => (s: String) => s"((java.lang.Integer)$s).intValue()" case (LongType, "java.lang.Object") => (s: String) => @@ -157,21 +157,31 @@ case class Invoke( // If the function can return null, we do an extra check to make sure our null bit is still set // correctly. val objNullCheck = if (ctx.defaultValue(dataType) == "null") { - s"${ev.isNull} = ${ev.value} == null;" + s"boolean ${ev.isNull} = ${ev.value} == null;" } else { + ev.isNull = obj.isNull "" } val value = unboxer(s"${obj.value}.$functionName($argString)") + val evaluate = if (method.forall(_.getExceptionTypes.isEmpty)) { + s"$javaType ${ev.value} = ${obj.isNull} ? ${ctx.defaultValue(dataType)} : ($javaType) $value;" + } else { + s""" + $javaType ${ev.value} = ${ctx.defaultValue(javaType)}; + try { + ${ev.value} = ${obj.isNull} ? ${ctx.defaultValue(javaType)} : ($javaType) $value; + } catch (Exception e) { + org.apache.spark.unsafe.Platform.throwException(e); + } + """ + } + s""" ${obj.code} ${argGen.map(_.code).mkString("\n")} - - boolean ${ev.isNull} = ${obj.isNull}; - $javaType ${ev.value} = - ${ev.isNull} ? - ${ctx.defaultValue(dataType)} : ($javaType) $value; + $evaluate $objNullCheck """ } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 69b09bcb35..c085a377ff 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -136,6 +136,7 @@ object SamplePushDown extends Rule[LogicalPlan] { * representation of data item. For example back to back map operations. */ object EliminateSerialization extends Rule[LogicalPlan] { + // TODO: find a more general way to do this optimization. def apply(plan: LogicalPlan): LogicalPlan = plan transform { case m @ MapPartitions(_, deserializer, _, child: ObjectOperator) if !deserializer.isInstanceOf[Attribute] && @@ -144,6 +145,14 @@ object EliminateSerialization extends Rule[LogicalPlan] { m.copy( deserializer = childWithoutSerialization.output.head, child = childWithoutSerialization) + + case m @ MapElements(_, deserializer, _, child: ObjectOperator) + if !deserializer.isInstanceOf[Attribute] && + deserializer.dataType == child.outputObject.dataType => + val childWithoutSerialization = child.withObjectOutput + m.copy( + deserializer = childWithoutSerialization.output.head, + child = childWithoutSerialization) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala index 58313c7b72..ec33a538a9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala @@ -65,7 +65,7 @@ object MapPartitions { child: LogicalPlan): MapPartitions = { MapPartitions( func.asInstanceOf[Iterator[Any] => Iterator[Any]], - UnresolvedDeserializer(encoderFor[T].deserializer, Nil), + UnresolvedDeserializer(encoderFor[T].deserializer), encoderFor[U].namedExpressions, child) } @@ -83,6 +83,30 @@ case class MapPartitions( serializer: Seq[NamedExpression], child: LogicalPlan) extends UnaryNode with ObjectOperator +object MapElements { + def apply[T : Encoder, U : Encoder]( + func: AnyRef, + child: LogicalPlan): MapElements = { + MapElements( + func, + UnresolvedDeserializer(encoderFor[T].deserializer), + encoderFor[U].namedExpressions, + child) + } +} + +/** + * A relation produced by applying `func` to each element of the `child`. + * + * @param deserializer used to extract the input to `func` from an input row. + * @param serializer use to serialize the output of `func`. + */ +case class MapElements( + func: AnyRef, + deserializer: Expression, + serializer: Seq[NamedExpression], + child: LogicalPlan) extends UnaryNode with ObjectOperator + /** Factory for constructing new `AppendColumn` nodes. */ object AppendColumns { def apply[T : Encoder, U : Encoder]( @@ -90,7 +114,7 @@ object AppendColumns { child: LogicalPlan): AppendColumns = { new AppendColumns( func.asInstanceOf[Any => Any], - UnresolvedDeserializer(encoderFor[T].deserializer, Nil), + UnresolvedDeserializer(encoderFor[T].deserializer), encoderFor[U].namedExpressions, child) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index f472a5068e..2854d5f9da 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -766,7 +766,8 @@ class Dataset[T] private[sql]( implicit val tuple2Encoder: Encoder[(T, U)] = ExpressionEncoder.tuple(this.unresolvedTEncoder, other.unresolvedTEncoder) - withTypedPlan[(T, U)](other, encoderFor[(T, U)]) { (left, right) => + + withTypedPlan { Project( leftData :: rightData :: Nil, joined.analyzed) @@ -1900,7 +1901,9 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ @Experimental - def map[U : Encoder](func: T => U): Dataset[U] = mapPartitions(_.map(func)) + def map[U : Encoder](func: T => U): Dataset[U] = withTypedPlan { + MapElements[T, U](func, logicalPlan) + } /** * :: Experimental :: @@ -1911,8 +1914,10 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ @Experimental - def map[U](func: MapFunction[T, U], encoder: Encoder[U]): Dataset[U] = - map(t => func.call(t))(encoder) + def map[U](func: MapFunction[T, U], encoder: Encoder[U]): Dataset[U] = { + implicit val uEnc = encoder + withTypedPlan(MapElements[T, U](func, logicalPlan)) + } /** * :: Experimental :: @@ -2412,12 +2417,7 @@ class Dataset[T] private[sql]( } /** A convenient function to wrap a logical plan and produce a Dataset. */ - @inline private def withTypedPlan(logicalPlan: => LogicalPlan): Dataset[T] = { - new Dataset[T](sqlContext, logicalPlan, encoder) + @inline private def withTypedPlan[U : Encoder](logicalPlan: => LogicalPlan): Dataset[U] = { + Dataset(sqlContext, logicalPlan) } - - private[sql] def withTypedPlan[R]( - other: Dataset[_], encoder: Encoder[R])( - f: (LogicalPlan, LogicalPlan) => LogicalPlan): Dataset[R] = - new Dataset[R](sqlContext, f(logicalPlan, other.logicalPlan), encoder) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index e52f05a5f4..5f3128d8e4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -341,6 +341,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case logical.MapPartitions(f, in, out, child) => execution.MapPartitions(f, in, out, planLater(child)) :: Nil + case logical.MapElements(f, in, out, child) => + execution.MapElements(f, in, out, planLater(child)) :: Nil case logical.AppendColumns(f, in, out, child) => execution.AppendColumns(f, in, out, planLater(child)) :: Nil case logical.MapGroups(f, key, in, out, grouping, data, child) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala index 9f539c4929..4e75a3a794 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala @@ -152,7 +152,7 @@ trait CodegenSupport extends SparkPlan { s""" | |/*** CONSUME: ${toCommentSafeString(parent.simpleString)} */ - |${evaluated} + |$evaluated |${parent.doConsume(ctx, inputVars, rowVar)} """.stripMargin } @@ -169,20 +169,20 @@ trait CodegenSupport extends SparkPlan { /** * Returns source code to evaluate the variables for required attributes, and clear the code - * of evaluated variables, to prevent them to be evaluated twice.. + * of evaluated variables, to prevent them to be evaluated twice. */ protected def evaluateRequiredVariables( attributes: Seq[Attribute], variables: Seq[ExprCode], required: AttributeSet): String = { - var evaluateVars = "" + val evaluateVars = new StringBuilder variables.zipWithIndex.foreach { case (ev, i) => if (ev.code != "" && required.contains(attributes(i))) { - evaluateVars += ev.code.trim + "\n" + evaluateVars.append(ev.code.trim + "\n") ev.code = "" } } - evaluateVars + evaluateVars.toString() } /** @@ -305,7 +305,6 @@ case class WholeStageCodegen(child: SparkPlan) extends UnaryNode with CodegenSup def doCodeGen(): (CodegenContext, String) = { val ctx = new CodegenContext val code = child.asInstanceOf[CodegenSupport].produce(ctx, this) - val references = ctx.references.toArray val source = s""" public Object generate(Object[] references) { return new GeneratedIterator(references); diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala index 582dda8603..f48f3f09c7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala @@ -17,10 +17,13 @@ package org.apache.spark.sql.execution +import scala.language.existentials + +import org.apache.spark.api.java.function.MapFunction import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateSafeProjection, GenerateUnsafeProjection, GenerateUnsafeRowJoiner} +import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.types.ObjectType @@ -67,6 +70,70 @@ case class MapPartitions( } } +/** + * Applies the given function to each input row and encodes the result. + * + * Note that, each serializer expression needs the result object which is returned by the given + * function, as input. This operator uses some tricks to make sure we only calculate the result + * object once. We don't use [[Project]] directly as subexpression elimination doesn't work with + * whole stage codegen and it's confusing to show the un-common-subexpression-eliminated version of + * a project while explain. + */ +case class MapElements( + func: AnyRef, + deserializer: Expression, + serializer: Seq[NamedExpression], + child: SparkPlan) extends UnaryNode with ObjectOperator with CodegenSupport { + override def output: Seq[Attribute] = serializer.map(_.toAttribute) + + override def upstreams(): Seq[RDD[InternalRow]] = { + child.asInstanceOf[CodegenSupport].upstreams() + } + + protected override def doProduce(ctx: CodegenContext): String = { + child.asInstanceOf[CodegenSupport].produce(ctx, this) + } + + override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = { + val (funcClass, methodName) = func match { + case m: MapFunction[_, _] => classOf[MapFunction[_, _]] -> "call" + case _ => classOf[Any => Any] -> "apply" + } + val funcObj = Literal.create(func, ObjectType(funcClass)) + val resultObjType = serializer.head.collect { case b: BoundReference => b }.head.dataType + val callFunc = Invoke(funcObj, methodName, resultObjType, Seq(deserializer)) + + val bound = ExpressionCanonicalizer.execute( + BindReferences.bindReference(callFunc, child.output)) + ctx.currentVars = input + val evaluated = bound.gen(ctx) + + val resultObj = LambdaVariable(evaluated.value, evaluated.isNull, resultObjType) + val outputFields = serializer.map(_ transform { + case _: BoundReference => resultObj + }) + val resultVars = outputFields.map(_.gen(ctx)) + s""" + ${evaluated.code} + ${consume(ctx, resultVars)} + """ + } + + override protected def doExecute(): RDD[InternalRow] = { + val callFunc: Any => Any = func match { + case m: MapFunction[_, _] => i => m.asInstanceOf[MapFunction[Any, Any]].call(i) + case _ => func.asInstanceOf[Any => Any] + } + child.execute().mapPartitionsInternal { iter => + val getObject = generateToObject(deserializer, child.output) + val outputObject = generateToRow(serializer) + iter.map(row => outputObject(callFunc(getObject(row)))) + } + } + + override def outputOrdering: Seq[SortOrder] = child.outputOrdering +} + /** * Applies the given function to each input row, appending the encoded result at the end of the row. */ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetBenchmark.scala new file mode 100644 index 0000000000..6eb952445f --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetBenchmark.scala @@ -0,0 +1,86 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.SparkContext +import org.apache.spark.sql.types.StringType +import org.apache.spark.util.Benchmark + +/** + * Benchmark for Dataset typed operations comparing with DataFrame and RDD versions. + */ +object DatasetBenchmark { + + case class Data(l: Long, s: String) + + def main(args: Array[String]): Unit = { + val sparkContext = new SparkContext("local[*]", "Dataset benchmark") + val sqlContext = new SQLContext(sparkContext) + + import sqlContext.implicits._ + + val numRows = 10000000 + val df = sqlContext.range(1, numRows).select($"id".as("l"), $"id".cast(StringType).as("s")) + val numChains = 10 + + val benchmark = new Benchmark("back-to-back map", numRows) + + val func = (d: Data) => Data(d.l + 1, d.s) + benchmark.addCase("Dataset") { iter => + var res = df.as[Data] + var i = 0 + while (i < numChains) { + res = res.map(func) + i += 1 + } + res.queryExecution.toRdd.foreach(_ => Unit) + } + + benchmark.addCase("DataFrame") { iter => + var res = df + var i = 0 + while (i < numChains) { + res = res.select($"l" + 1 as "l") + i += 1 + } + res.queryExecution.toRdd.foreach(_ => Unit) + } + + val rdd = sparkContext.range(1, numRows).map(l => Data(l, l.toString)) + benchmark.addCase("RDD") { iter => + var res = rdd + var i = 0 + while (i < numChains) { + res = rdd.map(func) + i += 1 + } + res.foreach(_ => Unit) + } + + /* + Java HotSpot(TM) 64-Bit Server VM 1.8.0_60-b27 on Mac OS X 10.11.4 + Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz + back-to-back map: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------- + Dataset 902 / 995 11.1 90.2 1.0X + DataFrame 132 / 167 75.5 13.2 6.8X + RDD 216 / 237 46.3 21.6 4.2X + */ + benchmark.run() + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index f7f3bd78e9..4e62fac919 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -198,10 +198,7 @@ abstract class QueryTest extends PlanTest { val logicalPlan = df.queryExecution.analyzed // bypass some cases that we can't handle currently. logicalPlan.transform { - case _: MapPartitions => return - case _: MapGroups => return - case _: AppendColumns => return - case _: CoGroup => return + case _: ObjectOperator => return case _: LogicalRelation => return }.transformAllExpressions { case a: ImperativeAggregate => return diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala index 6d5be0b5dd..f73ca887f1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala @@ -17,7 +17,8 @@ package org.apache.spark.sql.execution -import org.apache.spark.sql.Row +import org.apache.spark.api.java.function.MapFunction +import org.apache.spark.sql.{Encoders, Row} import org.apache.spark.sql.execution.aggregate.TungstenAggregate import org.apache.spark.sql.execution.joins.BroadcastHashJoin import org.apache.spark.sql.functions.{avg, broadcast, col, max} @@ -70,4 +71,15 @@ class WholeStageCodegenSuite extends SparkPlanTest with SharedSQLContext { p.asInstanceOf[WholeStageCodegen].child.isInstanceOf[Sort]).isDefined) assert(df.collect() === Array(Row(1), Row(2), Row(3))) } + + test("MapElements should be included in WholeStageCodegen") { + import testImplicits._ + + val ds = sqlContext.range(10).map(_.toString) + val plan = ds.queryExecution.executedPlan + assert(plan.find(p => + p.isInstanceOf[WholeStageCodegen] && + p.asInstanceOf[WholeStageCodegen].child.isInstanceOf[MapElements]).isDefined) + assert(ds.collect() === 0.until(10).map(_.toString).toArray) + } } -- cgit v1.2.3 From adbfdb878dd1029738db3d1955d08b33de1aa8a9 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Tue, 5 Apr 2016 21:23:20 -0700 Subject: [SPARK-14128][SQL] Alter table DDL followup ## What changes were proposed in this pull request? This is just a followup to #12121, which implemented the alter table DDLs using the `SessionCatalog`. Specially, this corrects the behavior of setting the location of a datasource table. For datasource tables, we need to set the `locationUri` in addition to the `path` entry in the serde properties. Additionally, changing the location of a datasource table partition is not allowed. ## How was this patch tested? `DDLSuite` Author: Andrew Or Closes #12186 from andrewor14/alter-table-ddl-followup. --- .../org/apache/spark/sql/execution/command/ddl.scala | 6 ++++-- .../spark/sql/execution/command/DDLSuite.scala | 20 +++++++++++++++++--- 2 files changed, 21 insertions(+), 5 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala index 0d38c41a3f..6d56a6fec8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala @@ -383,8 +383,9 @@ case class AlterTableSetLocation( val part = catalog.getPartition(tableName, spec) val newPart = if (DDLUtils.isDatasourceTable(table)) { - part.copy(storage = part.storage.copy( - serdeProperties = part.storage.serdeProperties ++ Map("path" -> location))) + throw new AnalysisException( + "alter table set location for partition is not allowed for tables defined " + + "using the datasource API") } else { part.copy(storage = part.storage.copy(locationUri = Some(location))) } @@ -394,6 +395,7 @@ case class AlterTableSetLocation( val newTable = if (DDLUtils.isDatasourceTable(table)) { table.withNewStorage( + locationUri = Some(location), serdeProperties = table.storage.serdeProperties ++ Map("path" -> location)) } else { table.withNewStorage(locationUri = Some(location)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala index d8e2c94a8a..a8db4e9923 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala @@ -417,23 +417,37 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { .map { s => catalog.getPartition(tableIdent, s).storage } .getOrElse { catalog.getTable(tableIdent).storage } if (isDatasourceTable) { - assert(storageFormat.serdeProperties.get("path") === Some(expected)) + if (spec.isDefined) { + assert(storageFormat.serdeProperties.isEmpty) + assert(storageFormat.locationUri.isEmpty) + } else { + assert(storageFormat.serdeProperties.get("path") === Some(expected)) + assert(storageFormat.locationUri === Some(expected)) + } } else { assert(storageFormat.locationUri === Some(expected)) } } + // Optionally expect AnalysisException + def maybeWrapException[T](expectException: Boolean)(body: => T): Unit = { + if (expectException) intercept[AnalysisException] { body } else body + } // set table location sql("ALTER TABLE dbx.tab1 SET LOCATION '/path/to/your/lovely/heart'") verifyLocation("/path/to/your/lovely/heart") // set table partition location - sql("ALTER TABLE dbx.tab1 PARTITION (a='1') SET LOCATION '/path/to/part/ways'") + maybeWrapException(isDatasourceTable) { + sql("ALTER TABLE dbx.tab1 PARTITION (a='1') SET LOCATION '/path/to/part/ways'") + } verifyLocation("/path/to/part/ways", Some(partSpec)) // set table location without explicitly specifying database catalog.setCurrentDatabase("dbx") sql("ALTER TABLE tab1 SET LOCATION '/swanky/steak/place'") verifyLocation("/swanky/steak/place") // set table partition location without explicitly specifying database - sql("ALTER TABLE tab1 PARTITION (a='1') SET LOCATION 'vienna'") + maybeWrapException(isDatasourceTable) { + sql("ALTER TABLE tab1 PARTITION (a='1') SET LOCATION 'vienna'") + } verifyLocation("vienna", Some(partSpec)) // table to alter does not exist intercept[AnalysisException] { -- cgit v1.2.3 From 48467f4eb02209a884adbcf052670a057a75fcbd Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Tue, 5 Apr 2016 22:32:37 -0700 Subject: [SPARK-14416][CORE] Add thread-safe comments for CoarseGrainedSchedulerBackend's fields ## What changes were proposed in this pull request? While I was reviewing #12078, I found most of CoarseGrainedSchedulerBackend's mutable fields doesn't have any comments about the thread-safe assumptions and it's hard for people to figure out which part of codes should be protected by the lock. This PR just added comments/annotations for them and also added strict access modifiers for some fields. ## How was this patch tested? Existing unit tests. Author: Shixiong Zhu Closes #12188 from zsxwing/comments. --- .../cluster/CoarseGrainedSchedulerBackend.scala | 37 ++++++++++++++-------- .../scheduler/cluster/YarnSchedulerBackend.scala | 9 ++++-- 2 files changed, 30 insertions(+), 16 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index 70470cc6d2..f71bfd489d 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -19,6 +19,7 @@ package org.apache.spark.scheduler.cluster import java.util.concurrent.TimeUnit import java.util.concurrent.atomic.AtomicInteger +import javax.annotation.concurrent.GuardedBy import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} @@ -43,24 +44,30 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp extends ExecutorAllocationClient with SchedulerBackend with Logging { // Use an atomic variable to track total number of cores in the cluster for simplicity and speed - var totalCoreCount = new AtomicInteger(0) + protected val totalCoreCount = new AtomicInteger(0) // Total number of executors that are currently registered - var totalRegisteredExecutors = new AtomicInteger(0) - val conf = scheduler.sc.conf + protected val totalRegisteredExecutors = new AtomicInteger(0) + protected val conf = scheduler.sc.conf private val maxRpcMessageSize = RpcUtils.maxMessageSizeBytes(conf) // Submit tasks only after (registered resources / total expected resources) // is equal to at least this value, that is double between 0 and 1. - var minRegisteredRatio = + private val _minRegisteredRatio = math.min(1, conf.getDouble("spark.scheduler.minRegisteredResourcesRatio", 0)) // Submit tasks after maxRegisteredWaitingTime milliseconds // if minRegisteredRatio has not yet been reached - val maxRegisteredWaitingTimeMs = + private val maxRegisteredWaitingTimeMs = conf.getTimeAsMs("spark.scheduler.maxRegisteredResourcesWaitingTime", "30s") - val createTime = System.currentTimeMillis() + private val createTime = System.currentTimeMillis() + // Accessing `executorDataMap` in `DriverEndpoint.receive/receiveAndReply` doesn't need any + // protection. But accessing `executorDataMap` out of `DriverEndpoint.receive/receiveAndReply` + // must be protected by `CoarseGrainedSchedulerBackend.this`. Besides, `executorDataMap` should + // only be modified in `DriverEndpoint.receive/receiveAndReply` with protection by + // `CoarseGrainedSchedulerBackend.this`. private val executorDataMap = new HashMap[String, ExecutorData] // Number of executors requested from the cluster manager that have not registered yet + @GuardedBy("CoarseGrainedSchedulerBackend.this") private var numPendingExecutors = 0 private val listenerBus = scheduler.sc.listenerBus @@ -68,23 +75,26 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp // Executors we have requested the cluster manager to kill that have not died yet; maps // the executor ID to whether it was explicitly killed by the driver (and thus shouldn't // be considered an app-related failure). + @GuardedBy("CoarseGrainedSchedulerBackend.this") private val executorsPendingToRemove = new HashMap[String, Boolean] // A map to store hostname with its possible task number running on it + @GuardedBy("CoarseGrainedSchedulerBackend.this") protected var hostToLocalTaskCount: Map[String, Int] = Map.empty // The number of pending tasks which is locality required + @GuardedBy("CoarseGrainedSchedulerBackend.this") protected var localityAwareTasks = 0 - // Executors that have been lost, but for which we don't yet know the real exit reason. - protected val executorsPendingLossReason = new HashSet[String] - // The num of current max ExecutorId used to re-register appMaster - protected var currentExecutorIdCounter = 0 + @volatile protected var currentExecutorIdCounter = 0 class DriverEndpoint(override val rpcEnv: RpcEnv, sparkProperties: Seq[(String, String)]) extends ThreadSafeRpcEndpoint with Logging { + // Executors that have been lost, but for which we don't yet know the real exit reason. + protected val executorsPendingLossReason = new HashSet[String] + // If this DriverEndpoint is changed to support multiple threads, // then this may need to be changed so that we don't share the serializer // instance across threads @@ -261,7 +271,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp } // Remove a disconnected slave from the cluster - def removeExecutor(executorId: String, reason: ExecutorLossReason): Unit = { + private def removeExecutor(executorId: String, reason: ExecutorLossReason): Unit = { executorDataMap.get(executorId) match { case Some(executorInfo) => // This must be synchronized because variables mutated @@ -313,7 +323,8 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp } var driverEndpoint: RpcEndpointRef = null - val taskIdsOnSlave = new HashMap[String, HashSet[String]] + + protected def minRegisteredRatio: Double = _minRegisteredRatio override def start() { val properties = new ArrayBuffer[(String, String)] @@ -417,7 +428,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp /** * Return the number of executors currently registered with this backend. */ - def numExistingExecutors: Int = executorDataMap.size + private def numExistingExecutors: Int = executorDataMap.size /** * Request an additional number of executors from the cluster manager. diff --git a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala index 5aeaf44732..8720ee57fe 100644 --- a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala +++ b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala @@ -39,9 +39,12 @@ private[spark] abstract class YarnSchedulerBackend( sc: SparkContext) extends CoarseGrainedSchedulerBackend(scheduler, sc.env.rpcEnv) { - if (conf.getOption("spark.scheduler.minRegisteredResourcesRatio").isEmpty) { - minRegisteredRatio = 0.8 - } + override val minRegisteredRatio = + if (conf.getOption("spark.scheduler.minRegisteredResourcesRatio").isEmpty) { + 0.8 + } else { + super.minRegisteredRatio + } protected var totalExpectedExecutors = 0 -- cgit v1.2.3 From 68be5b9e8a5ac1fc4d243bb54c2ca95fee3f74dc Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Tue, 5 Apr 2016 22:33:44 -0700 Subject: [SPARK-14396][SQL] Throw Exceptions for DDLs of Partitioned Views #### What changes were proposed in this pull request? Because the concept of partitioning is associated with physical tables, we disable all the supports of partitioned views, which are defined in the following three commands in [Hive DDL Manual](https://cwiki.apache.org/confluence/display/Hive/LanguageManual+DDL#LanguageManualDDL-Create/Drop/AlterView): ``` ALTER VIEW view DROP [IF EXISTS] PARTITION spec1[, PARTITION spec2, ...]; ALTER VIEW view ADD [IF NOT EXISTS] PARTITION spec; CREATE VIEW [IF NOT EXISTS] [db_name.]view_name [(column_name [COMMENT column_comment], ...) ] [COMMENT view_comment] [TBLPROPERTIES (property_name = property_value, ...)] AS SELECT ...; ``` An exception is thrown when users issue any of these three DDL commands. #### How was this patch tested? Added test cases for parsing create view and changed the existing test cases to verify if the exceptions are thrown. Author: gatorsmile Author: xiaoli Author: Xiao Li Closes #12169 from gatorsmile/viewPartition. --- .../spark/sql/execution/SparkSqlParser.scala | 10 +- .../sql/execution/command/DDLCommandSuite.scala | 48 +--- .../hive/execution/HiveCompatibilitySuite.scala | 8 +- .../spark/sql/hive/execution/HiveSqlParser.scala | 12 +- .../spark/sql/hive/HiveDDLCommandSuite.scala | 308 +++++++++++++++++++++ .../org/apache/spark/sql/hive/HiveQlSuite.scala | 254 ----------------- 6 files changed, 345 insertions(+), 295 deletions(-) create mode 100644 sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDDLCommandSuite.scala delete mode 100644 sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveQlSuite.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala index d3086fc91e..3de8aa0276 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala @@ -20,7 +20,7 @@ import scala.collection.JavaConverters._ import org.apache.spark.sql.{AnalysisException, SaveMode} import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.catalyst.parser.{AbstractSqlParser, AstBuilder, ParseException} +import org.apache.spark.sql.catalyst.parser._ import org.apache.spark.sql.catalyst.parser.SqlBaseParser._ import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, OneRowRelation} import org.apache.spark.sql.execution.command.{DescribeCommand => _, _} @@ -474,9 +474,13 @@ class SparkSqlAstBuilder extends AstBuilder { * ALTER TABLE table ADD [IF NOT EXISTS] PARTITION spec [LOCATION 'loc1'] * ALTER VIEW view ADD [IF NOT EXISTS] PARTITION spec * }}} + * + * ALTER VIEW ... DROP PARTITION ... is not supported because the concept of partitioning + * is associated with physical tables */ override def visitAddTablePartition( ctx: AddTablePartitionContext): LogicalPlan = withOrigin(ctx) { + if (ctx.VIEW != null) throw new ParseException(s"Operation not allowed: partitioned views", ctx) // Create partition spec to location mapping. val specsAndLocs = if (ctx.partitionSpec.isEmpty) { ctx.partitionSpecLocation.asScala.map { @@ -538,9 +542,13 @@ class SparkSqlAstBuilder extends AstBuilder { * ALTER TABLE table DROP [IF EXISTS] PARTITION spec1[, PARTITION spec2, ...] [PURGE]; * ALTER VIEW view DROP [IF EXISTS] PARTITION spec1[, PARTITION spec2, ...]; * }}} + * + * ALTER VIEW ... DROP PARTITION ... is not supported because the concept of partitioning + * is associated with physical tables */ override def visitDropTablePartitions( ctx: DropTablePartitionsContext): LogicalPlan = withOrigin(ctx) { + if (ctx.VIEW != null) throw new ParseException(s"Operation not allowed: partitioned views", ctx) AlterTableDropPartition( visitTableIdentifier(ctx.tableIdentifier), ctx.partitionSpec.asScala.map(visitNonOptionalPartitionSpec), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala index 618c9a58a6..46dcadd690 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala @@ -351,22 +351,12 @@ class DDLCommandSuite extends PlanTest { |(col1=NULL, cOL2='f', col3=5, COL4=true) """.stripMargin - val parsed1 = parser.parsePlan(sql1) - val parsed2 = parser.parsePlan(sql2) - - val expected1 = AlterTableAddPartition( - TableIdentifier("view_name", None), - Seq( - (Map("dt" -> "2008-08-08", "country" -> "us"), None), - (Map("dt" -> "2009-09-09", "country" -> "uk"), None)), - ifNotExists = true)(sql1) - val expected2 = AlterTableAddPartition( - TableIdentifier("view_name", None), - Seq((Map("col1" -> "NULL", "col2" -> "f", "col3" -> "5", "col4" -> "true"), None)), - ifNotExists = false)(sql2) - - comparePlans(parsed1, expected1) - comparePlans(parsed2, expected2) + intercept[ParseException] { + parser.parsePlan(sql1) + } + intercept[ParseException] { + parser.parsePlan(sql2) + } } test("alter table: rename partition") { @@ -416,8 +406,13 @@ class DDLCommandSuite extends PlanTest { val parsed1_table = parser.parsePlan(sql1_table) val parsed2_table = parser.parsePlan(sql2_table) - val parsed1_view = parser.parsePlan(sql1_view) - val parsed2_view = parser.parsePlan(sql2_view) + + intercept[ParseException] { + parser.parsePlan(sql1_view) + } + intercept[ParseException] { + parser.parsePlan(sql2_view) + } val tableIdent = TableIdentifier("table_name", None) val expected1_table = AlterTableDropPartition( @@ -435,25 +430,8 @@ class DDLCommandSuite extends PlanTest { ifExists = false, purge = true)(sql2_table) - val expected1_view = AlterTableDropPartition( - tableIdent, - Seq( - Map("dt" -> "2008-08-08", "country" -> "us"), - Map("dt" -> "2009-09-09", "country" -> "uk")), - ifExists = true, - purge = false)(sql1_view) - val expected2_view = AlterTableDropPartition( - tableIdent, - Seq( - Map("dt" -> "2008-08-08", "country" -> "us"), - Map("dt" -> "2009-09-09", "country" -> "uk")), - ifExists = false, - purge = false)(sql2_table) - comparePlans(parsed1_table, expected1_table) comparePlans(parsed2_table, expected2_table) - comparePlans(parsed1_view, expected1_view) - comparePlans(parsed2_view, expected2_view) } test("alter table: archive partition") { diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala index b01f556f0a..9e3cb18d45 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala @@ -372,7 +372,11 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "alter_index", // Macro commands are not supported - "macro" + "macro", + + // Create partitioned view is not supported + "create_like_view", + "describe_formatted_view_partitioned" ) /** @@ -482,7 +486,6 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "cp_mj_rc", "create_insert_outputformat", "create_like_tbl_props", - "create_like_view", "create_nested_type", "create_skewed_table1", "create_struct_table", @@ -507,7 +510,6 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "default_partition_name", "delimiter", "desc_non_existent_tbl", - "describe_formatted_view_partitioned", "diff_part_input_formats", "disable_file_format_check", "disallow_incompatible_type_change_off", diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveSqlParser.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveSqlParser.scala index c6c0b2ca59..ab69d3502e 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveSqlParser.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveSqlParser.scala @@ -215,11 +215,19 @@ class HiveSqlAstBuilder extends SparkSqlAstBuilder { /** * Create or replace a view. This creates a [[CreateViewAsSelect]] command. + * + * For example: + * {{{ + * CREATE VIEW [IF NOT EXISTS] [db_name.]view_name + * [(column_name [COMMENT column_comment], ...) ] + * [COMMENT view_comment] + * [TBLPROPERTIES (property_name = property_value, ...)] + * AS SELECT ...; + * }}} */ override def visitCreateView(ctx: CreateViewContext): LogicalPlan = withOrigin(ctx) { - // Pass a partitioned view on to hive. if (ctx.identifierList != null) { - HiveNativeCommand(command(ctx)) + throw new ParseException(s"Operation not allowed: partitioned views", ctx) } else { if (ctx.STRING != null) { logWarning("COMMENT clause is ignored.") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDDLCommandSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDDLCommandSuite.scala new file mode 100644 index 0000000000..b4e5d4adf1 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDDLCommandSuite.scala @@ -0,0 +1,308 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive + +import org.apache.hadoop.hive.serde.serdeConstants + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute +import org.apache.spark.sql.catalyst.catalog.{CatalogColumn, CatalogTable, CatalogTableType} +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans +import org.apache.spark.sql.catalyst.dsl.plans.DslLogicalPlan +import org.apache.spark.sql.catalyst.expressions.JsonTuple +import org.apache.spark.sql.catalyst.parser.ParseException +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical.{Generate, ScriptTransformation} +import org.apache.spark.sql.hive.execution.{HiveNativeCommand, HiveSqlParser} + +class HiveDDLCommandSuite extends PlanTest { + val parser = HiveSqlParser + + private def extractTableDesc(sql: String): (CatalogTable, Boolean) = { + parser.parsePlan(sql).collect { + case CreateTableAsSelect(desc, _, allowExisting) => (desc, allowExisting) + case CreateViewAsSelect(desc, _, allowExisting, _, _) => (desc, allowExisting) + }.head + } + + test("Test CTAS #1") { + val s1 = + """CREATE EXTERNAL TABLE IF NOT EXISTS mydb.page_view + |(viewTime INT, + |userid BIGINT, + |page_url STRING, + |referrer_url STRING, + |ip STRING COMMENT 'IP Address of the User', + |country STRING COMMENT 'country of origination') + |COMMENT 'This is the staging page view table' + |PARTITIONED BY (dt STRING COMMENT 'date type', hour STRING COMMENT 'hour of the day') + |ROW FORMAT DELIMITED FIELDS TERMINATED BY '\054' STORED AS RCFILE + |LOCATION '/user/external/page_view' + |TBLPROPERTIES ('p1'='v1', 'p2'='v2') + |AS SELECT * FROM src""".stripMargin + + val (desc, exists) = extractTableDesc(s1) + assert(exists) + assert(desc.identifier.database == Some("mydb")) + assert(desc.identifier.table == "page_view") + assert(desc.tableType == CatalogTableType.EXTERNAL_TABLE) + assert(desc.storage.locationUri == Some("/user/external/page_view")) + assert(desc.schema == + CatalogColumn("viewtime", "int") :: + CatalogColumn("userid", "bigint") :: + CatalogColumn("page_url", "string") :: + CatalogColumn("referrer_url", "string") :: + CatalogColumn("ip", "string", comment = Some("IP Address of the User")) :: + CatalogColumn("country", "string", comment = Some("country of origination")) :: Nil) + // TODO will be SQLText + assert(desc.viewText == Option("This is the staging page view table")) + assert(desc.partitionColumns == + CatalogColumn("dt", "string", comment = Some("date type")) :: + CatalogColumn("hour", "string", comment = Some("hour of the day")) :: Nil) + assert(desc.storage.serdeProperties == + Map((serdeConstants.SERIALIZATION_FORMAT, "\u002C"), (serdeConstants.FIELD_DELIM, "\u002C"))) + assert(desc.storage.inputFormat == Some("org.apache.hadoop.hive.ql.io.RCFileInputFormat")) + assert(desc.storage.outputFormat == Some("org.apache.hadoop.hive.ql.io.RCFileOutputFormat")) + assert(desc.storage.serde == + Some("org.apache.hadoop.hive.serde2.columnar.LazyBinaryColumnarSerDe")) + assert(desc.properties == Map(("p1", "v1"), ("p2", "v2"))) + } + + test("Test CTAS #2") { + val s2 = + """CREATE EXTERNAL TABLE IF NOT EXISTS mydb.page_view + |(viewTime INT, + |userid BIGINT, + |page_url STRING, + |referrer_url STRING, + |ip STRING COMMENT 'IP Address of the User', + |country STRING COMMENT 'country of origination') + |COMMENT 'This is the staging page view table' + |PARTITIONED BY (dt STRING COMMENT 'date type', hour STRING COMMENT 'hour of the day') + |ROW FORMAT SERDE 'parquet.hive.serde.ParquetHiveSerDe' + | STORED AS + | INPUTFORMAT 'parquet.hive.DeprecatedParquetInputFormat' + | OUTPUTFORMAT 'parquet.hive.DeprecatedParquetOutputFormat' + |LOCATION '/user/external/page_view' + |TBLPROPERTIES ('p1'='v1', 'p2'='v2') + |AS SELECT * FROM src""".stripMargin + + val (desc, exists) = extractTableDesc(s2) + assert(exists) + assert(desc.identifier.database == Some("mydb")) + assert(desc.identifier.table == "page_view") + assert(desc.tableType == CatalogTableType.EXTERNAL_TABLE) + assert(desc.storage.locationUri == Some("/user/external/page_view")) + assert(desc.schema == + CatalogColumn("viewtime", "int") :: + CatalogColumn("userid", "bigint") :: + CatalogColumn("page_url", "string") :: + CatalogColumn("referrer_url", "string") :: + CatalogColumn("ip", "string", comment = Some("IP Address of the User")) :: + CatalogColumn("country", "string", comment = Some("country of origination")) :: Nil) + // TODO will be SQLText + assert(desc.viewText == Option("This is the staging page view table")) + assert(desc.partitionColumns == + CatalogColumn("dt", "string", comment = Some("date type")) :: + CatalogColumn("hour", "string", comment = Some("hour of the day")) :: Nil) + assert(desc.storage.serdeProperties == Map()) + assert(desc.storage.inputFormat == Some("parquet.hive.DeprecatedParquetInputFormat")) + assert(desc.storage.outputFormat == Some("parquet.hive.DeprecatedParquetOutputFormat")) + assert(desc.storage.serde == Some("parquet.hive.serde.ParquetHiveSerDe")) + assert(desc.properties == Map(("p1", "v1"), ("p2", "v2"))) + } + + test("Test CTAS #3") { + val s3 = """CREATE TABLE page_view AS SELECT * FROM src""" + val (desc, exists) = extractTableDesc(s3) + assert(exists == false) + assert(desc.identifier.database == None) + assert(desc.identifier.table == "page_view") + assert(desc.tableType == CatalogTableType.MANAGED_TABLE) + assert(desc.storage.locationUri == None) + assert(desc.schema == Seq.empty[CatalogColumn]) + assert(desc.viewText == None) // TODO will be SQLText + assert(desc.storage.serdeProperties == Map()) + assert(desc.storage.inputFormat == Some("org.apache.hadoop.mapred.TextInputFormat")) + assert(desc.storage.outputFormat == + Some("org.apache.hadoop.hive.ql.io.IgnoreKeyTextOutputFormat")) + assert(desc.storage.serde.isEmpty) + assert(desc.properties == Map()) + } + + test("Test CTAS #4") { + val s4 = + """CREATE TABLE page_view + |STORED BY 'storage.handler.class.name' AS SELECT * FROM src""".stripMargin + intercept[AnalysisException] { + extractTableDesc(s4) + } + } + + test("Test CTAS #5") { + val s5 = """CREATE TABLE ctas2 + | ROW FORMAT SERDE "org.apache.hadoop.hive.serde2.columnar.ColumnarSerDe" + | WITH SERDEPROPERTIES("serde_p1"="p1","serde_p2"="p2") + | STORED AS RCFile + | TBLPROPERTIES("tbl_p1"="p11", "tbl_p2"="p22") + | AS + | SELECT key, value + | FROM src + | ORDER BY key, value""".stripMargin + val (desc, exists) = extractTableDesc(s5) + assert(exists == false) + assert(desc.identifier.database == None) + assert(desc.identifier.table == "ctas2") + assert(desc.tableType == CatalogTableType.MANAGED_TABLE) + assert(desc.storage.locationUri == None) + assert(desc.schema == Seq.empty[CatalogColumn]) + assert(desc.viewText == None) // TODO will be SQLText + assert(desc.storage.serdeProperties == Map(("serde_p1" -> "p1"), ("serde_p2" -> "p2"))) + assert(desc.storage.inputFormat == Some("org.apache.hadoop.hive.ql.io.RCFileInputFormat")) + assert(desc.storage.outputFormat == Some("org.apache.hadoop.hive.ql.io.RCFileOutputFormat")) + assert(desc.storage.serde == Some("org.apache.hadoop.hive.serde2.columnar.ColumnarSerDe")) + assert(desc.properties == Map(("tbl_p1" -> "p11"), ("tbl_p2" -> "p22"))) + } + + test("Invalid interval term should throw AnalysisException") { + def assertError(sql: String, errorMessage: String): Unit = { + val e = intercept[AnalysisException] { + parser.parsePlan(sql) + } + assert(e.getMessage.contains(errorMessage)) + } + assertError("select interval '42-32' year to month", + "month 32 outside range [0, 11]") + assertError("select interval '5 49:12:15' day to second", + "hour 49 outside range [0, 23]") + assertError("select interval '.1111111111' second", + "nanosecond 1111111111 outside range") + } + + test("use native json_tuple instead of hive's UDTF in LATERAL VIEW") { + val plan = parser.parsePlan( + """ + |SELECT * + |FROM (SELECT '{"f1": "value1", "f2": 12}' json) test + |LATERAL VIEW json_tuple(json, 'f1', 'f2') jt AS a, b + """.stripMargin) + + assert(plan.children.head.asInstanceOf[Generate].generator.isInstanceOf[JsonTuple]) + } + + test("transform query spec") { + val plan1 = parser.parsePlan("select transform(a, b) using 'func' from e where f < 10") + .asInstanceOf[ScriptTransformation].copy(ioschema = null) + val plan2 = parser.parsePlan("map a, b using 'func' as c, d from e") + .asInstanceOf[ScriptTransformation].copy(ioschema = null) + val plan3 = parser.parsePlan("reduce a, b using 'func' as (c: int, d decimal(10, 0)) from e") + .asInstanceOf[ScriptTransformation].copy(ioschema = null) + + val p = ScriptTransformation( + Seq(UnresolvedAttribute("a"), UnresolvedAttribute("b")), + "func", Seq.empty, plans.table("e"), null) + + comparePlans(plan1, + p.copy(child = p.child.where('f < 10), output = Seq('key.string, 'value.string))) + comparePlans(plan2, + p.copy(output = Seq('c.string, 'd.string))) + comparePlans(plan3, + p.copy(output = Seq('c.int, 'd.decimal(10, 0)))) + } + + test("use backticks in output of Script Transform") { + val plan = parser.parsePlan( + """SELECT `t`.`thing1` + |FROM (SELECT TRANSFORM (`parquet_t1`.`key`, `parquet_t1`.`value`) + |USING 'cat' AS (`thing1` int, `thing2` string) FROM `default`.`parquet_t1`) AS t + """.stripMargin) + } + + test("use backticks in output of Generator") { + val plan = parser.parsePlan( + """ + |SELECT `gentab2`.`gencol2` + |FROM `default`.`src` + |LATERAL VIEW explode(array(array(1, 2, 3))) `gentab1` AS `gencol1` + |LATERAL VIEW explode(`gentab1`.`gencol1`) `gentab2` AS `gencol2` + """.stripMargin) + } + + test("use escaped backticks in output of Generator") { + val plan = parser.parsePlan( + """ + |SELECT `gen``tab2`.`gen``col2` + |FROM `default`.`src` + |LATERAL VIEW explode(array(array(1, 2, 3))) `gen``tab1` AS `gen``col1` + |LATERAL VIEW explode(`gen``tab1`.`gen``col1`) `gen``tab2` AS `gen``col2` + """.stripMargin) + } + + test("create view -- basic") { + val v1 = "CREATE VIEW view1 AS SELECT * FROM tab1" + val (desc, exists) = extractTableDesc(v1) + assert(!exists) + assert(desc.identifier.database.isEmpty) + assert(desc.identifier.table == "view1") + assert(desc.tableType == CatalogTableType.VIRTUAL_VIEW) + assert(desc.storage.locationUri.isEmpty) + assert(desc.schema == Seq.empty[CatalogColumn]) + assert(desc.viewText.contains("SELECT * FROM tab1")) + assert(desc.viewOriginalText.contains("SELECT * FROM tab1")) + assert(desc.storage.serdeProperties == Map()) + assert(desc.storage.inputFormat.isEmpty) + assert(desc.storage.outputFormat.isEmpty) + assert(desc.storage.serde.isEmpty) + assert(desc.properties == Map()) + } + + test("create view - full") { + val v1 = + """ + |CREATE OR REPLACE VIEW IF NOT EXISTS view1 + |(col1, col3) + |COMMENT 'I cannot spell' + |TBLPROPERTIES('prop1Key'="prop1Val") + |AS SELECT * FROM tab1 + """.stripMargin + val (desc, exists) = extractTableDesc(v1) + assert(exists) + assert(desc.identifier.database.isEmpty) + assert(desc.identifier.table == "view1") + assert(desc.tableType == CatalogTableType.VIRTUAL_VIEW) + assert(desc.storage.locationUri.isEmpty) + assert(desc.schema == + CatalogColumn("col1", null, nullable = true, None) :: + CatalogColumn("col3", null, nullable = true, None) :: Nil) + assert(desc.viewText.contains("SELECT * FROM tab1")) + assert(desc.viewOriginalText.contains("SELECT * FROM tab1")) + assert(desc.storage.serdeProperties == Map()) + assert(desc.storage.inputFormat.isEmpty) + assert(desc.storage.outputFormat.isEmpty) + assert(desc.storage.serde.isEmpty) + assert(desc.properties == Map("prop1Key" -> "prop1Val")) + } + + test("create view -- partitioned view") { + val v1 = "CREATE VIEW view1 partitioned on (ds, hr) as select * from srcpart" + intercept[ParseException] { + parser.parsePlan(v1).isInstanceOf[HiveNativeCommand] + } + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveQlSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveQlSuite.scala deleted file mode 100644 index a8a0d6b8de..0000000000 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveQlSuite.scala +++ /dev/null @@ -1,254 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.hive - -import org.apache.hadoop.hive.serde.serdeConstants - -import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute -import org.apache.spark.sql.catalyst.catalog.{CatalogColumn, CatalogTable, CatalogTableType} -import org.apache.spark.sql.catalyst.dsl.expressions._ -import org.apache.spark.sql.catalyst.dsl.plans -import org.apache.spark.sql.catalyst.dsl.plans.DslLogicalPlan -import org.apache.spark.sql.catalyst.expressions.JsonTuple -import org.apache.spark.sql.catalyst.plans.PlanTest -import org.apache.spark.sql.catalyst.plans.logical.{Generate, ScriptTransformation} -import org.apache.spark.sql.hive.execution.HiveSqlParser - -class HiveQlSuite extends PlanTest { - val parser = HiveSqlParser - - private def extractTableDesc(sql: String): (CatalogTable, Boolean) = { - parser.parsePlan(sql).collect { - case CreateTableAsSelect(desc, child, allowExisting) => (desc, allowExisting) - }.head - } - - test("Test CTAS #1") { - val s1 = - """CREATE EXTERNAL TABLE IF NOT EXISTS mydb.page_view - |(viewTime INT, - |userid BIGINT, - |page_url STRING, - |referrer_url STRING, - |ip STRING COMMENT 'IP Address of the User', - |country STRING COMMENT 'country of origination') - |COMMENT 'This is the staging page view table' - |PARTITIONED BY (dt STRING COMMENT 'date type', hour STRING COMMENT 'hour of the day') - |ROW FORMAT DELIMITED FIELDS TERMINATED BY '\054' STORED AS RCFILE - |LOCATION '/user/external/page_view' - |TBLPROPERTIES ('p1'='v1', 'p2'='v2') - |AS SELECT * FROM src""".stripMargin - - val (desc, exists) = extractTableDesc(s1) - assert(exists) - assert(desc.identifier.database == Some("mydb")) - assert(desc.identifier.table == "page_view") - assert(desc.tableType == CatalogTableType.EXTERNAL_TABLE) - assert(desc.storage.locationUri == Some("/user/external/page_view")) - assert(desc.schema == - CatalogColumn("viewtime", "int") :: - CatalogColumn("userid", "bigint") :: - CatalogColumn("page_url", "string") :: - CatalogColumn("referrer_url", "string") :: - CatalogColumn("ip", "string", comment = Some("IP Address of the User")) :: - CatalogColumn("country", "string", comment = Some("country of origination")) :: Nil) - // TODO will be SQLText - assert(desc.viewText == Option("This is the staging page view table")) - assert(desc.partitionColumns == - CatalogColumn("dt", "string", comment = Some("date type")) :: - CatalogColumn("hour", "string", comment = Some("hour of the day")) :: Nil) - assert(desc.storage.serdeProperties == - Map((serdeConstants.SERIALIZATION_FORMAT, "\u002C"), (serdeConstants.FIELD_DELIM, "\u002C"))) - assert(desc.storage.inputFormat == Some("org.apache.hadoop.hive.ql.io.RCFileInputFormat")) - assert(desc.storage.outputFormat == Some("org.apache.hadoop.hive.ql.io.RCFileOutputFormat")) - assert(desc.storage.serde == - Some("org.apache.hadoop.hive.serde2.columnar.LazyBinaryColumnarSerDe")) - assert(desc.properties == Map(("p1", "v1"), ("p2", "v2"))) - } - - test("Test CTAS #2") { - val s2 = - """CREATE EXTERNAL TABLE IF NOT EXISTS mydb.page_view - |(viewTime INT, - |userid BIGINT, - |page_url STRING, - |referrer_url STRING, - |ip STRING COMMENT 'IP Address of the User', - |country STRING COMMENT 'country of origination') - |COMMENT 'This is the staging page view table' - |PARTITIONED BY (dt STRING COMMENT 'date type', hour STRING COMMENT 'hour of the day') - |ROW FORMAT SERDE 'parquet.hive.serde.ParquetHiveSerDe' - | STORED AS - | INPUTFORMAT 'parquet.hive.DeprecatedParquetInputFormat' - | OUTPUTFORMAT 'parquet.hive.DeprecatedParquetOutputFormat' - |LOCATION '/user/external/page_view' - |TBLPROPERTIES ('p1'='v1', 'p2'='v2') - |AS SELECT * FROM src""".stripMargin - - val (desc, exists) = extractTableDesc(s2) - assert(exists) - assert(desc.identifier.database == Some("mydb")) - assert(desc.identifier.table == "page_view") - assert(desc.tableType == CatalogTableType.EXTERNAL_TABLE) - assert(desc.storage.locationUri == Some("/user/external/page_view")) - assert(desc.schema == - CatalogColumn("viewtime", "int") :: - CatalogColumn("userid", "bigint") :: - CatalogColumn("page_url", "string") :: - CatalogColumn("referrer_url", "string") :: - CatalogColumn("ip", "string", comment = Some("IP Address of the User")) :: - CatalogColumn("country", "string", comment = Some("country of origination")) :: Nil) - // TODO will be SQLText - assert(desc.viewText == Option("This is the staging page view table")) - assert(desc.partitionColumns == - CatalogColumn("dt", "string", comment = Some("date type")) :: - CatalogColumn("hour", "string", comment = Some("hour of the day")) :: Nil) - assert(desc.storage.serdeProperties == Map()) - assert(desc.storage.inputFormat == Some("parquet.hive.DeprecatedParquetInputFormat")) - assert(desc.storage.outputFormat == Some("parquet.hive.DeprecatedParquetOutputFormat")) - assert(desc.storage.serde == Some("parquet.hive.serde.ParquetHiveSerDe")) - assert(desc.properties == Map(("p1", "v1"), ("p2", "v2"))) - } - - test("Test CTAS #3") { - val s3 = """CREATE TABLE page_view AS SELECT * FROM src""" - val (desc, exists) = extractTableDesc(s3) - assert(exists == false) - assert(desc.identifier.database == None) - assert(desc.identifier.table == "page_view") - assert(desc.tableType == CatalogTableType.MANAGED_TABLE) - assert(desc.storage.locationUri == None) - assert(desc.schema == Seq.empty[CatalogColumn]) - assert(desc.viewText == None) // TODO will be SQLText - assert(desc.storage.serdeProperties == Map()) - assert(desc.storage.inputFormat == Some("org.apache.hadoop.mapred.TextInputFormat")) - assert(desc.storage.outputFormat == - Some("org.apache.hadoop.hive.ql.io.IgnoreKeyTextOutputFormat")) - assert(desc.storage.serde.isEmpty) - assert(desc.properties == Map()) - } - - test("Test CTAS #4") { - val s4 = - """CREATE TABLE page_view - |STORED BY 'storage.handler.class.name' AS SELECT * FROM src""".stripMargin - intercept[AnalysisException] { - extractTableDesc(s4) - } - } - - test("Test CTAS #5") { - val s5 = """CREATE TABLE ctas2 - | ROW FORMAT SERDE "org.apache.hadoop.hive.serde2.columnar.ColumnarSerDe" - | WITH SERDEPROPERTIES("serde_p1"="p1","serde_p2"="p2") - | STORED AS RCFile - | TBLPROPERTIES("tbl_p1"="p11", "tbl_p2"="p22") - | AS - | SELECT key, value - | FROM src - | ORDER BY key, value""".stripMargin - val (desc, exists) = extractTableDesc(s5) - assert(exists == false) - assert(desc.identifier.database == None) - assert(desc.identifier.table == "ctas2") - assert(desc.tableType == CatalogTableType.MANAGED_TABLE) - assert(desc.storage.locationUri == None) - assert(desc.schema == Seq.empty[CatalogColumn]) - assert(desc.viewText == None) // TODO will be SQLText - assert(desc.storage.serdeProperties == Map(("serde_p1" -> "p1"), ("serde_p2" -> "p2"))) - assert(desc.storage.inputFormat == Some("org.apache.hadoop.hive.ql.io.RCFileInputFormat")) - assert(desc.storage.outputFormat == Some("org.apache.hadoop.hive.ql.io.RCFileOutputFormat")) - assert(desc.storage.serde == Some("org.apache.hadoop.hive.serde2.columnar.ColumnarSerDe")) - assert(desc.properties == Map(("tbl_p1" -> "p11"), ("tbl_p2" -> "p22"))) - } - - test("Invalid interval term should throw AnalysisException") { - def assertError(sql: String, errorMessage: String): Unit = { - val e = intercept[AnalysisException] { - parser.parsePlan(sql) - } - assert(e.getMessage.contains(errorMessage)) - } - assertError("select interval '42-32' year to month", - "month 32 outside range [0, 11]") - assertError("select interval '5 49:12:15' day to second", - "hour 49 outside range [0, 23]") - assertError("select interval '.1111111111' second", - "nanosecond 1111111111 outside range") - } - - test("use native json_tuple instead of hive's UDTF in LATERAL VIEW") { - val plan = parser.parsePlan( - """ - |SELECT * - |FROM (SELECT '{"f1": "value1", "f2": 12}' json) test - |LATERAL VIEW json_tuple(json, 'f1', 'f2') jt AS a, b - """.stripMargin) - - assert(plan.children.head.asInstanceOf[Generate].generator.isInstanceOf[JsonTuple]) - } - - test("transform query spec") { - val plan1 = parser.parsePlan("select transform(a, b) using 'func' from e where f < 10") - .asInstanceOf[ScriptTransformation].copy(ioschema = null) - val plan2 = parser.parsePlan("map a, b using 'func' as c, d from e") - .asInstanceOf[ScriptTransformation].copy(ioschema = null) - val plan3 = parser.parsePlan("reduce a, b using 'func' as (c: int, d decimal(10, 0)) from e") - .asInstanceOf[ScriptTransformation].copy(ioschema = null) - - val p = ScriptTransformation( - Seq(UnresolvedAttribute("a"), UnresolvedAttribute("b")), - "func", Seq.empty, plans.table("e"), null) - - comparePlans(plan1, - p.copy(child = p.child.where('f < 10), output = Seq('key.string, 'value.string))) - comparePlans(plan2, - p.copy(output = Seq('c.string, 'd.string))) - comparePlans(plan3, - p.copy(output = Seq('c.int, 'd.decimal(10, 0)))) - } - - test("use backticks in output of Script Transform") { - val plan = parser.parsePlan( - """SELECT `t`.`thing1` - |FROM (SELECT TRANSFORM (`parquet_t1`.`key`, `parquet_t1`.`value`) - |USING 'cat' AS (`thing1` int, `thing2` string) FROM `default`.`parquet_t1`) AS t - """.stripMargin) - } - - test("use backticks in output of Generator") { - val plan = parser.parsePlan( - """ - |SELECT `gentab2`.`gencol2` - |FROM `default`.`src` - |LATERAL VIEW explode(array(array(1, 2, 3))) `gentab1` AS `gencol1` - |LATERAL VIEW explode(`gentab1`.`gencol1`) `gentab2` AS `gencol2` - """.stripMargin) - } - - test("use escaped backticks in output of Generator") { - val plan = parser.parsePlan( - """ - |SELECT `gen``tab2`.`gen``col2` - |FROM `default`.`src` - |LATERAL VIEW explode(array(array(1, 2, 3))) `gen``tab1` AS `gen``col1` - |LATERAL VIEW explode(`gen``tab1`.`gen``col1`) `gen``tab2` AS `gen``col2` - """.stripMargin) - } -} -- cgit v1.2.3 From 78c1076d0421cc41cbdb788f38b13c9a00e8f561 Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Tue, 5 Apr 2016 22:37:51 -0700 Subject: [SPARK-14252] Executors do not try to download remote cached blocks ## What changes were proposed in this pull request? As mentioned in the ticket this was because one get path in the refactored `BlockManager` did not check for remote storage. ## How was this patch tested? Unit test, also verified manually with reproduction in the ticket. cc JoshRosen Author: Eric Liang Closes #12193 from ericl/spark-14252. --- .../main/scala/org/apache/spark/storage/BlockManager.scala | 8 ++++++++ .../scala/org/apache/spark/storage/BlockManagerSuite.scala | 13 +++++++++++++ 2 files changed, 21 insertions(+) diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 9608418b43..35a6c63ad1 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -643,6 +643,14 @@ private[spark] class BlockManager( level: StorageLevel, classTag: ClassTag[T], makeIterator: () => Iterator[T]): Either[BlockResult, Iterator[T]] = { + // Attempt to read the block from local or remote storage. If it's present, then we don't need + // to go through the local-get-or-put path. + get(blockId) match { + case Some(block) => + return Left(block) + case _ => + // Need to compute the block. + } // Initially we hold no locks on this block. doPutIterator(blockId, makeIterator, level, classTag, keepReadLock = true) match { case None => diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index 32c00ac687..66b28de10f 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -515,6 +515,19 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE } } + test("SPARK-14252: getOrElseUpdate should still read from remote storage") { + store = makeBlockManager(8000, "executor1") + store2 = makeBlockManager(8000, "executor2") + val list1 = List(new Array[Byte](4000)) + store2.putIterator( + "list1", list1.iterator, StorageLevel.MEMORY_ONLY, tellMaster = true) + assert(store.getOrElseUpdate( + "list1", + StorageLevel.MEMORY_ONLY, + ClassTag.Any, + () => throw new AssertionError("attempted to compute locally")).isLeft) + } + test("in-memory LRU storage") { testInMemoryLRUStorage(StorageLevel.MEMORY_ONLY) } -- cgit v1.2.3 From 25a4c8e0c5c63ca4722b1da6182e0e0f0f48b73a Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Wed, 6 Apr 2016 15:48:28 +0200 Subject: [SPARK-14396][BUILD][HOT] Fix compilation against Scala 2.10 #### What changes were proposed in this pull request? This PR is to fix the compilation errors in Scala 2.10 build, as shown in the link: https://amplab.cs.berkeley.edu/jenkins/job/spark-master-compile-maven-scala-2.10/735/console ``` [error] /home/jenkins/workspace/spark-master-compile-maven-scala-2.10/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDDLCommandSuite.scala:266: value contains is not a member of Option[String] [error] assert(desc.viewText.contains("SELECT * FROM tab1")) [error] ^ [error] /home/jenkins/workspace/spark-master-compile-maven-scala-2.10/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDDLCommandSuite.scala:267: value contains is not a member of Option[String] [error] assert(desc.viewOriginalText.contains("SELECT * FROM tab1")) [error] ^ [error] /home/jenkins/workspace/spark-master-compile-maven-scala-2.10/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDDLCommandSuite.scala:293: value contains is not a member of Option[String] [error] assert(desc.viewText.contains("SELECT * FROM tab1")) [error] ^ [error] /home/jenkins/workspace/spark-master-compile-maven-scala-2.10/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDDLCommandSuite.scala:294: value contains is not a member of Option[String] [error] assert(desc.viewOriginalText.contains("SELECT * FROM tab1")) [error] ^ [error] four errors found [error] Compile failed at Apr 5, 2016 10:59:09 PM [10.502s] ``` #### How was this patch tested? Not sure how to trigger Scala 2.10 compilation in the test environment. Author: gatorsmile Closes #12201 from gatorsmile/buildBreak2.10. --- .../scala/org/apache/spark/sql/hive/HiveDDLCommandSuite.scala | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDDLCommandSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDDLCommandSuite.scala index b4e5d4adf1..c5f01da4fa 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDDLCommandSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDDLCommandSuite.scala @@ -263,8 +263,8 @@ class HiveDDLCommandSuite extends PlanTest { assert(desc.tableType == CatalogTableType.VIRTUAL_VIEW) assert(desc.storage.locationUri.isEmpty) assert(desc.schema == Seq.empty[CatalogColumn]) - assert(desc.viewText.contains("SELECT * FROM tab1")) - assert(desc.viewOriginalText.contains("SELECT * FROM tab1")) + assert(desc.viewText == Option("SELECT * FROM tab1")) + assert(desc.viewOriginalText == Option("SELECT * FROM tab1")) assert(desc.storage.serdeProperties == Map()) assert(desc.storage.inputFormat.isEmpty) assert(desc.storage.outputFormat.isEmpty) @@ -290,8 +290,8 @@ class HiveDDLCommandSuite extends PlanTest { assert(desc.schema == CatalogColumn("col1", null, nullable = true, None) :: CatalogColumn("col3", null, nullable = true, None) :: Nil) - assert(desc.viewText.contains("SELECT * FROM tab1")) - assert(desc.viewOriginalText.contains("SELECT * FROM tab1")) + assert(desc.viewText == Option("SELECT * FROM tab1")) + assert(desc.viewOriginalText == Option("SELECT * FROM tab1")) assert(desc.storage.serdeProperties == Map()) assert(desc.storage.inputFormat.isEmpty) assert(desc.storage.outputFormat.isEmpty) -- cgit v1.2.3 From 24015199f46b5934d3000960538539495e025acf Mon Sep 17 00:00:00 2001 From: Victor Chima Date: Wed, 6 Apr 2016 15:27:46 +0100 Subject: Added omitted word in error message ## What changes were proposed in this pull request? Added an omitted word in the error message displayed by the Graphx Pregel API when `maxIterations <= 0` ## How was this patch tested? Manual test Author: Victor Chima Closes #12205 from blazy2k9/hotfix/pregel-error-message. --- graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala b/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala index d2e51d2ec4..646462b4a8 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala @@ -119,7 +119,7 @@ object Pregel extends Logging { mergeMsg: (A, A) => A) : Graph[VD, ED] = { - require(maxIterations > 0, s"Maximum of iterations must be greater than 0," + + require(maxIterations > 0, s"Maximum number of iterations must be greater than 0," + s" but got ${maxIterations}") var g = graph.mapVertices((vid, vdata) => vprog(vid, vdata, initialMsg)).cache() -- cgit v1.2.3 From 5e64dab868be1a0d415fb6d6dd3463e7171fdd1a Mon Sep 17 00:00:00 2001 From: Prajwal Tuladhar Date: Wed, 6 Apr 2016 15:28:52 +0100 Subject: [SPARK-14430][BUILD] use https while downloading binaries from build/mvn ## What changes were proposed in this pull request? `./build/mvn` file was downloading binaries in non HTTPS mode. This PR tends to fix it. ## How was this patch tested? By running `./build/mvn clean package` locally Author: Prajwal Tuladhar Closes #12182 from infynyxx/mvn_use_https. --- build/mvn | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/build/mvn b/build/mvn index 58058c04b8..41c0850ccb 100755 --- a/build/mvn +++ b/build/mvn @@ -72,7 +72,7 @@ install_mvn() { local MVN_VERSION="3.3.9" install_app \ - "http://archive.apache.org/dist/maven/maven-3/${MVN_VERSION}/binaries" \ + "https://archive.apache.org/dist/maven/maven-3/${MVN_VERSION}/binaries" \ "apache-maven-${MVN_VERSION}-bin.tar.gz" \ "apache-maven-${MVN_VERSION}/bin/mvn" @@ -84,7 +84,7 @@ install_zinc() { local zinc_path="zinc-0.3.9/bin/zinc" [ ! -f "${_DIR}/${zinc_path}" ] && ZINC_INSTALL_FLAG=1 install_app \ - "http://downloads.typesafe.com/zinc/0.3.9" \ + "https://downloads.typesafe.com/zinc/0.3.9" \ "zinc-0.3.9.tgz" \ "${zinc_path}" ZINC_BIN="${_DIR}/${zinc_path}" @@ -100,7 +100,7 @@ install_scala() { local scala_bin="${_DIR}/scala-${scala_version}/bin/scala" install_app \ - "http://downloads.typesafe.com/scala/${scala_version}" \ + "https://downloads.typesafe.com/scala/${scala_version}" \ "scala-${scala_version}.tgz" \ "scala-${scala_version}/bin/scala" -- cgit v1.2.3 From 59236e5c5b9d24f90fcf8d09b23ae8b06355657e Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Wed, 6 Apr 2016 10:05:02 -0700 Subject: [SPARK-14288][SQL] Memory Sink for streaming This PR exposes the internal testing `MemorySink` though the data source API. This will allow users to easily test streaming applications in the Spark shell or other local tests. Usage: ```scala inputStream.write .format("memory") .queryName("memStream") .startStream() // Now you can query the result of the stream here. sqlContext.table("memStream") ``` The most complicated part of the logic is choosing the checkpoint directory. There are a few requirements we are attempting to satisfy here: - when working in the shell locally, it should just work with no extra configuration. - when working on a cluster you should be able to make it easily create the checkpoint on a distributed file system so you can test aggregation (state checkpoints are also stored in this directory and must be accessible from workers). - it should be clear that you can't resume since the data is just in memory. The chosen algorithm proceeds as follows: - the user gives a checkpoint directory, use it - if the conf has a checkpoint location, use `$location/$queryName` - if neither, create a local directory - always check to make sure there are no offsets written to the directory Author: Michael Armbrust Closes #12119 from marmbrus/memorySink. --- .../org/apache/spark/sql/DataFrameWriter.scala | 79 ++++++++++++++++----- .../spark/sql/execution/SparkStrategies.scala | 6 ++ .../spark/sql/execution/streaming/memory.scala | 8 +++ .../scala/org/apache/spark/sql/QueryTest.scala | 2 + .../spark/sql/streaming/MemorySinkSuite.scala | 82 ++++++++++++++++++++++ 5 files changed, 159 insertions(+), 18 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/streaming/MemorySinkSuite.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index 3332a997cd..54d250867f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -29,8 +29,10 @@ import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation import org.apache.spark.sql.catalyst.plans.logical.{InsertIntoTable, Project} import org.apache.spark.sql.execution.datasources.{BucketSpec, CreateTableUsingAsSelect, DataSource} import org.apache.spark.sql.execution.datasources.jdbc.JdbcUtils -import org.apache.spark.sql.execution.streaming.StreamExecution +import org.apache.spark.sql.execution.streaming.{MemoryPlan, MemorySink, StreamExecution} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.HadoopFsRelation +import org.apache.spark.util.Utils /** * :: Experimental :: @@ -275,23 +277,64 @@ final class DataFrameWriter private[sql](df: DataFrame) { * @since 2.0.0 */ def startStream(): ContinuousQuery = { - val dataSource = - DataSource( - df.sqlContext, - className = source, - options = extraOptions.toMap, - partitionColumns = normalizedParCols.getOrElse(Nil)) - - val queryName = extraOptions.getOrElse("queryName", StreamExecution.nextName) - val checkpointLocation = extraOptions.getOrElse("checkpointLocation", { - new Path(df.sqlContext.conf.checkpointLocation, queryName).toUri.toString - }) - df.sqlContext.sessionState.continuousQueryManager.startQuery( - queryName, - checkpointLocation, - df, - dataSource.createSink(), - trigger) + if (source == "memory") { + val queryName = + extraOptions.getOrElse( + "queryName", throw new AnalysisException("queryName must be specified for memory sink")) + val checkpointLocation = extraOptions.get("checkpointLocation").map { userSpecified => + new Path(userSpecified).toUri.toString + }.orElse { + val checkpointConfig: Option[String] = + df.sqlContext.conf.getConf( + SQLConf.CHECKPOINT_LOCATION, + None) + + checkpointConfig.map { location => + new Path(location, queryName).toUri.toString + } + }.getOrElse { + Utils.createTempDir(namePrefix = "memory.stream").getCanonicalPath + } + + // If offsets have already been created, we trying to resume a query. + val checkpointPath = new Path(checkpointLocation, "offsets") + val fs = checkpointPath.getFileSystem(df.sqlContext.sparkContext.hadoopConfiguration) + if (fs.exists(checkpointPath)) { + throw new AnalysisException( + s"Unable to resume query written to memory sink. Delete $checkpointPath to start over.") + } else { + checkpointPath.toUri.toString + } + + val sink = new MemorySink(df.schema) + val resultDf = Dataset.ofRows(df.sqlContext, new MemoryPlan(sink)) + resultDf.registerTempTable(queryName) + val continuousQuery = df.sqlContext.sessionState.continuousQueryManager.startQuery( + queryName, + checkpointLocation, + df, + sink, + trigger) + continuousQuery + } else { + val dataSource = + DataSource( + df.sqlContext, + className = source, + options = extraOptions.toMap, + partitionColumns = normalizedParCols.getOrElse(Nil)) + + val queryName = extraOptions.getOrElse("queryName", StreamExecution.nextName) + val checkpointLocation = extraOptions.getOrElse("checkpointLocation", { + new Path(df.sqlContext.conf.checkpointLocation, queryName).toUri.toString + }) + df.sqlContext.sessionState.continuousQueryManager.startQuery( + queryName, + checkpointLocation, + df, + dataSource.createSink(), + trigger) + } } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 5f3128d8e4..d77aba7260 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution import org.apache.spark.sql.Strategy import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.planning._ import org.apache.spark.sql.catalyst.plans._ @@ -30,6 +31,7 @@ import org.apache.spark.sql.execution.command.{DescribeCommand => RunnableDescri import org.apache.spark.sql.execution.datasources.{DescribeCommand => LogicalDescribeCommand, _} import org.apache.spark.sql.execution.exchange.ShuffleExchange import org.apache.spark.sql.execution.joins.{BuildLeft, BuildRight} +import org.apache.spark.sql.execution.streaming.MemoryPlan import org.apache.spark.sql.internal.SQLConf private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { @@ -332,6 +334,10 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case r: RunnableCommand => ExecutedCommand(r) :: Nil + case MemoryPlan(sink, output) => + val encoder = RowEncoder(sink.schema) + LocalTableScan(output, sink.allData.map(r => encoder.toRow(r).copy())) :: Nil + case logical.Distinct(child) => throw new IllegalStateException( "logical distinct operator should have been replaced by aggregate in the optimizer") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala index b652530d7c..351ef404a8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala @@ -25,6 +25,8 @@ import scala.util.control.NonFatal import org.apache.spark.internal.Logging import org.apache.spark.sql.{DataFrame, Dataset, Encoder, Row, SQLContext} import org.apache.spark.sql.catalyst.encoders.encoderFor +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.plans.logical.LeafNode import org.apache.spark.sql.types.StructType object MemoryStream { @@ -136,3 +138,9 @@ class MemorySink(val schema: StructType) extends Sink with Logging { } } +/** + * Used to query the data that has been written into a [[MemorySink]]. + */ +case class MemoryPlan(sink: MemorySink, output: Seq[Attribute]) extends LeafNode { + def this(sink: MemorySink) = this(sink, sink.schema.toAttributes) +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index 4e62fac919..48a077d0e5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -31,6 +31,7 @@ import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.execution.LogicalRDD import org.apache.spark.sql.execution.columnar.InMemoryRelation import org.apache.spark.sql.execution.datasources.LogicalRelation +import org.apache.spark.sql.execution.streaming.MemoryPlan abstract class QueryTest extends PlanTest { @@ -200,6 +201,7 @@ abstract class QueryTest extends PlanTest { logicalPlan.transform { case _: ObjectOperator => return case _: LogicalRelation => return + case _: MemoryPlan => return }.transformAllExpressions { case a: ImperativeAggregate => return } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/MemorySinkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/MemorySinkSuite.scala new file mode 100644 index 0000000000..5249aa28dd --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/MemorySinkSuite.scala @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.streaming + +import org.apache.spark.sql.{AnalysisException, Row, StreamTest} +import org.apache.spark.sql.execution.streaming._ +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.util.Utils + +class MemorySinkSuite extends StreamTest with SharedSQLContext { + import testImplicits._ + + test("registering as a table") { + val input = MemoryStream[Int] + val query = input.toDF().write + .format("memory") + .queryName("memStream") + .startStream() + input.addData(1, 2, 3) + query.processAllAvailable() + + checkDataset( + sqlContext.table("memStream").as[Int], + 1, 2, 3) + + input.addData(4, 5, 6) + query.processAllAvailable() + checkDataset( + sqlContext.table("memStream").as[Int], + 1, 2, 3, 4, 5, 6) + + query.stop() + } + + test("error when no name is specified") { + val error = intercept[AnalysisException] { + val input = MemoryStream[Int] + val query = input.toDF().write + .format("memory") + .startStream() + } + + assert(error.message contains "queryName must be specified") + } + + test("error if attempting to resume specific checkpoint") { + val location = Utils.createTempDir("steaming.checkpoint").getCanonicalPath + + val input = MemoryStream[Int] + val query = input.toDF().write + .format("memory") + .queryName("memStream") + .option("checkpointLocation", location) + .startStream() + input.addData(1, 2, 3) + query.processAllAvailable() + query.stop() + + intercept[AnalysisException] { + input.toDF().write + .format("memory") + .queryName("memStream") + .option("checkpointLocation", location) + .startStream() + } + } +} -- cgit v1.2.3 From 90ca1844865baf96656a9e5efdf56f415f2646be Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Wed, 6 Apr 2016 10:46:34 -0700 Subject: [SPARK-14418][PYSPARK] fix unpersist of Broadcast in Python ## What changes were proposed in this pull request? Currently, Broaccast.unpersist() will remove the file of broadcast, which should be the behavior of destroy(). This PR added destroy() for Broadcast in Python, to match the sematics in Scala. ## How was this patch tested? Added regression tests. Author: Davies Liu Closes #12189 from davies/py_unpersist. --- python/pyspark/broadcast.py | 17 ++++++++++++++++- python/pyspark/tests.py | 15 +++++++++++++++ 2 files changed, 31 insertions(+), 1 deletion(-) diff --git a/python/pyspark/broadcast.py b/python/pyspark/broadcast.py index 663c9abe08..a0b819220e 100644 --- a/python/pyspark/broadcast.py +++ b/python/pyspark/broadcast.py @@ -99,11 +99,26 @@ class Broadcast(object): def unpersist(self, blocking=False): """ - Delete cached copies of this broadcast on the executors. + Delete cached copies of this broadcast on the executors. If the + broadcast is used after this is called, it will need to be + re-sent to each executor. + + :param blocking: Whether to block until unpersisting has completed """ if self._jbroadcast is None: raise Exception("Broadcast can only be unpersisted in driver") self._jbroadcast.unpersist(blocking) + + def destroy(self): + """ + Destroy all data and metadata related to this broadcast variable. + Use this with caution; once a broadcast variable has been destroyed, + it cannot be used again. This method blocks until destroy has + completed. + """ + if self._jbroadcast is None: + raise Exception("Broadcast can only be destroyed in driver") + self._jbroadcast.destroy() os.unlink(self._path) def __reduce__(self): diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 40fccb8c00..15c87e22f9 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -694,6 +694,21 @@ class RDDTests(ReusedPySparkTestCase): m = self.sc.parallelize(range(1), 1).map(lambda x: len(bdata.value)).sum() self.assertEqual(N, m) + def test_unpersist(self): + N = 1000 + data = [[float(i) for i in range(300)] for i in range(N)] + bdata = self.sc.broadcast(data) # 3MB + bdata.unpersist() + m = self.sc.parallelize(range(1), 1).map(lambda x: len(bdata.value)).sum() + self.assertEqual(N, m) + bdata.destroy() + try: + self.sc.parallelize(range(1), 1).map(lambda x: len(bdata.value)).sum() + except Exception as e: + pass + else: + raise Exception("job should fail after destroy the broadcast") + def test_multiple_broadcasts(self): N = 1 << 21 b1 = self.sc.broadcast(set(range(N))) # multiple blocks in JVM -- cgit v1.2.3 From 10494feae0c2c1aca545c73ba61af6d8f743c5bb Mon Sep 17 00:00:00 2001 From: Kousuke Saruta Date: Wed, 6 Apr 2016 10:57:46 -0700 Subject: [SPARK-14426][SQL] Merge PerserUtils and ParseUtils ## What changes were proposed in this pull request? We have ParserUtils and ParseUtils which are both utility collections for use during the parsing process. Those names and what they are used for is very similar so I think we can merge them. Also, the original unescapeSQLString method may have a fault. When "\u0061" style character literals are passed to the method, it's not unescaped successfully. This patch fix the bug. ## How was this patch tested? Added a new test case. Author: Kousuke Saruta Closes #12199 from sarutak/merge-ParseUtils-and-ParserUtils. --- scalastyle-config.xml | 2 +- .../spark/sql/catalyst/parser/ParseUtils.java | 135 --------------------- .../spark/sql/catalyst/parser/ParserUtils.scala | 78 +++++++++++- .../catalyst/parser/ExpressionParserSuite.scala | 2 +- .../sql/catalyst/parser/ParserUtilsSuite.scala | 65 ++++++++++ 5 files changed, 144 insertions(+), 138 deletions(-) delete mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/parser/ParseUtils.java create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ParserUtilsSuite.scala diff --git a/scalastyle-config.xml b/scalastyle-config.xml index 37d2ecf48e..33c2cbd293 100644 --- a/scalastyle-config.xml +++ b/scalastyle-config.xml @@ -116,7 +116,7 @@ This file is divided into 3 sections: - + diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/parser/ParseUtils.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/parser/ParseUtils.java deleted file mode 100644 index 01f89112a7..0000000000 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/parser/ParseUtils.java +++ /dev/null @@ -1,135 +0,0 @@ -/** - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.parser; - -import java.nio.charset.StandardCharsets; - -/** - * A couple of utility methods that help with parsing ASTs. - * - * The 'unescapeSQLString' method in this class was take from the SemanticAnalyzer in Hive: - * ql/src/java/org/apache/hadoop/hive/ql/parse/BaseSemanticAnalyzer.java - */ -public final class ParseUtils { - private ParseUtils() { - super(); - } - - private static final int[] multiplier = new int[] {1000, 100, 10, 1}; - - @SuppressWarnings("nls") - public static String unescapeSQLString(String b) { - Character enclosure = null; - - // Some of the strings can be passed in as unicode. For example, the - // delimiter can be passed in as \002 - So, we first check if the - // string is a unicode number, else go back to the old behavior - StringBuilder sb = new StringBuilder(b.length()); - for (int i = 0; i < b.length(); i++) { - - char currentChar = b.charAt(i); - if (enclosure == null) { - if (currentChar == '\'' || b.charAt(i) == '\"') { - enclosure = currentChar; - } - // ignore all other chars outside the enclosure - continue; - } - - if (enclosure.equals(currentChar)) { - enclosure = null; - continue; - } - - if (currentChar == '\\' && (i + 6 < b.length()) && b.charAt(i + 1) == 'u') { - int code = 0; - int base = i + 2; - for (int j = 0; j < 4; j++) { - int digit = Character.digit(b.charAt(j + base), 16); - code += digit * multiplier[j]; - } - sb.append((char)code); - i += 5; - continue; - } - - if (currentChar == '\\' && (i + 4 < b.length())) { - char i1 = b.charAt(i + 1); - char i2 = b.charAt(i + 2); - char i3 = b.charAt(i + 3); - if ((i1 >= '0' && i1 <= '1') && (i2 >= '0' && i2 <= '7') - && (i3 >= '0' && i3 <= '7')) { - byte bVal = (byte) ((i3 - '0') + ((i2 - '0') * 8) + ((i1 - '0') * 8 * 8)); - byte[] bValArr = new byte[1]; - bValArr[0] = bVal; - String tmp = new String(bValArr, StandardCharsets.UTF_8); - sb.append(tmp); - i += 3; - continue; - } - } - - if (currentChar == '\\' && (i + 2 < b.length())) { - char n = b.charAt(i + 1); - switch (n) { - case '0': - sb.append("\0"); - break; - case '\'': - sb.append("'"); - break; - case '"': - sb.append("\""); - break; - case 'b': - sb.append("\b"); - break; - case 'n': - sb.append("\n"); - break; - case 'r': - sb.append("\r"); - break; - case 't': - sb.append("\t"); - break; - case 'Z': - sb.append("\u001A"); - break; - case '\\': - sb.append("\\"); - break; - // The following 2 lines are exactly what MySQL does TODO: why do we do this? - case '%': - sb.append("\\%"); - break; - case '_': - sb.append("\\_"); - break; - default: - sb.append(n); - } - i++; - } else { - sb.append(currentChar); - } - } - return sb.toString(); - } -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserUtils.scala index 90b76dc314..cb9fefec8f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserUtils.scala @@ -16,11 +16,12 @@ */ package org.apache.spark.sql.catalyst.parser +import scala.collection.mutable.StringBuilder + import org.antlr.v4.runtime.{CharStream, ParserRuleContext, Token} import org.antlr.v4.runtime.misc.Interval import org.antlr.v4.runtime.tree.TerminalNode -import org.apache.spark.sql.catalyst.parser.ParseUtils.unescapeSQLString import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, Origin} @@ -87,6 +88,81 @@ object ParserUtils { } } + /** Unescape baskslash-escaped string enclosed by quotes. */ + def unescapeSQLString(b: String): String = { + var enclosure: Character = null + val sb = new StringBuilder(b.length()) + + def appendEscapedChar(n: Char) { + n match { + case '0' => sb.append('\u0000') + case '\'' => sb.append('\'') + case '"' => sb.append('\"') + case 'b' => sb.append('\b') + case 'n' => sb.append('\n') + case 'r' => sb.append('\r') + case 't' => sb.append('\t') + case 'Z' => sb.append('\u001A') + case '\\' => sb.append('\\') + // The following 2 lines are exactly what MySQL does TODO: why do we do this? + case '%' => sb.append("\\%") + case '_' => sb.append("\\_") + case _ => sb.append(n) + } + } + + var i = 0 + val strLength = b.length + while (i < strLength) { + val currentChar = b.charAt(i) + if (enclosure == null) { + if (currentChar == '\'' || currentChar == '\"') { + enclosure = currentChar + } + } else if (enclosure == currentChar) { + enclosure = null + } else if (currentChar == '\\') { + + if ((i + 6 < strLength) && b.charAt(i + 1) == 'u') { + // \u0000 style character literals. + + val base = i + 2 + val code = (0 until 4).foldLeft(0) { (mid, j) => + val digit = Character.digit(b.charAt(j + base), 16) + (mid << 4) + digit + } + sb.append(code.asInstanceOf[Char]) + i += 5 + } else if (i + 4 < strLength) { + // \000 style character literals. + + val i1 = b.charAt(i + 1) + val i2 = b.charAt(i + 2) + val i3 = b.charAt(i + 3) + + if ((i1 >= '0' && i1 <= '1') && (i2 >= '0' && i2 <= '7') && (i3 >= '0' && i3 <= '7')) { + val tmp = ((i3 - '0') + ((i2 - '0') << 3) + ((i1 - '0') << 6)).asInstanceOf[Char] + sb.append(tmp) + i += 3 + } else { + appendEscapedChar(i1) + i += 1 + } + } else if (i + 2 < strLength) { + // escaped character literals. + val n = b.charAt(i + 1) + appendEscapedChar(n) + i += 1 + } + } else { + // non-escaped character literals. + sb.append(currentChar) + } + i += 1 + } + sb.toString() + } + /** Some syntactic sugar which makes it easier to work with optional clauses for LogicalPlans. */ implicit class EnhancedLogicalPlan(val plan: LogicalPlan) extends AnyVal { /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala index a80d29ce5d..6f40ec67ec 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala @@ -415,7 +415,7 @@ class ExpressionParserSuite extends PlanTest { assertEqual("'\\110\\145\\154\\154\\157\\041'", "Hello!") // Unicode - assertEqual("'\\u0087\\u0111\\u0114\\u0108\\u0100\\u0032\\u0058\\u0041'", "World :)") + assertEqual("'\\u0057\\u006F\\u0072\\u006C\\u0064\\u0020\\u003A\\u0029'", "World :)") } test("intervals") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ParserUtilsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ParserUtilsSuite.scala new file mode 100644 index 0000000000..d090daf7b4 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ParserUtilsSuite.scala @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.parser + +import org.apache.spark.SparkFunSuite + +class ParserUtilsSuite extends SparkFunSuite { + + import ParserUtils._ + + test("unescapeSQLString") { + // scalastyle:off nonascii + + // String not including escaped characters and enclosed by double quotes. + assert(unescapeSQLString(""""abcdefg"""") == "abcdefg") + + // String enclosed by single quotes. + assert(unescapeSQLString("""'C0FFEE'""") == "C0FFEE") + + // Strings including single escaped characters. + assert(unescapeSQLString("""'\0'""") == "\u0000") + assert(unescapeSQLString(""""\'"""") == "\'") + assert(unescapeSQLString("""'\"'""") == "\"") + assert(unescapeSQLString(""""\b"""") == "\b") + assert(unescapeSQLString("""'\n'""") == "\n") + assert(unescapeSQLString(""""\r"""") == "\r") + assert(unescapeSQLString("""'\t'""") == "\t") + assert(unescapeSQLString(""""\Z"""") == "\u001A") + assert(unescapeSQLString("""'\\'""") == "\\") + assert(unescapeSQLString(""""\%"""") == "\\%") + assert(unescapeSQLString("""'\_'""") == "\\_") + + // String including '\000' style literal characters. + assert(unescapeSQLString("""'3 + 5 = \070'""") == "3 + 5 = \u0038") + assert(unescapeSQLString(""""\000"""") == "\u0000") + + // String including invalid '\000' style literal characters. + assert(unescapeSQLString(""""\256"""") == "256") + + // String including a '\u0000' style literal characters (\u732B is a cat in Kanji). + assert(unescapeSQLString(""""How cute \u732B are"""") == "How cute \u732B are") + + // String including a surrogate pair character + // (\uD867\uDE3D is Okhotsk atka mackerel in Kanji). + assert(unescapeSQLString(""""\uD867\uDE3D is a fish"""") == "\uD867\uDE3D is a fish") + + // scalastyle:on nonascii + } + + // TODO: Add test cases for other methods in ParserUtils +} -- cgit v1.2.3 From 5abd02c02b3fa3505defdc8ab0c5c5e23a16aa80 Mon Sep 17 00:00:00 2001 From: bomeng Date: Wed, 6 Apr 2016 11:05:52 -0700 Subject: [SPARK-14429][SQL] Improve LIKE pattern in "SHOW TABLES / FUNCTIONS LIKE " DDL LIKE is commonly used in SHOW TABLES / FUNCTIONS etc DDL. In the pattern, user can use `|` or `*` as wildcards. 1. Currently, we used `replaceAll()` to replace `*` with `.*`, but the replacement was scattered in several places; I have created an utility method and use it in all the places; 2. Consistency with Hive: the pattern is case insensitive in Hive and white spaces will be trimmed, but current pattern matching does not do that. For example, suppose we have tables (t1, t2, t3), `SHOW TABLES LIKE ' T* ' ` will list all the t-tables. Please use Hive to verify it. 3. Combined with `|`, the result will be sorted. For pattern like `' B*|a* '`, it will list the result in a-b order. I've made some changes to the utility method to make sure we will get the same result as Hive does. A new method was created in StringUtil and test cases were added. andrewor14 Author: bomeng Closes #12206 from bomeng/SPARK-14429. --- .../sql/catalyst/catalog/InMemoryCatalog.scala | 13 ++++-------- .../sql/catalyst/catalog/SessionCatalog.scala | 10 +++------- .../spark/sql/catalyst/util/StringUtils.scala | 23 +++++++++++++++++++++- .../spark/sql/catalyst/util/StringUtilsSuite.scala | 12 +++++++++++ .../scala/org/apache/spark/sql/SQLQuerySuite.scala | 12 +++++------ 5 files changed, 46 insertions(+), 24 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala index 2af0107fa3..5d136b663f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala @@ -21,7 +21,7 @@ import scala.collection.mutable import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} - +import org.apache.spark.sql.catalyst.util.StringUtils /** * An in-memory (ephemeral) implementation of the system catalog. @@ -47,11 +47,6 @@ class InMemoryCatalog extends ExternalCatalog { // Database name -> description private val catalog = new scala.collection.mutable.HashMap[String, DatabaseDesc] - private def filterPattern(names: Seq[String], pattern: String): Seq[String] = { - val regex = pattern.replaceAll("\\*", ".*").r - names.filter { funcName => regex.pattern.matcher(funcName).matches() } - } - private def functionExists(db: String, funcName: String): Boolean = { requireDbExists(db) catalog(db).functions.contains(funcName) @@ -141,7 +136,7 @@ class InMemoryCatalog extends ExternalCatalog { } override def listDatabases(pattern: String): Seq[String] = synchronized { - filterPattern(listDatabases(), pattern) + StringUtils.filterPattern(listDatabases(), pattern) } override def setCurrentDatabase(db: String): Unit = { /* no-op */ } @@ -208,7 +203,7 @@ class InMemoryCatalog extends ExternalCatalog { } override def listTables(db: String, pattern: String): Seq[String] = synchronized { - filterPattern(listTables(db), pattern) + StringUtils.filterPattern(listTables(db), pattern) } // -------------------------------------------------------------------------- @@ -322,7 +317,7 @@ class InMemoryCatalog extends ExternalCatalog { override def listFunctions(db: String, pattern: String): Seq[String] = synchronized { requireDbExists(db) - filterPattern(catalog(db).functions.keysIterator.toSeq, pattern) + StringUtils.filterPattern(catalog(db).functions.keysIterator.toSeq, pattern) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala index 62a3b1c105..2acf584e8f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, NoSuchFunctionE import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionInfo} import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, SubqueryAlias} - +import org.apache.spark.sql.catalyst.util.StringUtils /** * An internal catalog that is used by a Spark Session. This internal catalog serves as a @@ -297,9 +297,7 @@ class SessionCatalog( def listTables(db: String, pattern: String): Seq[TableIdentifier] = { val dbTables = externalCatalog.listTables(db, pattern).map { t => TableIdentifier(t, Some(db)) } - val regex = pattern.replaceAll("\\*", ".*").r - val _tempTables = tempTables.keys.toSeq - .filter { t => regex.pattern.matcher(t).matches() } + val _tempTables = StringUtils.filterPattern(tempTables.keys.toSeq, pattern) .map { t => TableIdentifier(t) } dbTables ++ _tempTables } @@ -613,9 +611,7 @@ class SessionCatalog( def listFunctions(db: String, pattern: String): Seq[FunctionIdentifier] = { val dbFunctions = externalCatalog.listFunctions(db, pattern).map { f => FunctionIdentifier(f, Some(db)) } - val regex = pattern.replaceAll("\\*", ".*").r - val loadedFunctions = functionRegistry.listFunction() - .filter { f => regex.pattern.matcher(f).matches() } + val loadedFunctions = StringUtils.filterPattern(functionRegistry.listFunction(), pattern) .map { f => FunctionIdentifier(f) } // TODO: Actually, there will be dbFunctions that have been loaded into the FunctionRegistry. // So, the returned list may have two entries for the same function. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringUtils.scala index c2eeb3c565..0f65028261 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringUtils.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.util -import java.util.regex.Pattern +import java.util.regex.{Pattern, PatternSyntaxException} import org.apache.spark.unsafe.types.UTF8String @@ -52,4 +52,25 @@ object StringUtils { def isTrueString(s: UTF8String): Boolean = trueStrings.contains(s.toLowerCase) def isFalseString(s: UTF8String): Boolean = falseStrings.contains(s.toLowerCase) + + /** + * This utility can be used for filtering pattern in the "Like" of "Show Tables / Functions" DDL + * @param names the names list to be filtered + * @param pattern the filter pattern, only '*' and '|' are allowed as wildcards, others will + * follows regular expression convention, case insensitive match and white spaces + * on both ends will be ignored + * @return the filtered names list in order + */ + def filterPattern(names: Seq[String], pattern: String): Seq[String] = { + val funcNames = scala.collection.mutable.SortedSet.empty[String] + pattern.trim().split("\\|").foreach { subPattern => + try { + val regex = ("(?i)" + subPattern.replaceAll("\\*", ".*")).r + funcNames ++= names.filter{ name => regex.pattern.matcher(name).matches() } + } catch { + case _: PatternSyntaxException => + } + } + funcNames.toSeq + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/StringUtilsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/StringUtilsSuite.scala index d6f273f9e5..2ffc18a8d1 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/StringUtilsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/StringUtilsSuite.scala @@ -31,4 +31,16 @@ class StringUtilsSuite extends SparkFunSuite { assert(escapeLikeRegex("**") === "(?s)\\Q*\\E\\Q*\\E") assert(escapeLikeRegex("a_b") === "(?s)\\Qa\\E.\\Qb\\E") } + + test("filter pattern") { + val names = Seq("a1", "a2", "b2", "c3") + assert(filterPattern(names, " * ") === Seq("a1", "a2", "b2", "c3")) + assert(filterPattern(names, "*a*") === Seq("a1", "a2")) + assert(filterPattern(names, " *a* ") === Seq("a1", "a2")) + assert(filterPattern(names, " a* ") === Seq("a1", "a2")) + assert(filterPattern(names, " a.* ") === Seq("a1", "a2")) + assert(filterPattern(names, " B.*|a* ") === Seq("a1", "a2", "b2")) + assert(filterPattern(names, " a. ") === Seq("a1", "a2")) + assert(filterPattern(names, " d* ") === Nil) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 5a851b47ca..2ab7c1581c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -24,6 +24,7 @@ import org.apache.spark.AccumulatorSuite import org.apache.spark.sql.catalyst.analysis.UnresolvedException import org.apache.spark.sql.catalyst.expressions.SortOrder import org.apache.spark.sql.catalyst.plans.logical.Aggregate +import org.apache.spark.sql.catalyst.util.StringUtils import org.apache.spark.sql.execution.aggregate import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, CartesianProduct, SortMergeJoin} import org.apache.spark.sql.functions._ @@ -56,17 +57,14 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { test("show functions") { def getFunctions(pattern: String): Seq[Row] = { - val regex = java.util.regex.Pattern.compile(pattern) - sqlContext.sessionState.functionRegistry.listFunction() - .filter(regex.matcher(_).matches()).map(Row(_)) + StringUtils.filterPattern(sqlContext.sessionState.functionRegistry.listFunction(), pattern) + .map(Row(_)) } - checkAnswer(sql("SHOW functions"), getFunctions(".*")) + checkAnswer(sql("SHOW functions"), getFunctions("*")) Seq("^c*", "*e$", "log*", "*date*").foreach { pattern => // For the pattern part, only '*' and '|' are allowed as wildcards. // For '*', we need to replace it to '.*'. - checkAnswer( - sql(s"SHOW FUNCTIONS '$pattern'"), - getFunctions(pattern.replaceAll("\\*", ".*"))) + checkAnswer(sql(s"SHOW FUNCTIONS '$pattern'"), getFunctions(pattern)) } } -- cgit v1.2.3 From 3c8d8821654e3d82ef927c55272348e1bcc34a79 Mon Sep 17 00:00:00 2001 From: bomeng Date: Wed, 6 Apr 2016 11:12:48 -0700 Subject: [SPARK-14383][SQL] missing "|" in the g4 file ## What changes were proposed in this pull request? A very trivial one. It missed "|" between DISTRIBUTE and UNSET. ## How was this patch tested? I do not think it is really needed. Author: bomeng Closes #12156 from bomeng/SPARK-14383. --- .../main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 | 2 +- .../org/apache/spark/sql/execution/command/DDLCommandSuite.scala | 7 +++++++ 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index 96c170be3d..8a45b4f2e1 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -645,7 +645,7 @@ nonReserved | NO | DATA | START | TRANSACTION | COMMIT | ROLLBACK | WORK | ISOLATION | LEVEL | SNAPSHOT | READ | WRITE | ONLY - | SORT | CLUSTER | DISTRIBUTE UNSET | TBLPROPERTIES | SKEWED | STORED | DIRECTORIES | LOCATION + | SORT | CLUSTER | DISTRIBUTE | UNSET | TBLPROPERTIES | SKEWED | STORED | DIRECTORIES | LOCATION | EXCHANGE | ARCHIVE | UNARCHIVE | FILEFORMAT | TOUCH | COMPACT | CONCATENATE | CHANGE | FIRST | AFTER | CASCADE | RESTRICT | BUCKETS | CLUSTERED | SORTED | PURGE | INPUTFORMAT | OUTPUTFORMAT | INPUTDRIVER | OUTPUTDRIVER | DBPROPERTIES | DFS | TRUNCATE | METADATA | REPLICATION | COMPUTE diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala index 46dcadd690..8e63b69876 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.command import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project} import org.apache.spark.sql.execution.SparkSqlParser import org.apache.spark.sql.execution.datasources.BucketSpec import org.apache.spark.sql.types._ @@ -685,4 +686,10 @@ class DDLCommandSuite extends PlanTest { parser.parsePlan("SELECT TRANSFORM (key, value) USING 'cat' AS (tKey, tValue) FROM testData") } } + + test("SPARK-14383: DISTRIBUTE and UNSET as non-keywords") { + val sql = "SELECT distribute, unset FROM x" + val parsed = parser.parsePlan(sql) + assert(parsed.isInstanceOf[Project]) + } } -- cgit v1.2.3 From db0b06c6ea7412266158b1c710bdc8ca30e26430 Mon Sep 17 00:00:00 2001 From: Xusen Yin Date: Wed, 6 Apr 2016 11:24:11 -0700 Subject: [SPARK-13786][ML][PYSPARK] Add save/load for pyspark.ml.tuning ## What changes were proposed in this pull request? https://issues.apache.org/jira/browse/SPARK-13786 Add save/load for Python CrossValidator/Model and TrainValidationSplit/Model. ## How was this patch tested? Test with Python doctest. Author: Xusen Yin Closes #12020 from yinxusen/SPARK-13786. --- .../scala/org/apache/spark/ml/param/params.scala | 11 + .../apache/spark/ml/tuning/CrossValidator.scala | 9 + .../spark/ml/tuning/TrainValidationSplit.scala | 9 + python/pyspark/ml/tests.py | 56 ++- python/pyspark/ml/tuning.py | 407 +++++++++++++++------ python/pyspark/ml/wrapper.py | 23 ++ 6 files changed, 404 insertions(+), 111 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala index d7837b6730..c368aadd23 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala @@ -18,6 +18,7 @@ package org.apache.spark.ml.param import java.lang.reflect.Modifier +import java.util.{List => JList} import java.util.NoSuchElementException import scala.annotation.varargs @@ -833,6 +834,11 @@ final class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any]) this } + /** Put param pairs with a [[java.util.List]] of values for Python. */ + private[ml] def put(paramPairs: JList[ParamPair[_]]): this.type = { + put(paramPairs.asScala: _*) + } + /** * Optionally returns the value associated with a param. */ @@ -932,6 +938,11 @@ final class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any]) } } + /** Java-friendly method for Python API */ + private[ml] def toList: java.util.List[ParamPair[_]] = { + this.toSeq.asJava + } + /** * Number of param pairs in this map. */ diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala index 040b0093b9..4d9d4d472e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala @@ -17,6 +17,10 @@ package org.apache.spark.ml.tuning +import java.util.{List => JList} + +import scala.collection.JavaConverters._ + import com.github.fommil.netlib.F2jBLAS import org.apache.hadoop.fs.Path import org.json4s.DefaultFormats @@ -200,6 +204,11 @@ class CrossValidatorModel private[ml] ( @Since("1.5.0") val avgMetrics: Array[Double]) extends Model[CrossValidatorModel] with CrossValidatorParams with MLWritable { + /** A Python-friendly auxiliary constructor. */ + private[ml] def this(uid: String, bestModel: Model[_], avgMetrics: JList[Double]) = { + this(uid, bestModel, avgMetrics.asScala.toArray) + } + @Since("1.4.0") override def transform(dataset: DataFrame): DataFrame = { transformSchema(dataset.schema, logging = true) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala index 07330bb6b0..0f2179c2a1 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala @@ -17,6 +17,10 @@ package org.apache.spark.ml.tuning +import java.util.{List => JList} + +import scala.collection.JavaConverters._ + import org.apache.hadoop.fs.Path import org.json4s.DefaultFormats @@ -198,6 +202,11 @@ class TrainValidationSplitModel private[ml] ( @Since("1.5.0") val validationMetrics: Array[Double]) extends Model[TrainValidationSplitModel] with TrainValidationSplitParams with MLWritable { + /** A Python-friendly auxiliary constructor. */ + private[ml] def this(uid: String, bestModel: Model[_], validationMetrics: JList[Double]) = { + this(uid, bestModel, validationMetrics.asScala.toArray) + } + @Since("1.5.0") override def transform(dataset: DataFrame): DataFrame = { transformSchema(dataset.schema, logging = true) diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index f6159b2c95..e3f873e3a7 100644 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -44,7 +44,7 @@ import numpy as np from pyspark.ml import Estimator, Model, Pipeline, PipelineModel, Transformer from pyspark.ml.classification import LogisticRegression, DecisionTreeClassifier from pyspark.ml.clustering import KMeans -from pyspark.ml.evaluation import RegressionEvaluator +from pyspark.ml.evaluation import BinaryClassificationEvaluator, RegressionEvaluator from pyspark.ml.feature import * from pyspark.ml.param import Param, Params, TypeConverters from pyspark.ml.param.shared import HasMaxIter, HasInputCol, HasSeed @@ -53,7 +53,7 @@ from pyspark.ml.tuning import * from pyspark.ml.util import keyword_only from pyspark.ml.util import MLWritable, MLWriter from pyspark.ml.wrapper import JavaWrapper -from pyspark.mllib.linalg import DenseVector, SparseVector +from pyspark.mllib.linalg import Vectors, DenseVector, SparseVector from pyspark.sql import DataFrame, SQLContext, Row from pyspark.sql.functions import rand from pyspark.tests import ReusedPySparkTestCase as PySparkTestCase @@ -479,6 +479,32 @@ class CrossValidatorTests(PySparkTestCase): "Best model should have zero induced error") self.assertEqual(1.0, bestModelMetric, "Best model has R-squared of 1") + def test_save_load(self): + temp_path = tempfile.mkdtemp() + sqlContext = SQLContext(self.sc) + dataset = sqlContext.createDataFrame( + [(Vectors.dense([0.0]), 0.0), + (Vectors.dense([0.4]), 1.0), + (Vectors.dense([0.5]), 0.0), + (Vectors.dense([0.6]), 1.0), + (Vectors.dense([1.0]), 1.0)] * 10, + ["features", "label"]) + lr = LogisticRegression() + grid = ParamGridBuilder().addGrid(lr.maxIter, [0, 1]).build() + evaluator = BinaryClassificationEvaluator() + cv = CrossValidator(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator) + cvModel = cv.fit(dataset) + cvPath = temp_path + "/cv" + cv.save(cvPath) + loadedCV = CrossValidator.load(cvPath) + self.assertEqual(loadedCV.getEstimator().uid, cv.getEstimator().uid) + self.assertEqual(loadedCV.getEvaluator().uid, cv.getEvaluator().uid) + self.assertEqual(loadedCV.getEstimatorParamMaps(), cv.getEstimatorParamMaps()) + cvModelPath = temp_path + "/cvModel" + cvModel.save(cvModelPath) + loadedModel = CrossValidatorModel.load(cvModelPath) + self.assertEqual(loadedModel.bestModel.uid, cvModel.bestModel.uid) + class TrainValidationSplitTests(PySparkTestCase): @@ -530,6 +556,32 @@ class TrainValidationSplitTests(PySparkTestCase): "Best model should have zero induced error") self.assertEqual(1.0, bestModelMetric, "Best model has R-squared of 1") + def test_save_load(self): + temp_path = tempfile.mkdtemp() + sqlContext = SQLContext(self.sc) + dataset = sqlContext.createDataFrame( + [(Vectors.dense([0.0]), 0.0), + (Vectors.dense([0.4]), 1.0), + (Vectors.dense([0.5]), 0.0), + (Vectors.dense([0.6]), 1.0), + (Vectors.dense([1.0]), 1.0)] * 10, + ["features", "label"]) + lr = LogisticRegression() + grid = ParamGridBuilder().addGrid(lr.maxIter, [0, 1]).build() + evaluator = BinaryClassificationEvaluator() + tvs = TrainValidationSplit(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator) + tvsModel = tvs.fit(dataset) + tvsPath = temp_path + "/tvs" + tvs.save(tvsPath) + loadedTvs = TrainValidationSplit.load(tvsPath) + self.assertEqual(loadedTvs.getEstimator().uid, tvs.getEstimator().uid) + self.assertEqual(loadedTvs.getEvaluator().uid, tvs.getEvaluator().uid) + self.assertEqual(loadedTvs.getEstimatorParamMaps(), tvs.getEstimatorParamMaps()) + tvsModelPath = temp_path + "/tvsModel" + tvsModel.save(tvsModelPath) + loadedModel = TrainValidationSplitModel.load(tvsModelPath) + self.assertEqual(loadedModel.bestModel.uid, tvsModel.bestModel.uid) + class PersistenceTest(PySparkTestCase): diff --git a/python/pyspark/ml/tuning.py b/python/pyspark/ml/tuning.py index a528d22e18..da00f317b3 100644 --- a/python/pyspark/ml/tuning.py +++ b/python/pyspark/ml/tuning.py @@ -18,12 +18,15 @@ import itertools import numpy as np +from pyspark import SparkContext from pyspark import since from pyspark.ml import Estimator, Model from pyspark.ml.param import Params, Param, TypeConverters from pyspark.ml.param.shared import HasSeed -from pyspark.ml.util import keyword_only +from pyspark.ml.util import keyword_only, JavaMLWriter, JavaMLReader, MLReadable, MLWritable +from pyspark.ml.wrapper import JavaWrapper from pyspark.sql.functions import rand +from pyspark.mllib.common import inherit_doc, _py2java __all__ = ['ParamGridBuilder', 'CrossValidator', 'CrossValidatorModel', 'TrainValidationSplit', 'TrainValidationSplitModel'] @@ -91,7 +94,84 @@ class ParamGridBuilder(object): return [dict(zip(keys, prod)) for prod in itertools.product(*grid_values)] -class CrossValidator(Estimator, HasSeed): +class ValidatorParams(HasSeed): + """ + Common params for TrainValidationSplit and CrossValidator. + """ + + estimator = Param(Params._dummy(), "estimator", "estimator to be cross-validated") + estimatorParamMaps = Param(Params._dummy(), "estimatorParamMaps", "estimator param maps") + evaluator = Param( + Params._dummy(), "evaluator", + "evaluator used to select hyper-parameters that maximize the validator metric") + + def setEstimator(self, value): + """ + Sets the value of :py:attr:`estimator`. + """ + return self._set(estimator=value) + + def getEstimator(self): + """ + Gets the value of estimator or its default value. + """ + return self.getOrDefault(self.estimator) + + def setEstimatorParamMaps(self, value): + """ + Sets the value of :py:attr:`estimatorParamMaps`. + """ + return self._set(estimatorParamMaps=value) + + def getEstimatorParamMaps(self): + """ + Gets the value of estimatorParamMaps or its default value. + """ + return self.getOrDefault(self.estimatorParamMaps) + + def setEvaluator(self, value): + """ + Sets the value of :py:attr:`evaluator`. + """ + return self._set(evaluator=value) + + def getEvaluator(self): + """ + Gets the value of evaluator or its default value. + """ + return self.getOrDefault(self.evaluator) + + @classmethod + def _from_java_impl(cls, java_stage): + """ + Return Python estimator, estimatorParamMaps, and evaluator from a Java ValidatorParams. + """ + + # Load information from java_stage to the instance. + estimator = JavaWrapper._from_java(java_stage.getEstimator()) + evaluator = JavaWrapper._from_java(java_stage.getEvaluator()) + epms = [estimator._transfer_param_map_from_java(epm) + for epm in java_stage.getEstimatorParamMaps()] + return estimator, epms, evaluator + + def _to_java_impl(self): + """ + Return Java estimator, estimatorParamMaps, and evaluator from this Python instance. + """ + + gateway = SparkContext._gateway + cls = SparkContext._jvm.org.apache.spark.ml.param.ParamMap + + java_epms = gateway.new_array(cls, len(self.getEstimatorParamMaps())) + for idx, epm in enumerate(self.getEstimatorParamMaps()): + java_epms[idx] = self.getEstimator()._transfer_param_map_to_java(epm) + + java_estimator = self.getEstimator()._to_java() + java_evaluator = self.getEvaluator()._to_java() + return java_estimator, java_epms, java_evaluator + + +class CrossValidator(Estimator, ValidatorParams, MLReadable, MLWritable): """ K-fold cross validation. @@ -116,11 +196,6 @@ class CrossValidator(Estimator, HasSeed): .. versionadded:: 1.4.0 """ - estimator = Param(Params._dummy(), "estimator", "estimator to be cross-validated") - estimatorParamMaps = Param(Params._dummy(), "estimatorParamMaps", "estimator param maps") - evaluator = Param( - Params._dummy(), "evaluator", - "evaluator used to select hyper-parameters that maximize the cross-validated metric") numFolds = Param(Params._dummy(), "numFolds", "number of folds for cross validation", typeConverter=TypeConverters.toInt) @@ -148,51 +223,6 @@ class CrossValidator(Estimator, HasSeed): kwargs = self.setParams._input_kwargs return self._set(**kwargs) - @since("1.4.0") - def setEstimator(self, value): - """ - Sets the value of :py:attr:`estimator`. - """ - self._paramMap[self.estimator] = value - return self - - @since("1.4.0") - def getEstimator(self): - """ - Gets the value of estimator or its default value. - """ - return self.getOrDefault(self.estimator) - - @since("1.4.0") - def setEstimatorParamMaps(self, value): - """ - Sets the value of :py:attr:`estimatorParamMaps`. - """ - self._paramMap[self.estimatorParamMaps] = value - return self - - @since("1.4.0") - def getEstimatorParamMaps(self): - """ - Gets the value of estimatorParamMaps or its default value. - """ - return self.getOrDefault(self.estimatorParamMaps) - - @since("1.4.0") - def setEvaluator(self, value): - """ - Sets the value of :py:attr:`evaluator`. - """ - self._paramMap[self.evaluator] = value - return self - - @since("1.4.0") - def getEvaluator(self): - """ - Gets the value of evaluator or its default value. - """ - return self.getOrDefault(self.evaluator) - @since("1.4.0") def setNumFolds(self, value): """ @@ -236,7 +266,7 @@ class CrossValidator(Estimator, HasSeed): else: bestIndex = np.argmin(metrics) bestModel = est.fit(dataset, epm[bestIndex]) - return CrossValidatorModel(bestModel) + return self._copyValues(CrossValidatorModel(bestModel)) @since("1.4.0") def copy(self, extra=None): @@ -258,8 +288,58 @@ class CrossValidator(Estimator, HasSeed): newCV.setEvaluator(self.getEvaluator().copy(extra)) return newCV + @since("2.0.0") + def write(self): + """Returns an MLWriter instance for this ML instance.""" + return JavaMLWriter(self) + + @since("2.0.0") + def save(self, path): + """Save this ML instance to the given path, a shortcut of `write().save(path)`.""" + self.write().save(path) + + @classmethod + @since("2.0.0") + def read(cls): + """Returns an MLReader instance for this class.""" + return JavaMLReader(cls) + + @classmethod + def _from_java(cls, java_stage): + """ + Given a Java CrossValidator, create and return a Python wrapper of it. + Used for ML persistence. + """ + + estimator, epms, evaluator = super(CrossValidator, cls)._from_java_impl(java_stage) + numFolds = java_stage.getNumFolds() + seed = java_stage.getSeed() + # Create a new instance of this stage. + py_stage = cls(estimator=estimator, estimatorParamMaps=epms, evaluator=evaluator, + numFolds=numFolds, seed=seed) + py_stage._resetUid(java_stage.uid()) + return py_stage -class CrossValidatorModel(Model): + def _to_java(self): + """ + Transfer this instance to a Java CrossValidator. Used for ML persistence. + + :return: Java object equivalent to this instance. + """ + + estimator, epms, evaluator = super(CrossValidator, self)._to_java_impl() + + _java_obj = JavaWrapper._new_java_obj("org.apache.spark.ml.tuning.CrossValidator", self.uid) + _java_obj.setEstimatorParamMaps(epms) + _java_obj.setEvaluator(evaluator) + _java_obj.setEstimator(estimator) + _java_obj.setSeed(self.getSeed()) + _java_obj.setNumFolds(self.getNumFolds()) + + return _java_obj + + +class CrossValidatorModel(Model, ValidatorParams, MLReadable, MLWritable): """ Model from k-fold cross validation. @@ -289,8 +369,60 @@ class CrossValidatorModel(Model): extra = dict() return CrossValidatorModel(self.bestModel.copy(extra)) + @since("2.0.0") + def write(self): + """Returns an MLWriter instance for this ML instance.""" + return JavaMLWriter(self) + + @since("2.0.0") + def save(self, path): + """Save this ML instance to the given path, a shortcut of `write().save(path)`.""" + self.write().save(path) -class TrainValidationSplit(Estimator, HasSeed): + @classmethod + @since("2.0.0") + def read(cls): + """Returns an MLReader instance for this class.""" + return JavaMLReader(cls) + + @classmethod + def _from_java(cls, java_stage): + """ + Given a Java CrossValidatorModel, create and return a Python wrapper of it. + Used for ML persistence. + """ + + # Load information from java_stage to the instance. + bestModel = JavaWrapper._from_java(java_stage.bestModel()) + estimator, epms, evaluator = super(CrossValidatorModel, cls)._from_java_impl(java_stage) + # Create a new instance of this stage. + py_stage = cls(bestModel=bestModel)\ + .setEstimator(estimator).setEstimatorParamMaps(epms).setEvaluator(evaluator) + py_stage._resetUid(java_stage.uid()) + return py_stage + + def _to_java(self): + """ + Transfer this instance to a Java CrossValidatorModel. Used for ML persistence. + + :return: Java object equivalent to this instance. + """ + + sc = SparkContext._active_spark_context + + _java_obj = JavaWrapper._new_java_obj("org.apache.spark.ml.tuning.CrossValidatorModel", + self.uid, + self.bestModel._to_java(), + _py2java(sc, [])) + estimator, epms, evaluator = super(CrossValidatorModel, self)._to_java_impl() + + _java_obj.set("evaluator", evaluator) + _java_obj.set("estimator", estimator) + _java_obj.set("estimatorParamMaps", epms) + return _java_obj + + +class TrainValidationSplit(Estimator, ValidatorParams, MLReadable, MLWritable): """ Train-Validation-Split. @@ -315,11 +447,6 @@ class TrainValidationSplit(Estimator, HasSeed): .. versionadded:: 2.0.0 """ - estimator = Param(Params._dummy(), "estimator", "estimator to be tested") - estimatorParamMaps = Param(Params._dummy(), "estimatorParamMaps", "estimator param maps") - evaluator = Param( - Params._dummy(), "evaluator", - "evaluator used to select hyper-parameters that maximize the validated metric") trainRatio = Param(Params._dummy(), "trainRatio", "Param for ratio between train and\ validation data. Must be between 0 and 1.") @@ -347,51 +474,6 @@ class TrainValidationSplit(Estimator, HasSeed): kwargs = self.setParams._input_kwargs return self._set(**kwargs) - @since("2.0.0") - def setEstimator(self, value): - """ - Sets the value of :py:attr:`estimator`. - """ - self._paramMap[self.estimator] = value - return self - - @since("2.0.0") - def getEstimator(self): - """ - Gets the value of estimator or its default value. - """ - return self.getOrDefault(self.estimator) - - @since("2.0.0") - def setEstimatorParamMaps(self, value): - """ - Sets the value of :py:attr:`estimatorParamMaps`. - """ - self._paramMap[self.estimatorParamMaps] = value - return self - - @since("2.0.0") - def getEstimatorParamMaps(self): - """ - Gets the value of estimatorParamMaps or its default value. - """ - return self.getOrDefault(self.estimatorParamMaps) - - @since("2.0.0") - def setEvaluator(self, value): - """ - Sets the value of :py:attr:`evaluator`. - """ - self._paramMap[self.evaluator] = value - return self - - @since("2.0.0") - def getEvaluator(self): - """ - Gets the value of evaluator or its default value. - """ - return self.getOrDefault(self.evaluator) - @since("2.0.0") def setTrainRatio(self, value): """ @@ -429,7 +511,7 @@ class TrainValidationSplit(Estimator, HasSeed): else: bestIndex = np.argmin(metrics) bestModel = est.fit(dataset, epm[bestIndex]) - return TrainValidationSplitModel(bestModel) + return self._copyValues(TrainValidationSplitModel(bestModel)) @since("2.0.0") def copy(self, extra=None): @@ -451,8 +533,59 @@ class TrainValidationSplit(Estimator, HasSeed): newTVS.setEvaluator(self.getEvaluator().copy(extra)) return newTVS + @since("2.0.0") + def write(self): + """Returns an MLWriter instance for this ML instance.""" + return JavaMLWriter(self) + + @since("2.0.0") + def save(self, path): + """Save this ML instance to the given path, a shortcut of `write().save(path)`.""" + self.write().save(path) + + @classmethod + @since("2.0.0") + def read(cls): + """Returns an MLReader instance for this class.""" + return JavaMLReader(cls) + + @classmethod + def _from_java(cls, java_stage): + """ + Given a Java TrainValidationSplit, create and return a Python wrapper of it. + Used for ML persistence. + """ + + estimator, epms, evaluator = super(TrainValidationSplit, cls)._from_java_impl(java_stage) + trainRatio = java_stage.getTrainRatio() + seed = java_stage.getSeed() + # Create a new instance of this stage. + py_stage = cls(estimator=estimator, estimatorParamMaps=epms, evaluator=evaluator, + trainRatio=trainRatio, seed=seed) + py_stage._resetUid(java_stage.uid()) + return py_stage + + def _to_java(self): + """ + Transfer this instance to a Java TrainValidationSplit. Used for ML persistence. -class TrainValidationSplitModel(Model): + :return: Java object equivalent to this instance. + """ + + estimator, epms, evaluator = super(TrainValidationSplit, self)._to_java_impl() + + _java_obj = JavaWrapper._new_java_obj("org.apache.spark.ml.tuning.TrainValidationSplit", + self.uid) + _java_obj.setEstimatorParamMaps(epms) + _java_obj.setEvaluator(evaluator) + _java_obj.setEstimator(estimator) + _java_obj.setTrainRatio(self.getTrainRatio()) + _java_obj.setSeed(self.getSeed()) + + return _java_obj + + +class TrainValidationSplitModel(Model, ValidatorParams, MLReadable, MLWritable): """ Model from train validation split. """ @@ -480,19 +613,75 @@ class TrainValidationSplitModel(Model): extra = dict() return TrainValidationSplitModel(self.bestModel.copy(extra)) + @since("2.0.0") + def write(self): + """Returns an MLWriter instance for this ML instance.""" + return JavaMLWriter(self) + + @since("2.0.0") + def save(self, path): + """Save this ML instance to the given path, a shortcut of `write().save(path)`.""" + self.write().save(path) + + @classmethod + @since("2.0.0") + def read(cls): + """Returns an MLReader instance for this class.""" + return JavaMLReader(cls) + + @classmethod + def _from_java(cls, java_stage): + """ + Given a Java TrainValidationSplitModel, create and return a Python wrapper of it. + Used for ML persistence. + """ + + # Load information from java_stage to the instance. + bestModel = JavaWrapper._from_java(java_stage.bestModel()) + estimator, epms, evaluator = \ + super(TrainValidationSplitModel, cls)._from_java_impl(java_stage) + # Create a new instance of this stage. + py_stage = cls(bestModel=bestModel)\ + .setEstimator(estimator).setEstimatorParamMaps(epms).setEvaluator(evaluator) + py_stage._resetUid(java_stage.uid()) + return py_stage + + def _to_java(self): + """ + Transfer this instance to a Java TrainValidationSplitModel. Used for ML persistence. + + :return: Java object equivalent to this instance. + """ + + sc = SparkContext._active_spark_context + + _java_obj = JavaWrapper._new_java_obj( + "org.apache.spark.ml.tuning.TrainValidationSplitModel", + self.uid, + self.bestModel._to_java(), + _py2java(sc, [])) + estimator, epms, evaluator = super(TrainValidationSplitModel, self)._to_java_impl() + + _java_obj.set("evaluator", evaluator) + _java_obj.set("estimator", estimator) + _java_obj.set("estimatorParamMaps", epms) + return _java_obj + + if __name__ == "__main__": import doctest + from pyspark.context import SparkContext from pyspark.sql import SQLContext globs = globals().copy() + # The small batch size here ensures that we see multiple batches, # even in these small test examples: sc = SparkContext("local[2]", "ml.tuning tests") sqlContext = SQLContext(sc) globs['sc'] = sc globs['sqlContext'] = sqlContext - (failure_count, test_count) = doctest.testmod( - globs=globs, optionflags=doctest.ELLIPSIS) + (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) sc.stop() if failure_count: exit(-1) diff --git a/python/pyspark/ml/wrapper.py b/python/pyspark/ml/wrapper.py index 35b0eba926..ca93bf7d7d 100644 --- a/python/pyspark/ml/wrapper.py +++ b/python/pyspark/ml/wrapper.py @@ -76,6 +76,17 @@ class JavaWrapper(Params): pair = self._make_java_param_pair(param, paramMap[param]) self._java_obj.set(pair) + def _transfer_param_map_to_java(self, pyParamMap): + """ + Transforms a Python ParamMap into a Java ParamMap. + """ + paramMap = JavaWrapper._new_java_obj("org.apache.spark.ml.param.ParamMap") + for param in self.params: + if param in pyParamMap: + pair = self._make_java_param_pair(param, pyParamMap[param]) + paramMap.put([pair]) + return paramMap + def _transfer_params_from_java(self): """ Transforms the embedded params from the companion Java object. @@ -88,6 +99,18 @@ class JavaWrapper(Params): value = _java2py(sc, self._java_obj.getOrDefault(java_param)) self._paramMap[param] = value + def _transfer_param_map_from_java(self, javaParamMap): + """ + Transforms a Java ParamMap into a Python ParamMap. + """ + sc = SparkContext._active_spark_context + paramMap = dict() + for pair in javaParamMap.toList(): + param = pair.param() + if self.hasParam(str(param.name())): + paramMap[self.getParam(param.name())] = _java2py(sc, pair.value()) + return paramMap + @staticmethod def _empty_java_param_map(): """ -- cgit v1.2.3 From 8cffcb60deb82d04a5c6e144ec9927f6f7addc8b Mon Sep 17 00:00:00 2001 From: Yuhao Yang Date: Wed, 6 Apr 2016 11:36:26 -0700 Subject: [SPARK-14322][MLLIB] Use treeAggregate instead of reduce in OnlineLDAOptimizer ## What changes were proposed in this pull request? jira: https://issues.apache.org/jira/browse/SPARK-14322 OnlineLDAOptimizer uses RDD.reduce in two places where it could use treeAggregate. This can cause scalability issues. This should be an easy fix. This is also a bug since it modifies the first argument to reduce, so we should use aggregate or treeAggregate. See this line: https://github.com/apache/spark/blob/f12f11e578169b47e3f8b18b299948c0670ba585/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala#L452 and a few lines below it. ## How was this patch tested? unit tests Author: Yuhao Yang Closes #12106 from hhbyyh/ldaTreeReduce. --- .../main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala index 7491ab0d51..2b404a8651 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala @@ -451,10 +451,11 @@ final class OnlineLDAOptimizer extends LDAOptimizer { } Iterator((stat, gammaPart)) } - val statsSum: BDM[Double] = stats.map(_._1).reduce(_ += _) + val statsSum: BDM[Double] = stats.map(_._1).treeAggregate(BDM.zeros[Double](k, vocabSize))( + _ += _, _ += _) expElogbetaBc.unpersist() val gammat: BDM[Double] = breeze.linalg.DenseMatrix.vertcat( - stats.map(_._2).reduce(_ ++ _).map(_.toDenseMatrix): _*) + stats.map(_._2).flatMap(list => list).collect().map(_.toDenseMatrix): _*) val batchResult = statsSum :* expElogbeta.t // Note that this is an optimization to avoid batch.count -- cgit v1.2.3 From af73d9737874f7adaec3cd19ac889ab3badb8e2a Mon Sep 17 00:00:00 2001 From: Zheng RuiFeng Date: Wed, 6 Apr 2016 11:45:16 -0700 Subject: [SPARK-13538][ML] Add GaussianMixture to ML JIRA: https://issues.apache.org/jira/browse/SPARK-13538 ## What changes were proposed in this pull request? Add GaussianMixture and GaussianMixtureModel to ML package ## How was this patch tested? unit tests and manual tests were done. Local Scalastyle checks passed. Author: Zheng RuiFeng Author: Ruifeng Zheng Author: Joseph K. Bradley Closes #11419 from zhengruifeng/mlgmm. --- .../spark/ml/clustering/GaussianMixture.scala | 311 +++++++++++++++++++++ .../spark/ml/clustering/GaussianMixtureSuite.scala | 133 +++++++++ 2 files changed, 444 insertions(+) create mode 100644 mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala create mode 100644 mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala new file mode 100644 index 0000000000..120bf3cf9d --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala @@ -0,0 +1,311 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.clustering + +import org.apache.hadoop.fs.Path + +import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.ml.{Estimator, Model} +import org.apache.spark.ml.param.{IntParam, ParamMap, Params} +import org.apache.spark.ml.param.shared._ +import org.apache.spark.ml.util._ +import org.apache.spark.mllib.clustering.{GaussianMixture => MLlibGM, GaussianMixtureModel => MLlibGMModel} +import org.apache.spark.mllib.linalg._ +import org.apache.spark.mllib.stat.distribution.MultivariateGaussian +import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.functions.{col, udf} +import org.apache.spark.sql.types.{IntegerType, StructType} + + +/** + * Common params for GaussianMixture and GaussianMixtureModel + */ +private[clustering] trait GaussianMixtureParams extends Params with HasMaxIter with HasFeaturesCol + with HasSeed with HasPredictionCol with HasProbabilityCol with HasTol { + + /** + * Set the number of clusters to create (k). Must be > 1. Default: 2. + * @group param + */ + @Since("2.0.0") + final val k = new IntParam(this, "k", "number of clusters to create", (x: Int) => x > 1) + + /** @group getParam */ + @Since("2.0.0") + def getK: Int = $(k) + + /** + * Validates and transforms the input schema. + * @param schema input schema + * @return output schema + */ + protected def validateAndTransformSchema(schema: StructType): StructType = { + SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT) + SchemaUtils.appendColumn(schema, $(predictionCol), IntegerType) + SchemaUtils.appendColumn(schema, $(probabilityCol), new VectorUDT) + } +} + +/** + * :: Experimental :: + * Model fitted by GaussianMixture. + * @param parentModel a model trained by spark.mllib.clustering.GaussianMixture. + */ +@Since("2.0.0") +@Experimental +class GaussianMixtureModel private[ml] ( + @Since("2.0.0") override val uid: String, + private val parentModel: MLlibGMModel) + extends Model[GaussianMixtureModel] with GaussianMixtureParams with MLWritable { + + @Since("2.0.0") + override def copy(extra: ParamMap): GaussianMixtureModel = { + val copied = new GaussianMixtureModel(uid, parentModel) + copyValues(copied, extra).setParent(this.parent) + } + + @Since("2.0.0") + override def transform(dataset: DataFrame): DataFrame = { + val predUDF = udf((vector: Vector) => predict(vector)) + val probUDF = udf((vector: Vector) => predictProbability(vector)) + dataset.withColumn($(predictionCol), predUDF(col($(featuresCol)))) + .withColumn($(probabilityCol), probUDF(col($(featuresCol)))) + } + + @Since("2.0.0") + override def transformSchema(schema: StructType): StructType = { + validateAndTransformSchema(schema) + } + + private[clustering] def predict(features: Vector): Int = parentModel.predict(features) + + private[clustering] def predictProbability(features: Vector): Vector = { + Vectors.dense(parentModel.predictSoft(features)) + } + + @Since("2.0.0") + def weights: Array[Double] = parentModel.weights + + @Since("2.0.0") + def gaussians: Array[MultivariateGaussian] = parentModel.gaussians + + @Since("2.0.0") + override def write: MLWriter = new GaussianMixtureModel.GaussianMixtureModelWriter(this) + + private var trainingSummary: Option[GaussianMixtureSummary] = None + + private[clustering] def setSummary(summary: GaussianMixtureSummary): this.type = { + this.trainingSummary = Some(summary) + this + } + + /** + * Return true if there exists summary of model. + */ + @Since("2.0.0") + def hasSummary: Boolean = trainingSummary.nonEmpty + + /** + * Gets summary of model on training set. An exception is + * thrown if `trainingSummary == None`. + */ + @Since("2.0.0") + def summary: GaussianMixtureSummary = trainingSummary.getOrElse { + throw new RuntimeException( + s"No training summary available for the ${this.getClass.getSimpleName}") + } +} + +@Since("2.0.0") +object GaussianMixtureModel extends MLReadable[GaussianMixtureModel] { + + @Since("2.0.0") + override def read: MLReader[GaussianMixtureModel] = new GaussianMixtureModelReader + + @Since("2.0.0") + override def load(path: String): GaussianMixtureModel = super.load(path) + + /** [[MLWriter]] instance for [[GaussianMixtureModel]] */ + private[GaussianMixtureModel] class GaussianMixtureModelWriter( + instance: GaussianMixtureModel) extends MLWriter { + + private case class Data(weights: Array[Double], mus: Array[Vector], sigmas: Array[Matrix]) + + override protected def saveImpl(path: String): Unit = { + // Save metadata and Params + DefaultParamsWriter.saveMetadata(instance, path, sc) + // Save model data: weights and gaussians + val weights = instance.weights + val gaussians = instance.gaussians + val mus = gaussians.map(_.mu) + val sigmas = gaussians.map(_.sigma) + val data = Data(weights, mus, sigmas) + val dataPath = new Path(path, "data").toString + sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) + } + } + + private class GaussianMixtureModelReader extends MLReader[GaussianMixtureModel] { + + /** Checked against metadata when loading model */ + private val className = classOf[GaussianMixtureModel].getName + + override def load(path: String): GaussianMixtureModel = { + val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + + val dataPath = new Path(path, "data").toString + val row = sqlContext.read.parquet(dataPath).select("weights", "mus", "sigmas").head() + val weights = row.getSeq[Double](0).toArray + val mus = row.getSeq[Vector](1).toArray + val sigmas = row.getSeq[Matrix](2).toArray + require(mus.length == sigmas.length, "Length of Mu and Sigma array must match") + require(mus.length == weights.length, "Length of weight and Gaussian array must match") + + val gaussians = (mus zip sigmas).map { + case (mu, sigma) => + new MultivariateGaussian(mu, sigma) + } + val model = new GaussianMixtureModel(metadata.uid, new MLlibGMModel(weights, gaussians)) + + DefaultParamsReader.getAndSetParams(model, metadata) + model + } + } +} + +/** + * :: Experimental :: + * GaussianMixture clustering. + */ +@Since("2.0.0") +@Experimental +class GaussianMixture @Since("2.0.0") ( + @Since("2.0.0") override val uid: String) + extends Estimator[GaussianMixtureModel] with GaussianMixtureParams with DefaultParamsWritable { + + setDefault( + k -> 2, + maxIter -> 100, + tol -> 0.01) + + @Since("2.0.0") + override def copy(extra: ParamMap): GaussianMixture = defaultCopy(extra) + + @Since("2.0.0") + def this() = this(Identifiable.randomUID("GaussianMixture")) + + /** @group setParam */ + @Since("2.0.0") + def setFeaturesCol(value: String): this.type = set(featuresCol, value) + + /** @group setParam */ + @Since("2.0.0") + def setPredictionCol(value: String): this.type = set(predictionCol, value) + + /** @group setParam */ + @Since("2.0.0") + def setProbabilityCol(value: String): this.type = set(probabilityCol, value) + + /** @group setParam */ + @Since("2.0.0") + def setK(value: Int): this.type = set(k, value) + + /** @group setParam */ + @Since("2.0.0") + def setMaxIter(value: Int): this.type = set(maxIter, value) + + /** @group setParam */ + @Since("2.0.0") + def setTol(value: Double): this.type = set(tol, value) + + /** @group setParam */ + @Since("2.0.0") + def setSeed(value: Long): this.type = set(seed, value) + + @Since("2.0.0") + override def fit(dataset: DataFrame): GaussianMixtureModel = { + val rdd = dataset.select(col($(featuresCol))).rdd.map { case Row(point: Vector) => point } + + val algo = new MLlibGM() + .setK($(k)) + .setMaxIterations($(maxIter)) + .setSeed($(seed)) + .setConvergenceTol($(tol)) + val parentModel = algo.run(rdd) + val model = copyValues(new GaussianMixtureModel(uid, parentModel).setParent(this)) + val summary = new GaussianMixtureSummary(model.transform(dataset), + $(predictionCol), $(probabilityCol), $(featuresCol), $(k)) + model.setSummary(summary) + } + + @Since("2.0.0") + override def transformSchema(schema: StructType): StructType = { + validateAndTransformSchema(schema) + } +} + +@Since("2.0.0") +object GaussianMixture extends DefaultParamsReadable[GaussianMixture] { + + @Since("2.0.0") + override def load(path: String): GaussianMixture = super.load(path) +} + +/** + * :: Experimental :: + * Summary of GaussianMixture. + * + * @param predictions [[DataFrame]] produced by [[GaussianMixtureModel.transform()]] + * @param predictionCol Name for column of predicted clusters in `predictions` + * @param probabilityCol Name for column of predicted probability of each cluster in `predictions` + * @param featuresCol Name for column of features in `predictions` + * @param k Number of clusters + */ +@Since("2.0.0") +@Experimental +class GaussianMixtureSummary private[clustering] ( + @Since("2.0.0") @transient val predictions: DataFrame, + @Since("2.0.0") val predictionCol: String, + @Since("2.0.0") val probabilityCol: String, + @Since("2.0.0") val featuresCol: String, + @Since("2.0.0") val k: Int) extends Serializable { + + /** + * Cluster centers of the transformed data. + */ + @Since("2.0.0") + @transient lazy val cluster: DataFrame = predictions.select(predictionCol) + + /** + * Probability of each cluster. + */ + @Since("2.0.0") + @transient lazy val probability: DataFrame = predictions.select(probabilityCol) + + /** + * Size of (number of data points in) each cluster. + */ + @Since("2.0.0") + lazy val clusterSizes: Array[Long] = { + val sizes = Array.fill[Long](k)(0) + cluster.groupBy(predictionCol).count().select(predictionCol, "count").collect().foreach { + case Row(cluster: Int, count: Long) => sizes(cluster) = count + } + sizes + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala new file mode 100644 index 0000000000..8edd44e5f1 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala @@ -0,0 +1,133 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.clustering + +import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.util.DefaultReadWriteTest +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.sql.DataFrame + + +class GaussianMixtureSuite extends SparkFunSuite with MLlibTestSparkContext + with DefaultReadWriteTest { + + final val k = 5 + @transient var dataset: DataFrame = _ + + override def beforeAll(): Unit = { + super.beforeAll() + + dataset = KMeansSuite.generateKMeansData(sqlContext, 50, 3, k) + } + + test("default parameters") { + val gm = new GaussianMixture() + + assert(gm.getK === 2) + assert(gm.getFeaturesCol === "features") + assert(gm.getPredictionCol === "prediction") + assert(gm.getMaxIter === 100) + assert(gm.getTol === 0.01) + } + + test("set parameters") { + val gm = new GaussianMixture() + .setK(9) + .setFeaturesCol("test_feature") + .setPredictionCol("test_prediction") + .setProbabilityCol("test_probability") + .setMaxIter(33) + .setSeed(123) + .setTol(1e-3) + + assert(gm.getK === 9) + assert(gm.getFeaturesCol === "test_feature") + assert(gm.getPredictionCol === "test_prediction") + assert(gm.getProbabilityCol === "test_probability") + assert(gm.getMaxIter === 33) + assert(gm.getSeed === 123) + assert(gm.getTol === 1e-3) + } + + test("parameters validation") { + intercept[IllegalArgumentException] { + new GaussianMixture().setK(1) + } + } + + test("fit, transform, and summary") { + val predictionColName = "gm_prediction" + val probabilityColName = "gm_probability" + val gm = new GaussianMixture().setK(k).setMaxIter(2).setPredictionCol(predictionColName) + .setProbabilityCol(probabilityColName).setSeed(1) + val model = gm.fit(dataset) + assert(model.hasParent) + assert(model.weights.length === k) + assert(model.gaussians.length === k) + + val transformed = model.transform(dataset) + val expectedColumns = Array("features", predictionColName, probabilityColName) + expectedColumns.foreach { column => + assert(transformed.columns.contains(column)) + } + + // Check validity of model summary + val numRows = dataset.count() + assert(model.hasSummary) + val summary: GaussianMixtureSummary = model.summary + assert(summary.predictionCol === predictionColName) + assert(summary.probabilityCol === probabilityColName) + assert(summary.featuresCol === "features") + assert(summary.predictions.count() === numRows) + for (c <- Array(predictionColName, probabilityColName, "features")) { + assert(summary.predictions.columns.contains(c)) + } + assert(summary.cluster.columns === Array(predictionColName)) + assert(summary.probability.columns === Array(probabilityColName)) + val clusterSizes = summary.clusterSizes + assert(clusterSizes.length === k) + assert(clusterSizes.sum === numRows) + assert(clusterSizes.forall(_ >= 0)) + } + + test("read/write") { + def checkModelData(model: GaussianMixtureModel, model2: GaussianMixtureModel): Unit = { + assert(model.weights === model2.weights) + assert(model.gaussians.map(_.mu) === model2.gaussians.map(_.mu)) + assert(model.gaussians.map(_.sigma) === model2.gaussians.map(_.sigma)) + } + val gm = new GaussianMixture() + testEstimatorAndModelReadWrite(gm, dataset, + GaussianMixtureSuite.allParamSettings, checkModelData) + } +} + +object GaussianMixtureSuite { + /** + * Mapping from all Params to valid settings which differ from the defaults. + * This is useful for tests which need to exercise all Params, such as save/load. + * This excludes input columns to simplify some tests. + */ + val allParamSettings: Map[String, Any] = Map( + "predictionCol" -> "myPrediction", + "probabilityCol" -> "myProbability", + "k" -> 3, + "maxIter" -> 2, + "tol" -> 0.01 + ) +} -- cgit v1.2.3 From bb1fa5b2182f384cb711fc2be45b0f1a8c466ed6 Mon Sep 17 00:00:00 2001 From: Sameer Agarwal Date: Wed, 6 Apr 2016 11:59:42 -0700 Subject: [SPARK-14320][SQL] Make ColumnarBatch.Row mutable ## What changes were proposed in this pull request? In order to leverage a data structure like `AggregateHashMap` (https://github.com/apache/spark/pull/12055) to speed up aggregates with keys, we need to make `ColumnarBatch.Row` mutable. ## How was this patch tested? Unit test in `ColumnarBatchSuite`. Also, tested via `BenchmarkWholeStageCodegen`. Author: Sameer Agarwal Closes #12103 from sameeragarwal/mutable-row. --- .../sql/execution/vectorized/AggregateHashMap.java | 11 ++- .../sql/execution/vectorized/ColumnVector.java | 12 +++ .../sql/execution/vectorized/ColumnarBatch.java | 94 +++++++++++++++++++++- .../sql/execution/BenchmarkWholeStageCodegen.scala | 5 +- .../execution/vectorized/ColumnarBatchSuite.scala | 21 +++++ 5 files changed, 135 insertions(+), 8 deletions(-) diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/AggregateHashMap.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/AggregateHashMap.java index abe8db589d..69ce54390f 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/AggregateHashMap.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/AggregateHashMap.java @@ -19,6 +19,8 @@ package org.apache.spark.sql.execution.vectorized; import java.util.Arrays; +import com.google.common.annotations.VisibleForTesting; + import org.apache.spark.memory.MemoryMode; import org.apache.spark.sql.types.StructType; @@ -38,9 +40,9 @@ import static org.apache.spark.sql.types.DataTypes.LongType; * for certain distribution of keys) and requires us to fall back on the latter for correctness. */ public class AggregateHashMap { - public ColumnarBatch batch; - public int[] buckets; + private ColumnarBatch batch; + private int[] buckets; private int numBuckets; private int numRows = 0; private int maxSteps = 3; @@ -69,16 +71,17 @@ public class AggregateHashMap { this(schema, DEFAULT_CAPACITY, DEFAULT_LOAD_FACTOR, DEFAULT_MAX_STEPS); } - public int findOrInsert(long key) { + public ColumnarBatch.Row findOrInsert(long key) { int idx = find(key); if (idx != -1 && buckets[idx] == -1) { batch.column(0).putLong(numRows, key); batch.column(1).putLong(numRows, 0); buckets[idx] = numRows++; } - return idx; + return batch.getRow(buckets[idx]); } + @VisibleForTesting public int find(long key) { long h = hash(key); int step = 0; diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java index 74fa6323cc..d5daaf99df 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java @@ -566,6 +566,18 @@ public abstract class ColumnVector { } } + + public final void putDecimal(int rowId, Decimal value, int precision) { + if (precision <= Decimal.MAX_INT_DIGITS()) { + putInt(rowId, value.toInt()); + } else if (precision <= Decimal.MAX_LONG_DIGITS()) { + putLong(rowId, value.toLong()); + } else { + BigInteger bigInteger = value.toJavaBigDecimal().unscaledValue(); + putByteArray(rowId, bigInteger.toByteArray()); + } + } + /** * Returns the UTF8String for rowId. */ diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java index d1cc4e6d03..8cece73faa 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java @@ -16,6 +16,7 @@ */ package org.apache.spark.sql.execution.vectorized; +import java.math.BigDecimal; import java.util.*; import org.apache.commons.lang.NotImplementedException; @@ -23,6 +24,7 @@ import org.apache.commons.lang.NotImplementedException; import org.apache.spark.memory.MemoryMode; import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.catalyst.expressions.GenericMutableRow; +import org.apache.spark.sql.catalyst.expressions.MutableRow; import org.apache.spark.sql.catalyst.expressions.UnsafeRow; import org.apache.spark.sql.catalyst.util.ArrayData; import org.apache.spark.sql.catalyst.util.MapData; @@ -91,7 +93,7 @@ public final class ColumnarBatch { * Adapter class to interop with existing components that expect internal row. A lot of * performance is lost with this translation. */ - public static final class Row extends InternalRow { + public static final class Row extends MutableRow { protected int rowId; private final ColumnarBatch parent; private final int fixedLenRowSize; @@ -232,6 +234,96 @@ public final class ColumnarBatch { public Object get(int ordinal, DataType dataType) { throw new NotImplementedException(); } + + @Override + public void update(int ordinal, Object value) { + if (value == null) { + setNullAt(ordinal); + } else { + DataType dt = columns[ordinal].dataType(); + if (dt instanceof BooleanType) { + setBoolean(ordinal, (boolean) value); + } else if (dt instanceof IntegerType) { + setInt(ordinal, (int) value); + } else if (dt instanceof ShortType) { + setShort(ordinal, (short) value); + } else if (dt instanceof LongType) { + setLong(ordinal, (long) value); + } else if (dt instanceof FloatType) { + setFloat(ordinal, (float) value); + } else if (dt instanceof DoubleType) { + setDouble(ordinal, (double) value); + } else if (dt instanceof DecimalType) { + DecimalType t = (DecimalType) dt; + setDecimal(ordinal, Decimal.apply((BigDecimal) value, t.precision(), t.scale()), + t.precision()); + } else { + throw new NotImplementedException("Datatype not supported " + dt); + } + } + } + + @Override + public void setNullAt(int ordinal) { + assert (!columns[ordinal].isConstant); + columns[ordinal].putNull(rowId); + } + + @Override + public void setBoolean(int ordinal, boolean value) { + assert (!columns[ordinal].isConstant); + columns[ordinal].putNotNull(rowId); + columns[ordinal].putBoolean(rowId, value); + } + + @Override + public void setByte(int ordinal, byte value) { + assert (!columns[ordinal].isConstant); + columns[ordinal].putNotNull(rowId); + columns[ordinal].putByte(rowId, value); + } + + @Override + public void setShort(int ordinal, short value) { + assert (!columns[ordinal].isConstant); + columns[ordinal].putNotNull(rowId); + columns[ordinal].putShort(rowId, value); + } + + @Override + public void setInt(int ordinal, int value) { + assert (!columns[ordinal].isConstant); + columns[ordinal].putNotNull(rowId); + columns[ordinal].putInt(rowId, value); + } + + @Override + public void setLong(int ordinal, long value) { + assert (!columns[ordinal].isConstant); + columns[ordinal].putNotNull(rowId); + columns[ordinal].putLong(rowId, value); + } + + @Override + public void setFloat(int ordinal, float value) { + assert (!columns[ordinal].isConstant); + columns[ordinal].putNotNull(rowId); + columns[ordinal].putFloat(rowId, value); + } + + @Override + public void setDouble(int ordinal, double value) { + assert (!columns[ordinal].isConstant); + columns[ordinal].putNotNull(rowId); + columns[ordinal].putDouble(rowId, value); + } + + @Override + public void setDecimal(int ordinal, Decimal value, int precision) { + assert (!columns[ordinal].isConstant); + columns[ordinal].putNotNull(rowId); + columns[ordinal].putDecimal(rowId, value, precision); + } } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala index 3566ef3043..5dbf619876 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala @@ -517,9 +517,8 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite { .add("value", LongType) val map = new AggregateHashMap(schema) while (i < numKeys) { - val idx = map.findOrInsert(i.toLong) - map.batch.column(1).putLong(map.buckets(idx), - map.batch.column(1).getLong(map.buckets(idx)) + 1) + val row = map.findOrInsert(i.toLong) + row.setLong(1, row.getLong(1) + 1) i += 1 } var s = 0 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala index 4262097e8f..8a551cd78c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala @@ -756,4 +756,25 @@ class ColumnarBatchSuite extends SparkFunSuite { }} } } + + test("mutable ColumnarBatch rows") { + val NUM_ITERS = 10 + val types = Array( + BooleanType, FloatType, DoubleType, + IntegerType, LongType, ShortType, DecimalType.IntDecimal, new DecimalType(30, 10)) + for (i <- 0 to NUM_ITERS) { + val random = new Random(System.nanoTime()) + val schema = RandomDataGenerator.randomSchema(random, numFields = 20, types) + val oldRow = RandomDataGenerator.randomRow(random, schema) + val newRow = RandomDataGenerator.randomRow(random, schema) + + (MemoryMode.ON_HEAP :: MemoryMode.OFF_HEAP :: Nil).foreach { memMode => + val batch = ColumnVectorUtils.toBatch(schema, memMode, (oldRow :: Nil).iterator.asJava) + val columnarBatchRow = batch.getRow(0) + newRow.toSeq.zipWithIndex.foreach(i => columnarBatchRow.update(i._2, i._1)) + compareStruct(schema, columnarBatchRow, newRow, 0) + batch.close() + } + } + } } -- cgit v1.2.3 From 9c6556c5f8ab013b36312db4bf02c4c6d965a535 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Wed, 6 Apr 2016 12:07:47 -0700 Subject: [SPARK-13430][PYSPARK][ML] Python API for training summaries of linear and logistic regression ## What changes were proposed in this pull request? Adding Python API for training summaries of LogisticRegression and LinearRegression in PySpark ML. ## How was this patch tested? Added unit tests to exercise the api calls for the summary classes. Also, manually verified values are expected and match those from Scala directly. Author: Bryan Cutler Closes #11621 from BryanCutler/pyspark-ml-summary-SPARK-13430. --- .../ml/classification/LogisticRegression.scala | 8 +- .../spark/ml/regression/LinearRegression.scala | 41 +++- project/MimaExcludes.scala | 3 + python/pyspark/ml/classification.py | 218 +++++++++++++++++- python/pyspark/ml/regression.py | 245 ++++++++++++++++++++- python/pyspark/ml/tests.py | 87 +++++++- python/pyspark/ml/wrapper.py | 30 ++- 7 files changed, 602 insertions(+), 30 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index aeb94a6600..37182928cc 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -777,10 +777,10 @@ sealed trait LogisticRegressionSummary extends Serializable { /** Dataframe outputted by the model's `transform` method. */ def predictions: DataFrame - /** Field in "predictions" which gives the calibrated probability of each instance as a vector. */ + /** Field in "predictions" which gives the calibrated probability of each class as a vector. */ def probabilityCol: String - /** Field in "predictions" which gives the true label of each instance. */ + /** Field in "predictions" which gives the true label of each instance (if available). */ def labelCol: String /** Field in "predictions" which gives the features of each instance as a vector. */ @@ -794,7 +794,7 @@ sealed trait LogisticRegressionSummary extends Serializable { * * @param predictions dataframe outputted by the model's `transform` method. * @param probabilityCol field in "predictions" which gives the calibrated probability of - * each instance as a vector. + * each class as a vector. * @param labelCol field in "predictions" which gives the true label of each instance. * @param featuresCol field in "predictions" which gives the features of each instance as a vector. * @param objectiveHistory objective function (scaled loss + regularization) at each iteration. @@ -818,7 +818,7 @@ class BinaryLogisticRegressionTrainingSummary private[classification] ( * * @param predictions dataframe outputted by the model's `transform` method. * @param probabilityCol field in "predictions" which gives the calibrated probability of - * each instance. + * each class as a vector. * @param labelCol field in "predictions" which gives the true label of each instance. * @param featuresCol field in "predictions" which gives the features of each instance as a vector. */ diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index 2633c06f40..9619e72a45 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -190,9 +190,9 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String summaryModel.transform(dataset), predictionColName, $(labelCol), + $(featuresCol), summaryModel, model.diagInvAtWA.toArray, - $(featuresCol), Array(0D)) return lrModel.setSummary(trainingSummary) @@ -249,9 +249,9 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String summaryModel.transform(dataset), predictionColName, $(labelCol), + $(featuresCol), model, Array(0D), - $(featuresCol), Array(0D)) return copyValues(model.setSummary(trainingSummary)) } else { @@ -356,9 +356,9 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String summaryModel.transform(dataset), predictionColName, $(labelCol), + $(featuresCol), model, Array(0D), - $(featuresCol), objectiveHistory) model.setSummary(trainingSummary) } @@ -421,7 +421,7 @@ class LinearRegressionModel private[ml] ( // Handle possible missing or invalid prediction columns val (summaryModel, predictionColName) = findSummaryModelAndPredictionCol() new LinearRegressionSummary(summaryModel.transform(dataset), predictionColName, - $(labelCol), summaryModel, Array(0D)) + $(labelCol), $(featuresCol), summaryModel, Array(0D)) } /** @@ -511,7 +511,7 @@ object LinearRegressionModel extends MLReadable[LinearRegressionModel] { /** * :: Experimental :: * Linear regression training results. Currently, the training summary ignores the - * training coefficients except for the objective trace. + * training weights except for the objective trace. * * @param predictions predictions outputted by the model's `transform` method. * @param objectiveHistory objective function (scaled loss + regularization) at each iteration. @@ -522,13 +522,24 @@ class LinearRegressionTrainingSummary private[regression] ( predictions: DataFrame, predictionCol: String, labelCol: String, + featuresCol: String, model: LinearRegressionModel, diagInvAtWA: Array[Double], - val featuresCol: String, val objectiveHistory: Array[Double]) - extends LinearRegressionSummary(predictions, predictionCol, labelCol, model, diagInvAtWA) { + extends LinearRegressionSummary( + predictions, + predictionCol, + labelCol, + featuresCol, + model, + diagInvAtWA) { - /** Number of training iterations until termination */ + /** + * Number of training iterations until termination + * + * This value is only available when using the "l-bfgs" solver. + * @see [[LinearRegression.solver]] + */ @Since("1.5.0") val totalIterations = objectiveHistory.length @@ -539,6 +550,10 @@ class LinearRegressionTrainingSummary private[regression] ( * Linear regression results evaluated on a dataset. * * @param predictions predictions outputted by the model's `transform` method. + * @param predictionCol Field in "predictions" which gives the predicted value of the label at + * each instance. + * @param labelCol Field in "predictions" which gives the true label of each instance. + * @param featuresCol Field in "predictions" which gives the features of each instance as a vector. */ @Since("1.5.0") @Experimental @@ -546,6 +561,7 @@ class LinearRegressionSummary private[regression] ( @transient val predictions: DataFrame, val predictionCol: String, val labelCol: String, + val featuresCol: String, val model: LinearRegressionModel, private val diagInvAtWA: Array[Double]) extends Serializable { @@ -639,6 +655,9 @@ class LinearRegressionSummary private[regression] ( /** * Standard error of estimated coefficients and intercept. + * + * This value is only available when using the "normal" solver. + * @see [[LinearRegression.solver]] */ lazy val coefficientStandardErrors: Array[Double] = { if (diagInvAtWA.length == 1 && diagInvAtWA(0) == 0) { @@ -660,6 +679,9 @@ class LinearRegressionSummary private[regression] ( /** * T-statistic of estimated coefficients and intercept. + * + * This value is only available when using the "normal" solver. + * @see [[LinearRegression.solver]] */ lazy val tValues: Array[Double] = { if (diagInvAtWA.length == 1 && diagInvAtWA(0) == 0) { @@ -677,6 +699,9 @@ class LinearRegressionSummary private[regression] ( /** * Two-sided p-value of estimated coefficients and intercept. + * + * This value is only available when using the "normal" solver. + * @see [[LinearRegression.solver]] */ lazy val pValues: Array[Double] = { if (diagInvAtWA.length == 1 && diagInvAtWA(0) == 0) { diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 9f245afd50..d916c49a6a 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -610,6 +610,9 @@ object MimaExcludes { // [SPARK-13674][SQL] Add wholestage codegen support to Sample ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.util.random.PoissonSampler.this"), ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.util.random.PoissonSampler.this") + ) ++ Seq( + // [SPARK-13430][ML] moved featureCol from LinearRegressionModelSummary to LinearRegressionSummary + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.ml.regression.LinearRegressionSummary.this") ) case v if v.startsWith("1.6") => Seq( diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index 067009559b..be7f9ea9ef 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -19,15 +19,18 @@ import warnings from pyspark import since from pyspark.ml.util import * -from pyspark.ml.wrapper import JavaEstimator, JavaModel +from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaCallable from pyspark.ml.param import TypeConverters from pyspark.ml.param.shared import * from pyspark.ml.regression import ( RandomForestParams, TreeEnsembleParams, DecisionTreeModel, TreeEnsembleModels) from pyspark.mllib.common import inherit_doc +from pyspark.sql import DataFrame __all__ = ['LogisticRegression', 'LogisticRegressionModel', + 'LogisticRegressionSummary', 'LogisticRegressionTrainingSummary', + 'BinaryLogisticRegressionSummary', 'BinaryLogisticRegressionTrainingSummary', 'DecisionTreeClassifier', 'DecisionTreeClassificationModel', 'GBTClassifier', 'GBTClassificationModel', 'RandomForestClassifier', 'RandomForestClassificationModel', @@ -233,6 +236,219 @@ class LogisticRegressionModel(JavaModel, JavaMLWritable, JavaMLReadable): """ return self._call_java("intercept") + @property + @since("2.0.0") + def summary(self): + """ + Gets summary (e.g. residuals, mse, r-squared ) of model on + training set. An exception is thrown if + `trainingSummary is None`. + """ + java_blrt_summary = self._call_java("summary") + # Note: Once multiclass is added, update this to return correct summary + return BinaryLogisticRegressionTrainingSummary(java_blrt_summary) + + @property + @since("2.0.0") + def hasSummary(self): + """ + Indicates whether a training summary exists for this model + instance. + """ + return self._call_java("hasSummary") + + @since("2.0.0") + def evaluate(self, dataset): + """ + Evaluates the model on a test dataset. + + :param dataset: + Test dataset to evaluate model on, where dataset is an + instance of :py:class:`pyspark.sql.DataFrame` + """ + if not isinstance(dataset, DataFrame): + raise ValueError("dataset must be a DataFrame but got %s." % type(dataset)) + java_blr_summary = self._call_java("evaluate", dataset) + return BinaryLogisticRegressionSummary(java_blr_summary) + + +class LogisticRegressionSummary(JavaCallable): + """ + Abstraction for Logistic Regression Results for a given model. + + .. versionadded:: 2.0.0 + """ + + @property + @since("2.0.0") + def predictions(self): + """ + Dataframe outputted by the model's `transform` method. + """ + return self._call_java("predictions") + + @property + @since("2.0.0") + def probabilityCol(self): + """ + Field in "predictions" which gives the calibrated probability + of each class as a vector. + """ + return self._call_java("probabilityCol") + + @property + @since("2.0.0") + def labelCol(self): + """ + Field in "predictions" which gives the true label of each + instance. + """ + return self._call_java("labelCol") + + @property + @since("2.0.0") + def featuresCol(self): + """ + Field in "predictions" which gives the features of each instance + as a vector. + """ + return self._call_java("featuresCol") + + +@inherit_doc +class LogisticRegressionTrainingSummary(LogisticRegressionSummary): + """ + Abstraction for multinomial Logistic Regression Training results. + Currently, the training summary ignores the training weights except + for the objective trace. + + .. versionadded:: 2.0.0 + """ + + @property + @since("2.0.0") + def objectiveHistory(self): + """ + Objective function (scaled loss + regularization) at each + iteration. + """ + return self._call_java("objectiveHistory") + + @property + @since("2.0.0") + def totalIterations(self): + """ + Number of training iterations until termination. + """ + return self._call_java("totalIterations") + + +@inherit_doc +class BinaryLogisticRegressionSummary(LogisticRegressionSummary): + """ + .. note:: Experimental + + Binary Logistic regression results for a given model. + + .. versionadded:: 2.0.0 + """ + + @property + @since("2.0.0") + def roc(self): + """ + Returns the receiver operating characteristic (ROC) curve, + which is an Dataframe having two fields (FPR, TPR) with + (0.0, 0.0) prepended and (1.0, 1.0) appended to it. + Reference: http://en.wikipedia.org/wiki/Receiver_operating_characteristic + + Note: This ignores instance weights (setting all to 1.0) from + `LogisticRegression.weightCol`. This will change in later Spark + versions. + """ + return self._call_java("roc") + + @property + @since("2.0.0") + def areaUnderROC(self): + """ + Computes the area under the receiver operating characteristic + (ROC) curve. + + Note: This ignores instance weights (setting all to 1.0) from + `LogisticRegression.weightCol`. This will change in later Spark + versions. + """ + return self._call_java("areaUnderROC") + + @property + @since("2.0.0") + def pr(self): + """ + Returns the precision-recall curve, which is an Dataframe + containing two fields recall, precision with (0.0, 1.0) prepended + to it. + + Note: This ignores instance weights (setting all to 1.0) from + `LogisticRegression.weightCol`. This will change in later Spark + versions. + """ + return self._call_java("pr") + + @property + @since("2.0.0") + def fMeasureByThreshold(self): + """ + Returns a dataframe with two fields (threshold, F-Measure) curve + with beta = 1.0. + + Note: This ignores instance weights (setting all to 1.0) from + `LogisticRegression.weightCol`. This will change in later Spark + versions. + """ + return self._call_java("fMeasureByThreshold") + + @property + @since("2.0.0") + def precisionByThreshold(self): + """ + Returns a dataframe with two fields (threshold, precision) curve. + Every possible probability obtained in transforming the dataset + are used as thresholds used in calculating the precision. + + Note: This ignores instance weights (setting all to 1.0) from + `LogisticRegression.weightCol`. This will change in later Spark + versions. + """ + return self._call_java("precisionByThreshold") + + @property + @since("2.0.0") + def recallByThreshold(self): + """ + Returns a dataframe with two fields (threshold, recall) curve. + Every possible probability obtained in transforming the dataset + are used as thresholds used in calculating the recall. + + Note: This ignores instance weights (setting all to 1.0) from + `LogisticRegression.weightCol`. This will change in later Spark + versions. + """ + return self._call_java("recallByThreshold") + + +@inherit_doc +class BinaryLogisticRegressionTrainingSummary(BinaryLogisticRegressionSummary, + LogisticRegressionTrainingSummary): + """ + .. note:: Experimental + + Binary Logistic regression training results for a given model. + + .. versionadded:: 2.0.0 + """ + pass + class TreeClassifierParams(object): """ diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py index de8a5e4bed..6cd1b4bf3a 100644 --- a/python/pyspark/ml/regression.py +++ b/python/pyspark/ml/regression.py @@ -20,8 +20,9 @@ import warnings from pyspark import since from pyspark.ml.param.shared import * from pyspark.ml.util import * -from pyspark.ml.wrapper import JavaEstimator, JavaModel +from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaCallable from pyspark.mllib.common import inherit_doc +from pyspark.sql import DataFrame __all__ = ['AFTSurvivalRegression', 'AFTSurvivalRegressionModel', @@ -29,6 +30,7 @@ __all__ = ['AFTSurvivalRegression', 'AFTSurvivalRegressionModel', 'GBTRegressor', 'GBTRegressionModel', 'IsotonicRegression', 'IsotonicRegressionModel', 'LinearRegression', 'LinearRegressionModel', + 'LinearRegressionSummary', 'LinearRegressionTrainingSummary', 'RandomForestRegressor', 'RandomForestRegressionModel'] @@ -131,7 +133,6 @@ class LinearRegressionModel(JavaModel, JavaMLWritable, JavaMLReadable): """ Model weights. """ - warnings.warn("weights is deprecated. Use coefficients instead.") return self._call_java("weights") @@ -151,6 +152,246 @@ class LinearRegressionModel(JavaModel, JavaMLWritable, JavaMLReadable): """ return self._call_java("intercept") + @property + @since("2.0.0") + def summary(self): + """ + Gets summary (e.g. residuals, mse, r-squared ) of model on + training set. An exception is thrown if + `trainingSummary is None`. + """ + java_lrt_summary = self._call_java("summary") + return LinearRegressionTrainingSummary(java_lrt_summary) + + @property + @since("2.0.0") + def hasSummary(self): + """ + Indicates whether a training summary exists for this model + instance. + """ + return self._call_java("hasSummary") + + @since("2.0.0") + def evaluate(self, dataset): + """ + Evaluates the model on a test dataset. + + :param dataset: + Test dataset to evaluate model on, where dataset is an + instance of :py:class:`pyspark.sql.DataFrame` + """ + if not isinstance(dataset, DataFrame): + raise ValueError("dataset must be a DataFrame but got %s." % type(dataset)) + java_lr_summary = self._call_java("evaluate", dataset) + return LinearRegressionSummary(java_lr_summary) + + +class LinearRegressionSummary(JavaCallable): + """ + .. note:: Experimental + + Linear regression results evaluated on a dataset. + + .. versionadded:: 2.0.0 + """ + + @property + @since("2.0.0") + def predictions(self): + """ + Dataframe outputted by the model's `transform` method. + """ + return self._call_java("predictions") + + @property + @since("2.0.0") + def predictionCol(self): + """ + Field in "predictions" which gives the predicted value of + the label at each instance. + """ + return self._call_java("predictionCol") + + @property + @since("2.0.0") + def labelCol(self): + """ + Field in "predictions" which gives the true label of each + instance. + """ + return self._call_java("labelCol") + + @property + @since("2.0.0") + def featuresCol(self): + """ + Field in "predictions" which gives the features of each instance + as a vector. + """ + return self._call_java("featuresCol") + + @property + @since("2.0.0") + def explainedVariance(self): + """ + Returns the explained variance regression score. + explainedVariance = 1 - variance(y - \hat{y}) / variance(y) + Reference: http://en.wikipedia.org/wiki/Explained_variation + + Note: This ignores instance weights (setting all to 1.0) from + `LinearRegression.weightCol`. This will change in later Spark + versions. + """ + return self._call_java("explainedVariance") + + @property + @since("2.0.0") + def meanAbsoluteError(self): + """ + Returns the mean absolute error, which is a risk function + corresponding to the expected value of the absolute error + loss or l1-norm loss. + + Note: This ignores instance weights (setting all to 1.0) from + `LinearRegression.weightCol`. This will change in later Spark + versions. + """ + return self._call_java("meanAbsoluteError") + + @property + @since("2.0.0") + def meanSquaredError(self): + """ + Returns the mean squared error, which is a risk function + corresponding to the expected value of the squared error + loss or quadratic loss. + + Note: This ignores instance weights (setting all to 1.0) from + `LinearRegression.weightCol`. This will change in later Spark + versions. + """ + return self._call_java("meanSquaredError") + + @property + @since("2.0.0") + def rootMeanSquaredError(self): + """ + Returns the root mean squared error, which is defined as the + square root of the mean squared error. + + Note: This ignores instance weights (setting all to 1.0) from + `LinearRegression.weightCol`. This will change in later Spark + versions. + """ + return self._call_java("rootMeanSquaredError") + + @property + @since("2.0.0") + def r2(self): + """ + Returns R^2^, the coefficient of determination. + Reference: http://en.wikipedia.org/wiki/Coefficient_of_determination + + Note: This ignores instance weights (setting all to 1.0) from + `LinearRegression.weightCol`. This will change in later Spark + versions. + """ + return self._call_java("r2") + + @property + @since("2.0.0") + def residuals(self): + """ + Residuals (label - predicted value) + """ + return self._call_java("residuals") + + @property + @since("2.0.0") + def numInstances(self): + """ + Number of instances in DataFrame predictions + """ + return self._call_java("numInstances") + + @property + @since("2.0.0") + def devianceResiduals(self): + """ + The weighted residuals, the usual residuals rescaled by the + square root of the instance weights. + """ + return self._call_java("devianceResiduals") + + @property + @since("2.0.0") + def coefficientStandardErrors(self): + """ + Standard error of estimated coefficients and intercept. + This value is only available when using the "normal" solver. + + .. seealso:: :py:attr:`LinearRegression.solver` + """ + return self._call_java("coefficientStandardErrors") + + @property + @since("2.0.0") + def tValues(self): + """ + T-statistic of estimated coefficients and intercept. + This value is only available when using the "normal" solver. + + .. seealso:: :py:attr:`LinearRegression.solver` + """ + return self._call_java("tValues") + + @property + @since("2.0.0") + def pValues(self): + """ + Two-sided p-value of estimated coefficients and intercept. + This value is only available when using the "normal" solver. + + .. seealso:: :py:attr:`LinearRegression.solver` + """ + return self._call_java("pValues") + + +@inherit_doc +class LinearRegressionTrainingSummary(LinearRegressionSummary): + """ + .. note:: Experimental + + Linear regression training results. Currently, the training summary ignores the + training weights except for the objective trace. + + .. versionadded:: 2.0.0 + """ + + @property + @since("2.0.0") + def objectiveHistory(self): + """ + Objective function (scaled loss + regularization) at each + iteration. + This value is only available when using the "l-bfgs" solver. + + .. seealso:: :py:attr:`LinearRegression.solver` + """ + return self._call_java("objectiveHistory") + + @property + @since("2.0.0") + def totalIterations(self): + """ + Number of training iterations until termination. + This value is only available when using the "l-bfgs" solver. + + .. seealso:: :py:attr:`LinearRegression.solver` + """ + return self._call_java("totalIterations") + @inherit_doc class IsotonicRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index e3f873e3a7..2dcd5eeb52 100644 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -239,6 +239,17 @@ class OtherTestParams(HasMaxIter, HasInputCol, HasSeed): return self._set(**kwargs) +class HasThrowableProperty(Params): + + def __init__(self): + super(HasThrowableProperty, self).__init__() + self.p = Param(self, "none", "empty param") + + @property + def test_property(self): + raise RuntimeError("Test property to raise error when invoked") + + class ParamTests(PySparkTestCase): def test_copy_new_parent(self): @@ -749,15 +760,75 @@ class PersistenceTest(PySparkTestCase): pass -class HasThrowableProperty(Params): - - def __init__(self): - super(HasThrowableProperty, self).__init__() - self.p = Param(self, "none", "empty param") +class TrainingSummaryTest(PySparkTestCase): - @property - def test_property(self): - raise RuntimeError("Test property to raise error when invoked") + def test_linear_regression_summary(self): + from pyspark.mllib.linalg import Vectors + sqlContext = SQLContext(self.sc) + df = sqlContext.createDataFrame([(1.0, 2.0, Vectors.dense(1.0)), + (0.0, 2.0, Vectors.sparse(1, [], []))], + ["label", "weight", "features"]) + lr = LinearRegression(maxIter=5, regParam=0.0, solver="normal", weightCol="weight", + fitIntercept=False) + model = lr.fit(df) + self.assertTrue(model.hasSummary) + s = model.summary + # test that api is callable and returns expected types + self.assertGreater(s.totalIterations, 0) + self.assertTrue(isinstance(s.predictions, DataFrame)) + self.assertEqual(s.predictionCol, "prediction") + self.assertEqual(s.labelCol, "label") + self.assertEqual(s.featuresCol, "features") + objHist = s.objectiveHistory + self.assertTrue(isinstance(objHist, list) and isinstance(objHist[0], float)) + self.assertAlmostEqual(s.explainedVariance, 0.25, 2) + self.assertAlmostEqual(s.meanAbsoluteError, 0.0) + self.assertAlmostEqual(s.meanSquaredError, 0.0) + self.assertAlmostEqual(s.rootMeanSquaredError, 0.0) + self.assertAlmostEqual(s.r2, 1.0, 2) + self.assertTrue(isinstance(s.residuals, DataFrame)) + self.assertEqual(s.numInstances, 2) + devResiduals = s.devianceResiduals + self.assertTrue(isinstance(devResiduals, list) and isinstance(devResiduals[0], float)) + coefStdErr = s.coefficientStandardErrors + self.assertTrue(isinstance(coefStdErr, list) and isinstance(coefStdErr[0], float)) + tValues = s.tValues + self.assertTrue(isinstance(tValues, list) and isinstance(tValues[0], float)) + pValues = s.pValues + self.assertTrue(isinstance(pValues, list) and isinstance(pValues[0], float)) + # test evaluation (with training dataset) produces a summary with same values + # one check is enough to verify a summary is returned, Scala version runs full test + sameSummary = model.evaluate(df) + self.assertAlmostEqual(sameSummary.explainedVariance, s.explainedVariance) + + def test_logistic_regression_summary(self): + from pyspark.mllib.linalg import Vectors + sqlContext = SQLContext(self.sc) + df = sqlContext.createDataFrame([(1.0, 2.0, Vectors.dense(1.0)), + (0.0, 2.0, Vectors.sparse(1, [], []))], + ["label", "weight", "features"]) + lr = LogisticRegression(maxIter=5, regParam=0.01, weightCol="weight", fitIntercept=False) + model = lr.fit(df) + self.assertTrue(model.hasSummary) + s = model.summary + # test that api is callable and returns expected types + self.assertTrue(isinstance(s.predictions, DataFrame)) + self.assertEqual(s.probabilityCol, "probability") + self.assertEqual(s.labelCol, "label") + self.assertEqual(s.featuresCol, "features") + objHist = s.objectiveHistory + self.assertTrue(isinstance(objHist, list) and isinstance(objHist[0], float)) + self.assertGreater(s.totalIterations, 0) + self.assertTrue(isinstance(s.roc, DataFrame)) + self.assertAlmostEqual(s.areaUnderROC, 1.0, 2) + self.assertTrue(isinstance(s.pr, DataFrame)) + self.assertTrue(isinstance(s.fMeasureByThreshold, DataFrame)) + self.assertTrue(isinstance(s.precisionByThreshold, DataFrame)) + self.assertTrue(isinstance(s.recallByThreshold, DataFrame)) + # test evaluation (with training dataset) produces a summary with same values + # one check is enough to verify a summary is returned, Scala version runs full test + sameSummary = model.evaluate(df) + self.assertAlmostEqual(sameSummary.areaUnderROC, s.areaUnderROC) if __name__ == "__main__": diff --git a/python/pyspark/ml/wrapper.py b/python/pyspark/ml/wrapper.py index ca93bf7d7d..a2cf2296fb 100644 --- a/python/pyspark/ml/wrapper.py +++ b/python/pyspark/ml/wrapper.py @@ -213,8 +213,30 @@ class JavaTransformer(Transformer, JavaWrapper): return DataFrame(self._java_obj.transform(dataset._jdf), dataset.sql_ctx) +class JavaCallable(object): + """ + Wrapper for a plain object in JVM to make Java calls, can be used + as a mixin to another class that defines a _java_obj wrapper + """ + def __init__(self, java_obj=None, sc=None): + super(JavaCallable, self).__init__() + self._sc = sc if sc is not None else SparkContext._active_spark_context + # if this class is a mixin and _java_obj is already defined then don't initialize + if java_obj is not None or not hasattr(self, "_java_obj"): + self._java_obj = java_obj + + def __del__(self): + if self._java_obj is not None: + self._sc._gateway.detach(self._java_obj) + + def _call_java(self, name, *args): + m = getattr(self._java_obj, name) + java_args = [_py2java(self._sc, arg) for arg in args] + return _java2py(self._sc, m(*java_args)) + + @inherit_doc -class JavaModel(Model, JavaTransformer): +class JavaModel(Model, JavaCallable, JavaTransformer): """ Base class for :py:class:`Model`s that wrap Java/Scala implementations. Subclasses should inherit this class before @@ -259,9 +281,3 @@ class JavaModel(Model, JavaTransformer): that._java_obj = self._java_obj.copy(self._empty_java_param_map()) that._transfer_params_to_java() return that - - def _call_java(self, name, *args): - m = getattr(self._java_obj, name) - sc = SparkContext._active_spark_context - java_args = [_py2java(sc, arg) for arg in args] - return _java2py(sc, m(*java_args)) -- cgit v1.2.3 From a4ead6d3881f071a2ae53ff1c961c6ac388cac1d Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Wed, 6 Apr 2016 12:28:04 -0700 Subject: [SPARK-14382][SQL] QueryProgress should be post after committedOffsets is updated ## What changes were proposed in this pull request? Make sure QueryProgress is post after committedOffsets is updated. If QueryProgress is post before committedOffsets is updated, the listener may see a wrong sinkStatus (created from committedOffsets). See https://amplab.cs.berkeley.edu/jenkins/job/spark-master-test-maven-hadoop-2.2/644/testReport/junit/org.apache.spark.sql.util/ContinuousQueryListenerSuite/single_listener/ for an example of the failure. ## How was this patch tested? Existing unit tests. Author: Shixiong Zhu Closes #12155 from zsxwing/SPARK-14382. --- .../spark/sql/execution/streaming/StreamExecution.scala | 15 +++++---------- .../spark/sql/util/ContinuousQueryListenerSuite.scala | 3 +-- 2 files changed, 6 insertions(+), 12 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala index 3e4acb752a..688e051e1f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala @@ -159,7 +159,7 @@ class StreamExecution( triggerExecutor.execute(() => { if (isActive) { if (dataAvailable) runBatch() - commitAndConstructNextBatch() + constructNextBatch() true } else { false @@ -207,7 +207,7 @@ class StreamExecution( case None => // We are starting this stream for the first time. logInfo(s"Starting new continuous query.") currentBatchId = 0 - commitAndConstructNextBatch() + constructNextBatch() } } @@ -227,15 +227,8 @@ class StreamExecution( /** * Queries all of the sources to see if any new data is available. When there is new data the * batchId counter is incremented and a new log entry is written with the newest offsets. - * - * Note that committing the offsets for a new batch implicitly marks the previous batch as - * finished and thus this method should only be called when all currently available data - * has been written to the sink. */ - private def commitAndConstructNextBatch(): Boolean = { - // Update committed offsets. - committedOffsets ++= availableOffsets - + private def constructNextBatch(): Boolean = { // There is a potential dead-lock in Hadoop "Shell.runCommand" before 2.5.0 (HADOOP-10622). // If we interrupt some thread running Shell.runCommand, we may hit this issue. // As "FileStreamSource.getOffset" will create a file using HDFS API and call "Shell.runCommand" @@ -331,6 +324,8 @@ class StreamExecution( val batchTime = (System.nanoTime() - startTime).toDouble / 1000000 logInfo(s"Completed up to $availableOffsets in ${batchTime}ms") + // Update committed offsets. + committedOffsets ++= availableOffsets postEvent(new QueryProgress(this)) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/util/ContinuousQueryListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/util/ContinuousQueryListenerSuite.scala index d04783ecac..3498fe83d0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/util/ContinuousQueryListenerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/util/ContinuousQueryListenerSuite.scala @@ -146,7 +146,6 @@ class ContinuousQueryListenerSuite extends StreamTest with SharedSQLContext with private def withListenerAdded(listener: ContinuousQueryListener)(body: => Unit): Unit = { - @volatile var query: StreamExecution = null try { failAfter(1 minute) { sqlContext.streams.addListener(listener) @@ -212,7 +211,7 @@ class ContinuousQueryListenerSuite extends StreamTest with SharedSQLContext with case class QueryStatus( active: Boolean, - expection: Option[Exception], + exception: Option[Exception], sourceStatuses: Array[SourceStatus], sinkStatus: SinkStatus) -- cgit v1.2.3 From 5a4b11a901703464b9261dea0642d80cf8d4856c Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Wed, 6 Apr 2016 15:33:39 -0700 Subject: [SPARK-14224] [SPARK-14223] [SPARK-14310] [SQL] fix RowEncoder and parquet reader for wide table ## What changes were proposed in this pull request? 1) fix the RowEncoder for wide table (many columns) by splitting the generate code into multiple functions. 2) Separate DataSourceScan as RowDataSourceScan and BatchedDataSourceScan 3) Disable the returning columnar batch in parquet reader if there are many columns. 4) Added a internal config for maximum number of fields (nested) columns supported by whole stage codegen. Closes #12098 ## How was this patch tested? Add a tests for table with 1000 columns. Author: Davies Liu Closes #12047 from davies/many_columns. --- .../spark/sql/catalyst/expressions/objects.scala | 24 +- .../parquet/VectorizedParquetRecordReader.java | 19 +- .../scala/org/apache/spark/sql/SQLContext.scala | 2 +- .../apache/spark/sql/execution/ExistingRDD.scala | 291 +++++++++++---------- .../spark/sql/execution/WholeStageCodegen.scala | 13 +- .../execution/datasources/DataSourceStrategy.scala | 6 +- .../execution/datasources/FileSourceStrategy.scala | 2 +- .../execution/datasources/SqlNewHadoopRDD.scala | 34 +-- .../datasources/parquet/ParquetRelation.scala | 77 ++++-- .../org/apache/spark/sql/internal/SQLConf.scala | 11 + .../org/apache/spark/sql/sources/interfaces.scala | 9 + .../datasources/FileSourceStrategySuite.scala | 3 +- .../datasources/parquet/ParquetQuerySuite.scala | 10 + 13 files changed, 267 insertions(+), 234 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala index a0490e1351..28b6b2adf8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala @@ -524,22 +524,26 @@ case class CreateExternalRow(children: Seq[Expression], schema: StructType) override def genCode(ctx: CodegenContext, ev: ExprCode): String = { val rowClass = classOf[GenericRowWithSchema].getName val values = ctx.freshName("values") - val schemaField = ctx.addReferenceObj("schema", schema) - s""" - boolean ${ev.isNull} = false; - final Object[] $values = new Object[${children.size}]; - """ + - children.zipWithIndex.map { case (e, i) => - val eval = e.gen(ctx) - eval.code + s""" + ctx.addMutableState("Object[]", values, "") + + val childrenCodes = children.zipWithIndex.map { case (e, i) => + val eval = e.gen(ctx) + eval.code + s""" if (${eval.isNull}) { $values[$i] = null; } else { $values[$i] = ${eval.value}; } """ - }.mkString("\n") + - s"final ${classOf[Row].getName} ${ev.value} = new $rowClass($values, this.$schemaField);" + } + val childrenCode = ctx.splitExpressions(ctx.INPUT_ROW, childrenCodes) + val schemaField = ctx.addReferenceObj("schema", schema) + s""" + boolean ${ev.isNull} = false; + $values = new Object[${children.size}]; + $childrenCode + final ${classOf[Row].getName} ${ev.value} = new $rowClass($values, this.$schemaField); + """ } } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java index a0b6276ef5..51bdf0f0f2 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java @@ -31,7 +31,8 @@ import org.apache.spark.memory.MemoryMode; import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.execution.vectorized.ColumnVectorUtils; import org.apache.spark.sql.execution.vectorized.ColumnarBatch; -import org.apache.spark.sql.types.*; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; /** * A specialized RecordReader that reads into InternalRows or ColumnarBatches directly using the @@ -99,20 +100,6 @@ public class VectorizedParquetRecordReader extends SpecificParquetRecordReaderBa */ private static final MemoryMode DEFAULT_MEMORY_MODE = MemoryMode.ON_HEAP; - /** - * Tries to initialize the reader for this split. Returns true if this reader supports reading - * this split and false otherwise. - */ - public boolean tryInitialize(InputSplit inputSplit, TaskAttemptContext taskAttemptContext) - throws IOException, InterruptedException { - try { - initialize(inputSplit, taskAttemptContext); - return true; - } catch (UnsupportedOperationException e) { - return false; - } - } - /** * Implementation of RecordReader API. */ @@ -222,7 +209,7 @@ public class VectorizedParquetRecordReader extends SpecificParquetRecordReaderBa return columnarBatch; } - /** + /* * Can be called before any rows are returned to enable returning columnar batches directly. */ public void enableReturningBatches() { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 1c9cb79ba4..9259ff4062 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -120,7 +120,7 @@ class SQLContext private[sql]( */ @transient protected[sql] lazy val sessionState: SessionState = new SessionState(self) - protected[sql] def conf: SQLConf = sessionState.conf + protected[spark] def conf: SQLConf = sessionState.conf /** * An interface to register custom [[org.apache.spark.sql.util.QueryExecutionListener]]s diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala index ab575e90c9..392c48fb7b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala @@ -24,13 +24,13 @@ import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Statistics} -import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, UnknownPartitioning} +import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, UnknownPartitioning} import org.apache.spark.sql.catalyst.util.toCommentSafeString import org.apache.spark.sql.execution.datasources.parquet.{DefaultSource => ParquetSource} import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.{BaseRelation, HadoopFsRelation} -import org.apache.spark.sql.types.DataType +import org.apache.spark.sql.types.{AtomicType, DataType} object RDDConversions { def productToRowRdd[A <: Product](data: RDD[A], outputTypes: Seq[DataType]): RDD[InternalRow] = { @@ -123,28 +123,30 @@ private[sql] case class PhysicalRDD( } } -/** Physical plan node for scanning data from a relation. */ -private[sql] case class DataSourceScan( - output: Seq[Attribute], - rdd: RDD[InternalRow], - @transient relation: BaseRelation, - override val metadata: Map[String, String] = Map.empty) - extends LeafNode with CodegenSupport { +private[sql] trait DataSourceScan extends LeafNode { + val rdd: RDD[InternalRow] + val relation: BaseRelation override val nodeName: String = relation.toString // Ignore rdd when checking results - override def sameResult(plan: SparkPlan ): Boolean = plan match { + override def sameResult(plan: SparkPlan): Boolean = plan match { case other: DataSourceScan => relation == other.relation && metadata == other.metadata case _ => false } +} - private[sql] override lazy val metrics = if (canProcessBatches()) { - Map("numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows"), - "scanTime" -> SQLMetrics.createTimingMetric(sparkContext, "scan time")) - } else { +/** Physical plan node for scanning data from a relation. */ +private[sql] case class RowDataSourceScan( + output: Seq[Attribute], + rdd: RDD[InternalRow], + @transient relation: BaseRelation, + override val outputPartitioning: Partitioning, + override val metadata: Map[String, String] = Map.empty) + extends DataSourceScan with CodegenSupport { + + private[sql] override lazy val metrics = Map("numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows")) - } val outputUnsafeRows = relation match { case r: HadoopFsRelation if r.fileFormat.isInstanceOf[ParquetSource] => @@ -153,38 +155,6 @@ private[sql] case class DataSourceScan( case _ => false } - override val outputPartitioning = { - val bucketSpec = relation match { - // TODO: this should be closer to bucket planning. - case r: HadoopFsRelation if r.sqlContext.conf.bucketingEnabled => r.bucketSpec - case _ => None - } - - def toAttribute(colName: String): Attribute = output.find(_.name == colName).getOrElse { - throw new AnalysisException(s"bucket column $colName not found in existing columns " + - s"(${output.map(_.name).mkString(", ")})") - } - - bucketSpec.map { spec => - val numBuckets = spec.numBuckets - val bucketColumns = spec.bucketColumnNames.map(toAttribute) - HashPartitioning(bucketColumns, numBuckets) - }.getOrElse { - UnknownPartitioning(0) - } - } - - private def canProcessBatches(): Boolean = { - relation match { - case r: HadoopFsRelation if r.fileFormat.isInstanceOf[ParquetSource] && - SQLContext.getActive().get.conf.getConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED) && - SQLContext.getActive().get.conf.getConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED) => - true - case _ => - false - } - } - protected override def doExecute(): RDD[InternalRow] = { val unsafeRow = if (outputUnsafeRows) { rdd @@ -211,6 +181,57 @@ private[sql] case class DataSourceScan( rdd :: Nil } + override protected def doProduce(ctx: CodegenContext): String = { + val numOutputRows = metricTerm(ctx, "numOutputRows") + // PhysicalRDD always just has one input + val input = ctx.freshName("input") + ctx.addMutableState("scala.collection.Iterator", input, s"$input = inputs[0];") + val exprRows = output.zipWithIndex.map{ case (a, i) => + new BoundReference(i, a.dataType, a.nullable) + } + val row = ctx.freshName("row") + ctx.INPUT_ROW = row + ctx.currentVars = null + val columnsRowInput = exprRows.map(_.gen(ctx)) + val inputRow = if (outputUnsafeRows) row else null + s""" + |while ($input.hasNext()) { + | InternalRow $row = (InternalRow) $input.next(); + | $numOutputRows.add(1); + | ${consume(ctx, columnsRowInput, inputRow).trim} + | if (shouldStop()) return; + |} + """.stripMargin + } +} + +/** Physical plan node for scanning data from a batched relation. */ +private[sql] case class BatchedDataSourceScan( + output: Seq[Attribute], + rdd: RDD[InternalRow], + @transient relation: BaseRelation, + override val outputPartitioning: Partitioning, + override val metadata: Map[String, String] = Map.empty) + extends DataSourceScan with CodegenSupport { + + private[sql] override lazy val metrics = + Map("numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows"), + "scanTime" -> SQLMetrics.createTimingMetric(sparkContext, "scan time")) + + protected override def doExecute(): RDD[InternalRow] = { + throw new UnsupportedOperationException + } + + override def simpleString: String = { + val metadataEntries = for ((key, value) <- metadata.toSeq.sorted) yield s"$key: $value" + val metadataStr = metadataEntries.mkString(" ", ", ", "") + s"BatchedScan $nodeName${output.mkString("[", ",", "]")}$metadataStr" + } + + override def upstreams(): Seq[RDD[InternalRow]] = { + rdd :: Nil + } + private def genCodeColumnVector(ctx: CodegenContext, columnVar: String, ordinal: String, dataType: DataType, nullable: Boolean): ExprCode = { val javaType = ctx.javaType(dataType) @@ -232,113 +253,65 @@ private[sql] case class DataSourceScan( // Support codegen so that we can avoid the UnsafeRow conversion in all cases. Codegen // never requires UnsafeRow as input. override protected def doProduce(ctx: CodegenContext): String = { - val columnarBatchClz = "org.apache.spark.sql.execution.vectorized.ColumnarBatch" - val columnVectorClz = "org.apache.spark.sql.execution.vectorized.ColumnVector" val input = ctx.freshName("input") - val idx = ctx.freshName("batchIdx") - val rowidx = ctx.freshName("rowIdx") - val batch = ctx.freshName("batch") // PhysicalRDD always just has one input ctx.addMutableState("scala.collection.Iterator", input, s"$input = inputs[0];") + + // metrics + val numOutputRows = metricTerm(ctx, "numOutputRows") + val scanTimeMetric = metricTerm(ctx, "scanTime") + val scanTimeTotalNs = ctx.freshName("scanTime") + ctx.addMutableState("long", scanTimeTotalNs, s"$scanTimeTotalNs = 0;") + + val columnarBatchClz = "org.apache.spark.sql.execution.vectorized.ColumnarBatch" + val batch = ctx.freshName("batch") ctx.addMutableState(columnarBatchClz, batch, s"$batch = null;") + + val columnVectorClz = "org.apache.spark.sql.execution.vectorized.ColumnVector" + val idx = ctx.freshName("batchIdx") ctx.addMutableState("int", idx, s"$idx = 0;") val colVars = output.indices.map(i => ctx.freshName("colInstance" + i)) val columnAssigns = colVars.zipWithIndex.map { case (name, i) => ctx.addMutableState(columnVectorClz, name, s"$name = null;") - s"$name = ${batch}.column($i);" } - - val row = ctx.freshName("row") - val numOutputRows = metricTerm(ctx, "numOutputRows") + s"$name = $batch.column($i);" + } - // The input RDD can either return (all) ColumnarBatches or InternalRows. We determine this - // by looking at the first value of the RDD and then calling the function which will process - // the remaining. It is faster to return batches. - // TODO: The abstractions between this class and SqlNewHadoopRDD makes it difficult to know - // here which path to use. Fix this. + val nextBatch = ctx.freshName("nextBatch") + ctx.addNewFunction(nextBatch, + s""" + |private void $nextBatch() throws java.io.IOException { + | long getBatchStart = System.nanoTime(); + | if ($input.hasNext()) { + | $batch = ($columnarBatchClz)$input.next(); + | $numOutputRows.add($batch.numRows()); + | $idx = 0; + | ${columnAssigns.mkString("", "\n", "\n")} + | } + | $scanTimeTotalNs += System.nanoTime() - getBatchStart; + |}""".stripMargin) - val exprRows = - output.zipWithIndex.map(x => new BoundReference(x._2, x._1.dataType, x._1.nullable)) - ctx.INPUT_ROW = row ctx.currentVars = null - val columnsRowInput = exprRows.map(_.gen(ctx)) - val inputRow = if (outputUnsafeRows) row else null - val scanRows = ctx.freshName("processRows") - ctx.addNewFunction(scanRows, - s""" - | private void $scanRows(InternalRow $row) throws java.io.IOException { - | boolean firstRow = true; - | while (!shouldStop() && (firstRow || $input.hasNext())) { - | if (firstRow) { - | firstRow = false; - | } else { - | $row = (InternalRow) $input.next(); - | } - | $numOutputRows.add(1); - | ${consume(ctx, columnsRowInput, inputRow).trim} - | } - | }""".stripMargin) - - // Timers for how long we spent inside the scan. We can only maintain this when using batches, - // otherwise the overhead is too high. - if (canProcessBatches()) { - val scanTimeMetric = metricTerm(ctx, "scanTime") - val getBatchStart = ctx.freshName("scanStart") - val scanTimeTotalNs = ctx.freshName("scanTime") - ctx.currentVars = null - val columnsBatchInput = (output zip colVars).map { case (attr, colVar) => - genCodeColumnVector(ctx, colVar, rowidx, attr.dataType, attr.nullable) } - val scanBatches = ctx.freshName("processBatches") - ctx.addMutableState("long", scanTimeTotalNs, s"$scanTimeTotalNs = 0;") - - ctx.addNewFunction(scanBatches, - s""" - | private void $scanBatches() throws java.io.IOException { - | while (true) { - | int numRows = $batch.numRows(); - | if ($idx == 0) { - | ${columnAssigns.mkString("", "\n", "\n")} - | $numOutputRows.add(numRows); - | } - | - | while (!shouldStop() && $idx < numRows) { - | int $rowidx = $idx++; - | ${consume(ctx, columnsBatchInput).trim} - | } - | if (shouldStop()) return; - | - | long $getBatchStart = System.nanoTime(); - | if (!$input.hasNext()) { - | $batch = null; - | $scanTimeMetric.add($scanTimeTotalNs / (1000 * 1000)); - | break; - | } - | $batch = ($columnarBatchClz)$input.next(); - | $scanTimeTotalNs += System.nanoTime() - $getBatchStart; - | $idx = 0; - | } - | }""".stripMargin) - - val value = ctx.freshName("value") - s""" - | if ($batch != null) { - | $scanBatches(); - | } else if ($input.hasNext()) { - | Object $value = $input.next(); - | if ($value instanceof $columnarBatchClz) { - | $batch = ($columnarBatchClz)$value; - | $scanBatches(); - | } else { - | $scanRows((InternalRow) $value); - | } - | } - """.stripMargin - } else { - s""" - |if ($input.hasNext()) { - | $scanRows((InternalRow) $input.next()); - |} - """.stripMargin + val rowidx = ctx.freshName("rowIdx") + val columnsBatchInput = (output zip colVars).map { case (attr, colVar) => + genCodeColumnVector(ctx, colVar, rowidx, attr.dataType, attr.nullable) } + s""" + |if ($batch == null) { + | $nextBatch(); + |} + |while ($batch != null) { + | int numRows = $batch.numRows(); + | while ($idx < numRows) { + | int $rowidx = $idx++; + | ${consume(ctx, columnsBatchInput).trim} + | if (shouldStop()) return; + | } + | $batch = null; + | $nextBatch(); + |} + |$scanTimeMetric.add($scanTimeTotalNs / (1000 * 1000)); + |$scanTimeTotalNs = 0; + """.stripMargin } } @@ -346,4 +319,38 @@ private[sql] object DataSourceScan { // Metadata keys val INPUT_PATHS = "InputPaths" val PUSHED_FILTERS = "PushedFilters" + + def create( + output: Seq[Attribute], + rdd: RDD[InternalRow], + relation: BaseRelation, + metadata: Map[String, String] = Map.empty): DataSourceScan = { + val outputPartitioning = { + val bucketSpec = relation match { + // TODO: this should be closer to bucket planning. + case r: HadoopFsRelation if r.sqlContext.conf.bucketingEnabled => r.bucketSpec + case _ => None + } + + def toAttribute(colName: String): Attribute = output.find(_.name == colName).getOrElse { + throw new AnalysisException(s"bucket column $colName not found in existing columns " + + s"(${output.map(_.name).mkString(", ")})") + } + + bucketSpec.map { spec => + val numBuckets = spec.numBuckets + val bucketColumns = spec.bucketColumnNames.map(toAttribute) + HashPartitioning(bucketColumns, numBuckets) + }.getOrElse { + UnknownPartitioning(0) + } + } + + relation match { + case r: HadoopFsRelation if r.fileFormat.supportBatch(r.sqlContext, relation.schema) => + BatchedDataSourceScan(output, rdd, relation, outputPartitioning, metadata) + case _ => + RowDataSourceScan(output, rdd, relation, outputPartitioning, metadata) + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala index 4e75a3a794..98129d6c52 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala @@ -29,6 +29,7 @@ import org.apache.spark.sql.execution.aggregate.TungstenAggregate import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, SortMergeJoin} import org.apache.spark.sql.execution.metric.{LongSQLMetricValue, SQLMetrics} import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types._ /** * An interface for those physical operators that support codegen. @@ -433,12 +434,20 @@ case class CollapseCodegenStages(conf: SQLConf) extends Rule[SparkPlan] { case _ => true } + private def numOfNestedFields(dataType: DataType): Int = dataType match { + case dt: StructType => dt.fields.map(f => numOfNestedFields(f.dataType)).sum + case m: MapType => numOfNestedFields(m.keyType) + numOfNestedFields(m.valueType) + case a: ArrayType => numOfNestedFields(a.elementType) + case u: UserDefinedType[_] => numOfNestedFields(u.sqlType) + case _ => 1 + } + private def supportCodegen(plan: SparkPlan): Boolean = plan match { case plan: CodegenSupport if plan.supportCodegen => val willFallback = plan.expressions.exists(_.find(e => !supportCodegen(e)).isDefined) // the generated code will be huge if there are too many columns - val haveManyColumns = plan.output.length > 200 - !willFallback && !haveManyColumns + val haveTooManyFields = numOfNestedFields(plan.schema) > conf.wholeStageMaxNumFields + !willFallback && !haveTooManyFields case _ => false } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index 52c8f3ef0b..8c183317f6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -238,7 +238,7 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { } case l @ LogicalRelation(baseRelation: TableScan, _, _) => - execution.DataSourceScan( + execution.DataSourceScan.create( l.output, toCatalystRDD(l, baseRelation.buildScan()), baseRelation) :: Nil case i @ logical.InsertIntoTable(l @ LogicalRelation(t: InsertableRelation, _, _), @@ -610,7 +610,7 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { // Don't request columns that are only referenced by pushed filters. .filterNot(handledSet.contains) - val scan = execution.DataSourceScan( + val scan = execution.DataSourceScan.create( projects.map(_.toAttribute), scanBuilder(requestedColumns, candidatePredicates, pushedFilters), relation.relation, metadata) @@ -620,7 +620,7 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { val requestedColumns = (projectSet ++ filterSet -- handledSet).map(relation.attributeMap).toSeq - val scan = execution.DataSourceScan( + val scan = execution.DataSourceScan.create( requestedColumns, scanBuilder(requestedColumns, candidatePredicates, pushedFilters), relation.relation, metadata) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala index 618d5a522b..aa1f76450c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala @@ -181,7 +181,7 @@ private[sql] object FileSourceStrategy extends Strategy with Logging { } val scan = - DataSourceScan( + DataSourceScan.create( readDataColumns ++ partitionColumns, new FileScanRDD( files.sqlContext, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SqlNewHadoopRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SqlNewHadoopRDD.scala index 159fdc99dd..6ddb218a22 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SqlNewHadoopRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SqlNewHadoopRDD.scala @@ -97,13 +97,6 @@ private[spark] class SqlNewHadoopRDD[V: ClassTag]( @transient protected val jobId = new JobID(jobTrackerId, id) - // If true, enable using the custom RecordReader for parquet. This only works for - // a subset of the types (no complex types). - protected val enableVectorizedParquetReader: Boolean = - sqlContext.getConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key).toBoolean - protected val enableWholestageCodegen: Boolean = - sqlContext.getConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key).toBoolean - override def getPartitions: Array[SparkPartition] = { val conf = getConf(isDriverSide = true) val inputFormat = inputFormatClass.newInstance @@ -165,32 +158,9 @@ private[spark] class SqlNewHadoopRDD[V: ClassTag]( } val attemptId = new TaskAttemptID(jobTrackerId, id, TaskType.MAP, split.index, 0) val hadoopAttemptContext = new TaskAttemptContextImpl(conf, attemptId) - private[this] var reader: RecordReader[Void, V] = null - - /** - * If the format is ParquetInputFormat, try to create the optimized RecordReader. If this - * fails (for example, unsupported schema), try with the normal reader. - * TODO: plumb this through a different way? - */ - if (enableVectorizedParquetReader && - format.getClass.getName == "org.apache.parquet.hadoop.ParquetInputFormat") { - val parquetReader: VectorizedParquetRecordReader = new VectorizedParquetRecordReader() - if (!parquetReader.tryInitialize( - split.serializableHadoopSplit.value, hadoopAttemptContext)) { - parquetReader.close() - } else { - reader = parquetReader.asInstanceOf[RecordReader[Void, V]] - parquetReader.resultBatch() - // Whole stage codegen (PhysicalRDD) is able to deal with batches directly - if (enableWholestageCodegen) parquetReader.enableReturningBatches() - } - } - - if (reader == null) { - reader = format.createRecordReader( + private[this] var reader: RecordReader[Void, V] = format.createRecordReader( split.serializableHadoopSplit.value, hadoopAttemptContext) - reader.initialize(split.serializableHadoopSplit.value, hadoopAttemptContext) - } + reader.initialize(split.serializableHadoopSplit.value, hadoopAttemptContext) // Register an on-task-completion callback to close the input stream. context.addTaskCompletionListener(context => close()) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala index 5b58fa1fc5..a2fd8da782 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala @@ -24,7 +24,6 @@ import java.util.logging.{Logger => JLogger} import scala.collection.JavaConverters._ import scala.collection.mutable import scala.util.{Failure, Try} -import scala.util.control.NonFatal import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileStatus, Path} @@ -53,7 +52,7 @@ import org.apache.spark.sql.catalyst.parser.LegacyTypeStringParser import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources._ -import org.apache.spark.sql.types.{DataType, StructType} +import org.apache.spark.sql.types.{AtomicType, DataType, StructType} import org.apache.spark.util.{SerializableConfiguration, Utils} import org.apache.spark.util.collection.BitSet @@ -276,6 +275,16 @@ private[sql] class DefaultSource file.getName == ParquetFileWriter.PARQUET_METADATA_FILE } + /** + * Returns whether the reader will the rows as batch or not. + */ + override def supportBatch(sqlContext: SQLContext, schema: StructType): Boolean = { + val conf = SQLContext.getActive().get.conf + conf.useFileScan && conf.parquetVectorizedReaderEnabled && + conf.wholeStageEnabled && schema.length <= conf.wholeStageMaxNumFields && + schema.forall(_.dataType.isInstanceOf[AtomicType]) + } + override def buildReader( sqlContext: SQLContext, dataSchema: StructType, @@ -306,6 +315,10 @@ private[sql] class DefaultSource SQLConf.PARQUET_INT96_AS_TIMESTAMP.key, sqlContext.conf.getConf(SQLConf.PARQUET_INT96_AS_TIMESTAMP)) + // Whole stage codegen (PhysicalRDD) is able to deal with batches directly + val returningBatch = + supportBatch(sqlContext, StructType(partitionSchema.fields ++ dataSchema.fields)) + // Try to push down filters when filter push-down is enabled. val pushed = if (sqlContext.getConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED.key).toBoolean) { filters @@ -324,10 +337,8 @@ private[sql] class DefaultSource // TODO: if you move this into the closure it reverts to the default values. // If true, enable using the custom RecordReader for parquet. This only works for // a subset of the types (no complex types). - val enableVectorizedParquetReader: Boolean = - sqlContext.getConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key).toBoolean - val enableWholestageCodegen: Boolean = - sqlContext.getConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key).toBoolean + val enableVectorizedParquetReader: Boolean = sqlContext.conf.parquetVectorizedReaderEnabled && + dataSchema.forall(_.dataType.isInstanceOf[AtomicType]) (file: PartitionedFile) => { assert(file.partitionValues.numFields == partitionSchema.size) @@ -347,32 +358,27 @@ private[sql] class DefaultSource val attemptId = new TaskAttemptID(new TaskID(new JobID(), TaskType.MAP, 0), 0) val hadoopAttemptContext = new TaskAttemptContextImpl(broadcastedConf.value.value, attemptId) - val parquetReader = try { - if (!enableVectorizedParquetReader) sys.error("Vectorized reader turned off.") + val parquetReader = if (enableVectorizedParquetReader) { val vectorizedReader = new VectorizedParquetRecordReader() vectorizedReader.initialize(split, hadoopAttemptContext) logDebug(s"Appending $partitionSchema ${file.partitionValues}") vectorizedReader.initBatch(partitionSchema, file.partitionValues) - // Whole stage codegen (PhysicalRDD) is able to deal with batches directly - // TODO: fix column appending - if (enableWholestageCodegen) { - logDebug(s"Enabling batch returning") + if (returningBatch) { vectorizedReader.enableReturningBatches() } vectorizedReader - } catch { - case NonFatal(e) => - logDebug(s"Falling back to parquet-mr: $e", e) - val reader = pushed match { - case Some(filter) => - new ParquetRecordReader[InternalRow]( - new CatalystReadSupport, - FilterCompat.get(filter, null)) - case _ => - new ParquetRecordReader[InternalRow](new CatalystReadSupport) - } - reader.initialize(split, hadoopAttemptContext) - reader + } else { + logDebug(s"Falling back to parquet-mr") + val reader = pushed match { + case Some(filter) => + new ParquetRecordReader[InternalRow]( + new CatalystReadSupport, + FilterCompat.get(filter, null)) + case _ => + new ParquetRecordReader[InternalRow](new CatalystReadSupport) + } + reader.initialize(split, hadoopAttemptContext) + reader } val iter = new RecordReaderIterator(parquetReader) @@ -432,13 +438,21 @@ private[sql] class DefaultSource val setInputPaths = ParquetRelation.initializeDriverSideJobFunc(inputFiles, parquetBlockSize) _ + val allPrimitiveTypes = dataSchema.forall(_.dataType.isInstanceOf[AtomicType]) + val inputFormatCls = if (sqlContext.conf.parquetVectorizedReaderEnabled + && allPrimitiveTypes) { + classOf[VectorizedParquetInputFormat] + } else { + classOf[ParquetInputFormat[InternalRow]] + } + Utils.withDummyCallSite(sqlContext.sparkContext) { new SqlNewHadoopRDD( sqlContext = sqlContext, broadcastedConf = broadcastedConf, initDriverSideJobFuncOpt = Some(setInputPaths), initLocalJobFuncOpt = Some(initLocalJobFuncOpt), - inputFormatClass = classOf[ParquetInputFormat[InternalRow]], + inputFormatClass = inputFormatCls, valueClass = classOf[InternalRow]) { val cacheMetadata = useMetadataCache @@ -481,6 +495,17 @@ private[sql] class DefaultSource } } +/** + * The ParquetInputFormat that create VectorizedParquetRecordReader. + */ +final class VectorizedParquetInputFormat extends ParquetInputFormat[InternalRow] { + override def createRecordReader( + inputSplit: InputSplit, + taskAttemptContext: TaskAttemptContext): ParquetRecordReader[InternalRow] = { + new VectorizedParquetRecordReader().asInstanceOf[ParquetRecordReader[InternalRow]] + } +} + // NOTE: This class is instantiated and used on executor side only, no need to be serializable. private[sql] class ParquetOutputWriter( path: String, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 927af89949..dc6ba1bcfb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -396,6 +396,13 @@ object SQLConf { .booleanConf .createWithDefault(true) + val WHOLESTAGE_MAX_NUM_FIELDS = SQLConfigBuilder("spark.sql.codegen.maxFields") + .internal() + .doc("The maximum number of fields (including nested fields) that will be supported before" + + " deactivating whole-stage codegen.") + .intConf + .createWithDefault(200) + val FILES_MAX_PARTITION_BYTES = SQLConfigBuilder("spark.sql.files.maxPartitionBytes") .doc("The maximum number of bytes to pack into a single partition when reading files.") .longConf @@ -480,6 +487,8 @@ private[sql] class SQLConf extends Serializable with CatalystConf with Logging { def parquetCacheMetadata: Boolean = getConf(PARQUET_CACHE_METADATA) + def parquetVectorizedReaderEnabled: Boolean = getConf(PARQUET_VECTORIZED_READER_ENABLED) + def columnBatchSize: Int = getConf(COLUMN_BATCH_SIZE) def numShufflePartitions: Int = getConf(SHUFFLE_PARTITIONS) @@ -504,6 +513,8 @@ private[sql] class SQLConf extends Serializable with CatalystConf with Logging { def wholeStageEnabled: Boolean = getConf(WHOLESTAGE_CODEGEN_ENABLED) + def wholeStageMaxNumFields: Int = getConf(WHOLESTAGE_MAX_NUM_FIELDS) + def exchangeReuseEnabled: Boolean = getConf(EXCHANGE_REUSE_ENABLED) def canonicalView: Boolean = getConf(CANONICAL_NATIVE_VIEW) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala index 14e14710f6..6acb41dd1f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala @@ -468,6 +468,15 @@ trait FileFormat { broadcastedConf: Broadcast[SerializableConfiguration], options: Map[String, String]): RDD[InternalRow] + /** + * Returns whether this format support returning columnar batch or not. + * + * TODO: we should just have different traits for the different formats. + */ + def supportBatch(sqlContext: SQLContext, dataSchema: StructType): Boolean = { + false + } + /** * Returns a function that can be used to read a single file in as an Iterator of InternalRow. * diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala index 4446a6881c..41f536fc37 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala @@ -279,7 +279,8 @@ class FileSourceStrategySuite extends QueryTest with SharedSQLContext with Predi /** Plans the query and calls the provided validation function with the planned partitioning. */ def checkScan(df: DataFrame)(func: Seq[FilePartition] => Unit): Unit = { val fileScan = df.queryExecution.executedPlan.collect { - case DataSourceScan(_, scan: FileScanRDD, _, _) => scan + case scan: DataSourceScan if scan.rdd.isInstanceOf[FileScanRDD] => + scan.rdd.asInstanceOf[FileScanRDD] }.headOption.getOrElse { fail(s"No FileScan in query\n${df.queryExecution}") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala index 2f806ebba6..7d206e7bc4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala @@ -579,6 +579,16 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext assert(CatalystReadSupport.expandUDT(schema) === expected) } + + test("read/write wide table") { + withTempPath { dir => + val path = dir.getCanonicalPath + + val df = sqlContext.range(1000).select(Seq.tabulate(1000) {i => ('id + i).as(s"c$i")} : _*) + df.write.mode(SaveMode.Overwrite).parquet(path) + checkAnswer(sqlContext.read.parquet(path), df) + } + } } object TestingUDT { -- cgit v1.2.3 From de4792605ad94d3d7548a2139372bb6cac331079 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Wed, 6 Apr 2016 15:45:03 -0700 Subject: [SPARK-14391][LAUNCHER] Increase test timeouts. Most of the time tests should still pass really quickly; it's just when machines are overloaded that the tests may take a little time, but that's still preferable over just failing the test. Author: Marcelo Vanzin Closes #12210 from vanzin/SPARK-14391. --- .../test/java/org/apache/spark/launcher/LauncherServerSuite.java | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/launcher/src/test/java/org/apache/spark/launcher/LauncherServerSuite.java b/launcher/src/test/java/org/apache/spark/launcher/LauncherServerSuite.java index a9039b3ec9..bfe1fcc87f 100644 --- a/launcher/src/test/java/org/apache/spark/launcher/LauncherServerSuite.java +++ b/launcher/src/test/java/org/apache/spark/launcher/LauncherServerSuite.java @@ -83,13 +83,13 @@ public class LauncherServerSuite extends BaseSuite { client = new TestClient(s); client.send(new Hello(handle.getSecret(), "1.4.0")); - assertTrue(semaphore.tryAcquire(1, TimeUnit.SECONDS)); + assertTrue(semaphore.tryAcquire(30, TimeUnit.SECONDS)); // Make sure the server matched the client to the handle. assertNotNull(handle.getConnection()); client.send(new SetAppId("app-id")); - assertTrue(semaphore.tryAcquire(1, TimeUnit.SECONDS)); + assertTrue(semaphore.tryAcquire(30, TimeUnit.SECONDS)); assertEquals("app-id", handle.getAppId()); client.send(new SetState(SparkAppHandle.State.RUNNING)); @@ -97,7 +97,7 @@ public class LauncherServerSuite extends BaseSuite { assertEquals(SparkAppHandle.State.RUNNING, handle.getState()); handle.stop(); - Message stopMsg = client.inbound.poll(10, TimeUnit.SECONDS); + Message stopMsg = client.inbound.poll(30, TimeUnit.SECONDS); assertTrue(stopMsg instanceof Stop); } finally { kill(handle); -- cgit v1.2.3 From 9af5423ec28258becf27dbe89833b4f7d324d26a Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Wed, 6 Apr 2016 15:46:20 -0700 Subject: [SPARK-12133][STREAMING] Streaming dynamic allocation ## What changes were proposed in this pull request? Added a new Executor Allocation Manager for the Streaming scheduler for doing Streaming Dynamic Allocation. ## How was this patch tested Unit tests, and cluster tests. Author: Tathagata Das Closes #12154 from tdas/streaming-dynamic-allocation. --- .../apache/spark/ExecutorAllocationClient.scala | 4 + .../main/scala/org/apache/spark/SparkContext.scala | 10 + .../cluster/CoarseGrainedSchedulerBackend.scala | 4 + .../apache/spark/streaming/StreamingContext.scala | 7 +- .../scheduler/ExecutorAllocationManager.scala | 233 ++++++++++++ .../spark/streaming/scheduler/JobScheduler.scala | 14 + .../streaming/scheduler/ReceiverTracker.scala | 19 + .../scheduler/ExecutorAllocationManagerSuite.scala | 395 +++++++++++++++++++++ 8 files changed, 683 insertions(+), 3 deletions(-) create mode 100644 streaming/src/main/scala/org/apache/spark/streaming/scheduler/ExecutorAllocationManager.scala create mode 100644 streaming/src/test/scala/org/apache/spark/streaming/scheduler/ExecutorAllocationManagerSuite.scala diff --git a/core/src/main/scala/org/apache/spark/ExecutorAllocationClient.scala b/core/src/main/scala/org/apache/spark/ExecutorAllocationClient.scala index 842bfdbadc..8baddf45bf 100644 --- a/core/src/main/scala/org/apache/spark/ExecutorAllocationClient.scala +++ b/core/src/main/scala/org/apache/spark/ExecutorAllocationClient.scala @@ -23,6 +23,10 @@ package org.apache.spark */ private[spark] trait ExecutorAllocationClient { + + /** Get the list of currently active executors */ + private[spark] def getExecutorIds(): Seq[String] + /** * Update the cluster manager on our scheduling needs. Three bits of information are included * to help it make decisions. diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 4b3264cbf5..c40fada64b 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -1360,6 +1360,16 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli listenerBus.addListener(listener) } + private[spark] override def getExecutorIds(): Seq[String] = { + schedulerBackend match { + case b: CoarseGrainedSchedulerBackend => + b.getExecutorIds() + case _ => + logWarning("Requesting executors is only supported in coarse-grained mode") + Nil + } + } + /** * Update the cluster manager on our scheduling needs. Three bits of information are included * to help it make decisions. diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index f71bfd489d..e5abf0e150 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -430,6 +430,10 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp */ private def numExistingExecutors: Int = executorDataMap.size + override def getExecutorIds(): Seq[String] = { + executorDataMap.keySet.toSeq + } + /** * Request an additional number of executors from the cluster manager. * @return whether the request is acknowledged. diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala index 83a1092b16..cc187f5cb4 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala @@ -43,7 +43,7 @@ import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.StreamingContextState._ import org.apache.spark.streaming.dstream._ import org.apache.spark.streaming.receiver.Receiver -import org.apache.spark.streaming.scheduler.{JobScheduler, StreamingListener} +import org.apache.spark.streaming.scheduler.{ExecutorAllocationManager, JobScheduler, StreamingListener} import org.apache.spark.streaming.ui.{StreamingJobProgressListener, StreamingTab} import org.apache.spark.util.{CallSite, ShutdownHookManager, ThreadUtils, Utils} @@ -527,11 +527,12 @@ class StreamingContext private[streaming] ( } } - if (Utils.isDynamicAllocationEnabled(sc.conf)) { + if (Utils.isDynamicAllocationEnabled(sc.conf) || + ExecutorAllocationManager.isDynamicAllocationEnabled(conf)) { logWarning("Dynamic Allocation is enabled for this application. " + "Enabling Dynamic allocation for Spark Streaming applications can cause data loss if " + "Write Ahead Log is not enabled for non-replayable sources like Flume. " + - "See the programming guide for details on how to enable the Write Ahead Log") + "See the programming guide for details on how to enable the Write Ahead Log.") } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ExecutorAllocationManager.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ExecutorAllocationManager.scala new file mode 100644 index 0000000000..f7b6584893 --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ExecutorAllocationManager.scala @@ -0,0 +1,233 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +package org.apache.spark.streaming.scheduler + +import scala.util.Random + +import org.apache.spark.{ExecutorAllocationClient, SparkConf} +import org.apache.spark.internal.Logging +import org.apache.spark.streaming.util.RecurringTimer +import org.apache.spark.util.{Clock, Utils} + +/** + * Class that manages executor allocated to a StreamingContext, and dynamically request or kill + * executors based on the statistics of the streaming computation. This is different from the core + * dynamic allocation policy; the core policy relies on executors being idle for a while, but the + * micro-batch model of streaming prevents any particular executors from being idle for a long + * time. Instead, the measure of "idle-ness" needs to be based on the time taken to process + * each batch. + * + * At a high level, the policy implemented by this class is as follows: + * - Use StreamingListener interface get batch processing times of completed batches + * - Periodically take the average batch completion times and compare with the batch interval + * - If (avg. proc. time / batch interval) >= scaling up ratio, then request more executors. + * The number of executors requested is based on the ratio = (avg. proc. time / batch interval). + * - If (avg. proc. time / batch interval) <= scaling down ratio, then try to kill a executor that + * is not running a receiver. + * + * This features should ideally be used in conjunction with backpressure, as backpressure ensures + * system stability, while executors are being readjusted. + */ +private[streaming] class ExecutorAllocationManager( + client: ExecutorAllocationClient, + receiverTracker: ReceiverTracker, + conf: SparkConf, + batchDurationMs: Long, + clock: Clock) extends StreamingListener with Logging { + + import ExecutorAllocationManager._ + + private val scalingIntervalSecs = conf.getTimeAsSeconds( + SCALING_INTERVAL_KEY, + s"${SCALING_INTERVAL_DEFAULT_SECS}s") + private val scalingUpRatio = conf.getDouble(SCALING_UP_RATIO_KEY, SCALING_UP_RATIO_DEFAULT) + private val scalingDownRatio = conf.getDouble(SCALING_DOWN_RATIO_KEY, SCALING_DOWN_RATIO_DEFAULT) + private val minNumExecutors = conf.getInt( + MIN_EXECUTORS_KEY, + math.max(1, receiverTracker.numReceivers)) + private val maxNumExecutors = conf.getInt(MAX_EXECUTORS_KEY, Integer.MAX_VALUE) + private val timer = new RecurringTimer(clock, scalingIntervalSecs * 1000, + _ => manageAllocation(), "streaming-executor-allocation-manager") + + @volatile private var batchProcTimeSum = 0L + @volatile private var batchProcTimeCount = 0 + + validateSettings() + + def start(): Unit = { + timer.start() + logInfo(s"ExecutorAllocationManager started with " + + s"ratios = [$scalingUpRatio, $scalingDownRatio] and interval = $scalingIntervalSecs sec") + } + + def stop(): Unit = { + timer.stop(interruptTimer = true) + logInfo("ExecutorAllocationManager stopped") + } + + /** + * Manage executor allocation by requesting or killing executors based on the collected + * batch statistics. + */ + private def manageAllocation(): Unit = synchronized { + logInfo(s"Managing executor allocation with ratios = [$scalingUpRatio, $scalingDownRatio]") + if (batchProcTimeCount > 0) { + val averageBatchProcTime = batchProcTimeSum / batchProcTimeCount + val ratio = averageBatchProcTime.toDouble / batchDurationMs + logInfo(s"Average: $averageBatchProcTime, ratio = $ratio" ) + if (ratio >= scalingUpRatio) { + logDebug("Requesting executors") + val numNewExecutors = math.max(math.round(ratio).toInt, 1) + requestExecutors(numNewExecutors) + } else if (ratio <= scalingDownRatio) { + logDebug("Killing executors") + killExecutor() + } + } + batchProcTimeSum = 0 + batchProcTimeCount = 0 + } + + /** Request the specified number of executors over the currently active one */ + private def requestExecutors(numNewExecutors: Int): Unit = { + require(numNewExecutors >= 1) + val allExecIds = client.getExecutorIds() + logDebug(s"Executors (${allExecIds.size}) = ${allExecIds}") + val targetTotalExecutors = + math.max(math.min(maxNumExecutors, allExecIds.size + numNewExecutors), minNumExecutors) + client.requestTotalExecutors(targetTotalExecutors, 0, Map.empty) + logInfo(s"Requested total $targetTotalExecutors executors") + } + + /** Kill an executor that is not running any receiver, if possible */ + private def killExecutor(): Unit = { + val allExecIds = client.getExecutorIds() + logDebug(s"Executors (${allExecIds.size}) = ${allExecIds}") + + if (allExecIds.nonEmpty && allExecIds.size > minNumExecutors) { + val execIdsWithReceivers = receiverTracker.allocatedExecutors.values.flatten.toSeq + logInfo(s"Executors with receivers (${execIdsWithReceivers.size}): ${execIdsWithReceivers}") + + val removableExecIds = allExecIds.diff(execIdsWithReceivers) + logDebug(s"Removable executors (${removableExecIds.size}): ${removableExecIds}") + if (removableExecIds.nonEmpty) { + val execIdToRemove = removableExecIds(Random.nextInt(removableExecIds.size)) + client.killExecutor(execIdToRemove) + logInfo(s"Requested to kill executor $execIdToRemove") + } else { + logInfo(s"No non-receiver executors to kill") + } + } else { + logInfo("No available executor to kill") + } + } + + private def addBatchProcTime(timeMs: Long): Unit = synchronized { + batchProcTimeSum += timeMs + batchProcTimeCount += 1 + logDebug( + s"Added batch processing time $timeMs, sum = $batchProcTimeSum, count = $batchProcTimeCount") + } + + private def validateSettings(): Unit = { + require( + scalingIntervalSecs > 0, + s"Config $SCALING_INTERVAL_KEY must be more than 0") + + require( + scalingUpRatio > 0, + s"Config $SCALING_UP_RATIO_KEY must be more than 0") + + require( + scalingDownRatio > 0, + s"Config $SCALING_DOWN_RATIO_KEY must be more than 0") + + require( + minNumExecutors > 0, + s"Config $MIN_EXECUTORS_KEY must be more than 0") + + require( + maxNumExecutors > 0, + s"$MAX_EXECUTORS_KEY must be more than 0") + + require( + scalingUpRatio > scalingDownRatio, + s"Config $SCALING_UP_RATIO_KEY must be more than config $SCALING_DOWN_RATIO_KEY") + + if (conf.contains(MIN_EXECUTORS_KEY) && conf.contains(MAX_EXECUTORS_KEY)) { + require( + maxNumExecutors >= minNumExecutors, + s"Config $MAX_EXECUTORS_KEY must be more than config $MIN_EXECUTORS_KEY") + } + } + + override def onBatchCompleted(batchCompleted: StreamingListenerBatchCompleted): Unit = { + logDebug("onBatchCompleted called: " + batchCompleted) + if (!batchCompleted.batchInfo.outputOperationInfos.values.exists(_.failureReason.nonEmpty)) { + batchCompleted.batchInfo.processingDelay.foreach(addBatchProcTime) + } + } +} + +private[streaming] object ExecutorAllocationManager extends Logging { + val ENABLED_KEY = "spark.streaming.dynamicAllocation.enabled" + + val SCALING_INTERVAL_KEY = "spark.streaming.dynamicAllocation.scalingInterval" + val SCALING_INTERVAL_DEFAULT_SECS = 60 + + val SCALING_UP_RATIO_KEY = "spark.streaming.dynamicAllocation.scalingUpRatio" + val SCALING_UP_RATIO_DEFAULT = 0.9 + + val SCALING_DOWN_RATIO_KEY = "spark.streaming.dynamicAllocation.scalingDownRatio" + val SCALING_DOWN_RATIO_DEFAULT = 0.3 + + val MIN_EXECUTORS_KEY = "spark.streaming.dynamicAllocation.minExecutors" + + val MAX_EXECUTORS_KEY = "spark.streaming.dynamicAllocation.maxExecutors" + + def isDynamicAllocationEnabled(conf: SparkConf): Boolean = { + val numExecutor = conf.getInt("spark.executor.instances", 0) + val streamingDynamicAllocationEnabled = conf.getBoolean(ENABLED_KEY, false) + if (numExecutor != 0 && streamingDynamicAllocationEnabled) { + throw new IllegalArgumentException( + "Dynamic Allocation for streaming cannot be enabled while spark.executor.instances is set.") + } + if (Utils.isDynamicAllocationEnabled(conf) && streamingDynamicAllocationEnabled) { + throw new IllegalArgumentException( + """ + |Dynamic Allocation cannot be enabled for both streaming and core at the same time. + |Please disable core Dynamic Allocation by setting spark.dynamicAllocation.enabled to + |false to use Dynamic Allocation in streaming. + """.stripMargin) + } + val testing = conf.getBoolean("spark.streaming.dynamicAllocation.testing", false) + numExecutor == 0 && streamingDynamicAllocationEnabled && (!Utils.isLocalMaster(conf) || testing) + } + + def createIfEnabled( + client: ExecutorAllocationClient, + receiverTracker: ReceiverTracker, + conf: SparkConf, + batchDurationMs: Long, + clock: Clock): Option[ExecutorAllocationManager] = { + if (isDynamicAllocationEnabled(conf)) { + Some(new ExecutorAllocationManager(client, receiverTracker, conf, batchDurationMs, clock)) + } else None + } +} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala index 61f9e0974c..303c325274 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala @@ -57,6 +57,8 @@ class JobScheduler(val ssc: StreamingContext) extends Logging { // A tracker to track all the input stream information as well as processed record number var inputInfoTracker: InputInfoTracker = null + private var executorAllocationManager: Option[ExecutorAllocationManager] = None + private var eventLoop: EventLoop[JobSchedulerEvent] = null def start(): Unit = synchronized { @@ -79,8 +81,16 @@ class JobScheduler(val ssc: StreamingContext) extends Logging { listenerBus.start() receiverTracker = new ReceiverTracker(ssc) inputInfoTracker = new InputInfoTracker(ssc) + executorAllocationManager = ExecutorAllocationManager.createIfEnabled( + ssc.sparkContext, + receiverTracker, + ssc.conf, + ssc.graph.batchDuration.milliseconds, + clock) + executorAllocationManager.foreach(ssc.addStreamingListener) receiverTracker.start() jobGenerator.start() + executorAllocationManager.foreach(_.start()) logInfo("Started JobScheduler") } @@ -93,6 +103,10 @@ class JobScheduler(val ssc: StreamingContext) extends Logging { receiverTracker.stop(processAllReceivedData) } + if (executorAllocationManager != null) { + executorAllocationManager.foreach(_.stop()) + } + // Second, stop generating jobs. If it has to process all received data, // then this will wait for all the processing through JobScheduler to be over. jobGenerator.stop(processAllReceivedData) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala index b3ae287001..d67f70732d 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala @@ -92,6 +92,8 @@ private[streaming] case object AllReceiverIds extends ReceiverTrackerLocalMessag private[streaming] case class UpdateReceiverRateLimit(streamUID: Int, newRate: Long) extends ReceiverTrackerLocalMessage +private[streaming] case object GetAllReceiverInfo extends ReceiverTrackerLocalMessage + /** * This class manages the execution of the receivers of ReceiverInputDStreams. Instance of * this class must be created after all input streams have been added and StreamingContext.start() @@ -234,6 +236,20 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false } } + /** + * Get the executors allocated to each receiver. + * @return a map containing receiver ids to optional executor ids. + */ + def allocatedExecutors(): Map[Int, Option[String]] = { + endpoint.askWithRetry[Map[Int, ReceiverTrackingInfo]](GetAllReceiverInfo).mapValues { + _.runningExecutor.map { _.executorId } + } + } + + def numReceivers(): Int = { + receiverInputStreams.size + } + /** Register a receiver */ private def registerReceiver( streamId: Int, @@ -506,9 +522,12 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false case DeregisterReceiver(streamId, message, error) => deregisterReceiver(streamId, message, error) context.reply(true) + // Local messages case AllReceiverIds => context.reply(receiverTrackingInfos.filter(_._2.state != ReceiverState.INACTIVE).keys.toSeq) + case GetAllReceiverInfo => + context.reply(receiverTrackingInfos.toMap) case StopAllReceivers => assert(isTrackerStopping || isTrackerStopped) stopReceivers() diff --git a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ExecutorAllocationManagerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ExecutorAllocationManagerSuite.scala new file mode 100644 index 0000000000..7630f4a75e --- /dev/null +++ b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ExecutorAllocationManagerSuite.scala @@ -0,0 +1,395 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.scheduler + +import org.mockito.Matchers.{eq => meq} +import org.mockito.Mockito._ +import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll, PrivateMethodTester} +import org.scalatest.concurrent.Eventually.{eventually, timeout} +import org.scalatest.mock.MockitoSugar +import org.scalatest.time.SpanSugar._ + +import org.apache.spark.{ExecutorAllocationClient, SparkConf, SparkFunSuite} +import org.apache.spark.streaming.{DummyInputDStream, Seconds, StreamingContext} +import org.apache.spark.util.{ManualClock, Utils} + + +class ExecutorAllocationManagerSuite extends SparkFunSuite + with BeforeAndAfter with BeforeAndAfterAll with MockitoSugar with PrivateMethodTester { + + import ExecutorAllocationManager._ + + private val batchDurationMillis = 1000L + private var allocationClient: ExecutorAllocationClient = null + private var clock: ManualClock = null + + before { + allocationClient = mock[ExecutorAllocationClient] + clock = new ManualClock() + } + + test("basic functionality") { + // Test that adding batch processing time info to allocation manager + // causes executors to be requested and killed accordingly + + // There is 1 receiver, and exec 1 has been allocated to it + withAllocationManager(numReceivers = 1) { case (receiverTracker, allocationManager) => + when(receiverTracker.allocatedExecutors).thenReturn(Map(1 -> Some("1"))) + + /** Add data point for batch processing time and verify executor allocation */ + def addBatchProcTimeAndVerifyAllocation(batchProcTimeMs: Double)(body: => Unit): Unit = { + // 2 active executors + reset(allocationClient) + when(allocationClient.getExecutorIds()).thenReturn(Seq("1", "2")) + addBatchProcTime(allocationManager, batchProcTimeMs.toLong) + clock.advance(SCALING_INTERVAL_DEFAULT_SECS * 1000 + 1) + eventually(timeout(10 seconds)) { + body + } + } + + /** Verify that the expected number of total executor were requested */ + def verifyTotalRequestedExecs(expectedRequestedTotalExecs: Option[Int]): Unit = { + if (expectedRequestedTotalExecs.nonEmpty) { + require(expectedRequestedTotalExecs.get > 0) + verify(allocationClient, times(1)).requestTotalExecutors( + meq(expectedRequestedTotalExecs.get), meq(0), meq(Map.empty)) + } else { + verify(allocationClient, never).requestTotalExecutors(0, 0, Map.empty) + } + } + + /** Verify that a particular executor was killed */ + def verifyKilledExec(expectedKilledExec: Option[String]): Unit = { + if (expectedKilledExec.nonEmpty) { + verify(allocationClient, times(1)).killExecutor(meq(expectedKilledExec.get)) + } else { + verify(allocationClient, never).killExecutor(null) + } + } + + // Batch proc time = batch interval, should increase allocation by 1 + addBatchProcTimeAndVerifyAllocation(batchDurationMillis) { + verifyTotalRequestedExecs(Some(3)) // one already allocated, increase allocation by 1 + verifyKilledExec(None) + } + + // Batch proc time = batch interval * 2, should increase allocation by 2 + addBatchProcTimeAndVerifyAllocation(batchDurationMillis * 2) { + verifyTotalRequestedExecs(Some(4)) + verifyKilledExec(None) + } + + // Batch proc time slightly more than the scale up ratio, should increase allocation by 1 + addBatchProcTimeAndVerifyAllocation(batchDurationMillis * SCALING_UP_RATIO_DEFAULT + 1) { + verifyTotalRequestedExecs(Some(3)) + verifyKilledExec(None) + } + + // Batch proc time slightly less than the scale up ratio, should not change allocation + addBatchProcTimeAndVerifyAllocation(batchDurationMillis * SCALING_UP_RATIO_DEFAULT - 1) { + verifyTotalRequestedExecs(None) + verifyKilledExec(None) + } + + // Batch proc time slightly more than the scale down ratio, should not change allocation + addBatchProcTimeAndVerifyAllocation(batchDurationMillis * SCALING_DOWN_RATIO_DEFAULT + 1) { + verifyTotalRequestedExecs(None) + verifyKilledExec(None) + } + + // Batch proc time slightly more than the scale down ratio, should not change allocation + addBatchProcTimeAndVerifyAllocation(batchDurationMillis * SCALING_DOWN_RATIO_DEFAULT - 1) { + verifyTotalRequestedExecs(None) + verifyKilledExec(Some("2")) + } + } + } + + test("requestExecutors policy") { + + /** Verify that the expected number of total executor were requested */ + def verifyRequestedExecs( + numExecs: Int, + numNewExecs: Int, + expectedRequestedTotalExecs: Int)( + implicit allocationManager: ExecutorAllocationManager): Unit = { + reset(allocationClient) + when(allocationClient.getExecutorIds()).thenReturn((1 to numExecs).map(_.toString)) + requestExecutors(allocationManager, numNewExecs) + verify(allocationClient, times(1)).requestTotalExecutors( + meq(expectedRequestedTotalExecs), meq(0), meq(Map.empty)) + } + + withAllocationManager(numReceivers = 1) { case (_, allocationManager) => + implicit val am = allocationManager + intercept[IllegalArgumentException] { + verifyRequestedExecs(numExecs = 0, numNewExecs = 0, 0) + } + verifyRequestedExecs(numExecs = 0, numNewExecs = 1, expectedRequestedTotalExecs = 1) + verifyRequestedExecs(numExecs = 1, numNewExecs = 1, expectedRequestedTotalExecs = 2) + verifyRequestedExecs(numExecs = 2, numNewExecs = 2, expectedRequestedTotalExecs = 4) + } + + withAllocationManager(numReceivers = 2) { case(_, allocationManager) => + implicit val am = allocationManager + + verifyRequestedExecs(numExecs = 0, numNewExecs = 1, expectedRequestedTotalExecs = 2) + verifyRequestedExecs(numExecs = 1, numNewExecs = 1, expectedRequestedTotalExecs = 2) + verifyRequestedExecs(numExecs = 2, numNewExecs = 2, expectedRequestedTotalExecs = 4) + } + + withAllocationManager( + // Test min 2 executors + new SparkConf().set("spark.streaming.dynamicAllocation.minExecutors", "2")) { + case (_, allocationManager) => + implicit val am = allocationManager + + verifyRequestedExecs(numExecs = 0, numNewExecs = 1, expectedRequestedTotalExecs = 2) + verifyRequestedExecs(numExecs = 0, numNewExecs = 3, expectedRequestedTotalExecs = 3) + verifyRequestedExecs(numExecs = 1, numNewExecs = 1, expectedRequestedTotalExecs = 2) + verifyRequestedExecs(numExecs = 1, numNewExecs = 2, expectedRequestedTotalExecs = 3) + verifyRequestedExecs(numExecs = 2, numNewExecs = 1, expectedRequestedTotalExecs = 3) + verifyRequestedExecs(numExecs = 2, numNewExecs = 2, expectedRequestedTotalExecs = 4) + } + + withAllocationManager( + // Test with max 2 executors + new SparkConf().set("spark.streaming.dynamicAllocation.maxExecutors", "2")) { + case (_, allocationManager) => + implicit val am = allocationManager + + verifyRequestedExecs(numExecs = 0, numNewExecs = 1, expectedRequestedTotalExecs = 1) + verifyRequestedExecs(numExecs = 0, numNewExecs = 3, expectedRequestedTotalExecs = 2) + verifyRequestedExecs(numExecs = 1, numNewExecs = 2, expectedRequestedTotalExecs = 2) + verifyRequestedExecs(numExecs = 2, numNewExecs = 1, expectedRequestedTotalExecs = 2) + verifyRequestedExecs(numExecs = 2, numNewExecs = 2, expectedRequestedTotalExecs = 2) + } + } + + test("killExecutor policy") { + + /** + * Verify that a particular executor was killed, given active executors and executors + * allocated to receivers. + */ + def verifyKilledExec( + execIds: Seq[String], + receiverExecIds: Map[Int, Option[String]], + expectedKilledExec: Option[String])( + implicit x: (ReceiverTracker, ExecutorAllocationManager)): Unit = { + val (receiverTracker, allocationManager) = x + + reset(allocationClient) + when(allocationClient.getExecutorIds()).thenReturn(execIds) + when(receiverTracker.allocatedExecutors).thenReturn(receiverExecIds) + killExecutor(allocationManager) + if (expectedKilledExec.nonEmpty) { + verify(allocationClient, times(1)).killExecutor(meq(expectedKilledExec.get)) + } else { + verify(allocationClient, never).killExecutor(null) + } + } + + withAllocationManager() { case (receiverTracker, allocationManager) => + implicit val rcvrTrackerAndExecAllocMgr = (receiverTracker, allocationManager) + + verifyKilledExec(Nil, Map.empty, None) + verifyKilledExec(Seq("1", "2"), Map.empty, None) + verifyKilledExec(Seq("1"), Map(1 -> Some("1")), None) + verifyKilledExec(Seq("1", "2"), Map(1 -> Some("1")), Some("2")) + verifyKilledExec(Seq("1", "2"), Map(1 -> Some("1"), 2 -> Some("2")), None) + } + + withAllocationManager( + new SparkConf().set("spark.streaming.dynamicAllocation.minExecutors", "2")) { + case (receiverTracker, allocationManager) => + implicit val rcvrTrackerAndExecAllocMgr = (receiverTracker, allocationManager) + + verifyKilledExec(Seq("1", "2"), Map.empty, None) + verifyKilledExec(Seq("1", "2", "3"), Map(1 -> Some("1"), 2 -> Some("2")), Some("3")) + } + } + + test("parameter validation") { + + def validateParams( + numReceivers: Int = 1, + scalingIntervalSecs: Option[Int] = None, + scalingUpRatio: Option[Double] = None, + scalingDownRatio: Option[Double] = None, + minExecs: Option[Int] = None, + maxExecs: Option[Int] = None): Unit = { + require(numReceivers > 0) + val receiverTracker = mock[ReceiverTracker] + when(receiverTracker.numReceivers()).thenReturn(numReceivers) + val conf = new SparkConf() + if (scalingIntervalSecs.nonEmpty) { + conf.set( + "spark.streaming.dynamicAllocation.scalingInterval", + s"${scalingIntervalSecs.get}s") + } + if (scalingUpRatio.nonEmpty) { + conf.set("spark.streaming.dynamicAllocation.scalingUpRatio", scalingUpRatio.get.toString) + } + if (scalingDownRatio.nonEmpty) { + conf.set( + "spark.streaming.dynamicAllocation.scalingDownRatio", + scalingDownRatio.get.toString) + } + if (minExecs.nonEmpty) { + conf.set("spark.streaming.dynamicAllocation.minExecutors", minExecs.get.toString) + } + if (maxExecs.nonEmpty) { + conf.set("spark.streaming.dynamicAllocation.maxExecutors", maxExecs.get.toString) + } + new ExecutorAllocationManager( + allocationClient, receiverTracker, conf, batchDurationMillis, clock) + } + + validateParams(numReceivers = 1) + validateParams(numReceivers = 2, minExecs = Some(1)) + validateParams(numReceivers = 2, minExecs = Some(3)) + validateParams(numReceivers = 2, maxExecs = Some(3)) + validateParams(numReceivers = 2, maxExecs = Some(1)) + validateParams(minExecs = Some(3), maxExecs = Some(3)) + validateParams(scalingIntervalSecs = Some(1)) + validateParams(scalingUpRatio = Some(1.1)) + validateParams(scalingDownRatio = Some(0.1)) + validateParams(scalingUpRatio = Some(1.1), scalingDownRatio = Some(0.1)) + + intercept[IllegalArgumentException] { + validateParams(minExecs = Some(0)) + } + intercept[IllegalArgumentException] { + validateParams(minExecs = Some(-1)) + } + intercept[IllegalArgumentException] { + validateParams(maxExecs = Some(0)) + } + intercept[IllegalArgumentException] { + validateParams(maxExecs = Some(-1)) + } + intercept[IllegalArgumentException] { + validateParams(minExecs = Some(4), maxExecs = Some(3)) + } + intercept[IllegalArgumentException] { + validateParams(scalingIntervalSecs = Some(-1)) + } + intercept[IllegalArgumentException] { + validateParams(scalingIntervalSecs = Some(0)) + } + intercept[IllegalArgumentException] { + validateParams(scalingUpRatio = Some(-0.1)) + } + intercept[IllegalArgumentException] { + validateParams(scalingUpRatio = Some(0)) + } + intercept[IllegalArgumentException] { + validateParams(scalingDownRatio = Some(-0.1)) + } + intercept[IllegalArgumentException] { + validateParams(scalingDownRatio = Some(0)) + } + intercept[IllegalArgumentException] { + validateParams(scalingUpRatio = Some(0.5), scalingDownRatio = Some(0.5)) + } + intercept[IllegalArgumentException] { + validateParams(scalingUpRatio = Some(0.3), scalingDownRatio = Some(0.5)) + } + } + + test("enabling and disabling") { + withStreamingContext(new SparkConf()) { ssc => + ssc.start() + assert(getExecutorAllocationManager(ssc).isEmpty) + } + + withStreamingContext( + new SparkConf().set("spark.streaming.dynamicAllocation.enabled", "true")) { ssc => + ssc.start() + assert(getExecutorAllocationManager(ssc).nonEmpty) + } + + val confWithBothDynamicAllocationEnabled = new SparkConf() + .set("spark.streaming.dynamicAllocation.enabled", "true") + .set("spark.dynamicAllocation.enabled", "true") + .set("spark.dynamicAllocation.testing", "true") + require(Utils.isDynamicAllocationEnabled(confWithBothDynamicAllocationEnabled) === true) + withStreamingContext(confWithBothDynamicAllocationEnabled) { ssc => + intercept[IllegalArgumentException] { + ssc.start() + } + } + } + + private def withAllocationManager( + conf: SparkConf = new SparkConf, + numReceivers: Int = 1 + )(body: (ReceiverTracker, ExecutorAllocationManager) => Unit): Unit = { + + val receiverTracker = mock[ReceiverTracker] + when(receiverTracker.numReceivers()).thenReturn(numReceivers) + + val manager = new ExecutorAllocationManager( + allocationClient, receiverTracker, conf, batchDurationMillis, clock) + try { + manager.start() + body(receiverTracker, manager) + } finally { + manager.stop() + } + } + + private val _addBatchProcTime = PrivateMethod[Unit]('addBatchProcTime) + private val _requestExecutors = PrivateMethod[Unit]('requestExecutors) + private val _killExecutor = PrivateMethod[Unit]('killExecutor) + private val _executorAllocationManager = + PrivateMethod[Option[ExecutorAllocationManager]]('executorAllocationManager) + + private def addBatchProcTime(manager: ExecutorAllocationManager, timeMs: Long): Unit = { + manager invokePrivate _addBatchProcTime(timeMs) + } + + private def requestExecutors(manager: ExecutorAllocationManager, newExecs: Int): Unit = { + manager invokePrivate _requestExecutors(newExecs) + } + + private def killExecutor(manager: ExecutorAllocationManager): Unit = { + manager invokePrivate _killExecutor() + } + + private def getExecutorAllocationManager( + ssc: StreamingContext): Option[ExecutorAllocationManager] = { + ssc.scheduler invokePrivate _executorAllocationManager() + } + + private def withStreamingContext(conf: SparkConf)(body: StreamingContext => Unit): Unit = { + conf.setMaster("local").setAppName(this.getClass.getSimpleName).set( + "spark.streaming.dynamicAllocation.testing", "true") // to test dynamic allocation + + var ssc: StreamingContext = null + try { + ssc = new StreamingContext(conf, Seconds(1)) + new DummyInputDStream(ssc).foreachRDD(_ => { }) + body(ssc) + } finally { + if (ssc != null) ssc.stop() + } + } +} -- cgit v1.2.3 From 457e58befe8cb7c346e54b344a45fa357b68cfc0 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Wed, 6 Apr 2016 16:00:29 -0700 Subject: [SPARK-14424][BUILD][DOCS] Update the build docs to switch from assembly to package and add a no… MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? Change our build docs & shell scripts to that developers are aware of the change from "assembly" to "package" ## How was this patch tested? Manually ran ./bin/spark-shell after ./build/sbt assembly and verified error message printed, ran new suggested build target and verified ./bin/spark-shell runs after this. Author: Holden Karau Author: Holden Karau Closes #12197 from holdenk/SPARK-1424-spark-class-broken-fix-build-docs. --- bin/spark-class | 2 +- docs/building-spark.md | 13 +++---------- 2 files changed, 4 insertions(+), 11 deletions(-) diff --git a/bin/spark-class b/bin/spark-class index b489591778..b2a36b9846 100755 --- a/bin/spark-class +++ b/bin/spark-class @@ -44,7 +44,7 @@ fi if [ ! -d "$SPARK_JARS_DIR" ] && [ -z "$SPARK_TESTING$SPARK_SQL_TESTING" ]; then echo "Failed to find Spark jars directory ($SPARK_JARS_DIR)." 1>&2 - echo "You need to build Spark before running this program." 1>&2 + echo "You need to build Spark with the target \"package\" before running this program." 1>&2 exit 1 else LAUNCH_CLASSPATH="$SPARK_JARS_DIR/*" diff --git a/docs/building-spark.md b/docs/building-spark.md index 13aa80496e..40661604af 100644 --- a/docs/building-spark.md +++ b/docs/building-spark.md @@ -190,13 +190,6 @@ or Java 8 tests are automatically enabled when a Java 8 JDK is detected. If you have JDK 8 installed but it is not the system default, you can set JAVA_HOME to point to JDK 8 before running the tests. -# Building for PySpark on YARN - -PySpark on YARN is only supported if the jar is built with Maven. Further, there is a known problem -with building this assembly jar on Red Hat based operating systems (see [SPARK-1753](https://issues.apache.org/jira/browse/SPARK-1753)). If you wish to -run PySpark on a YARN cluster with Red Hat installed, we recommend that you build the jar elsewhere, -then ship it over to the cluster. We are investigating the exact cause for this. - # Packaging without Hadoop Dependencies for YARN The assembly jar produced by `mvn package` will, by default, include all of Spark's dependencies, including Hadoop and some of its ecosystem projects. On YARN deployments, this causes multiple versions of these to appear on executor classpaths: the version packaged in the Spark assembly and the version on each node, included with `yarn.application.classpath`. The `hadoop-provided` profile builds the assembly without including Hadoop-ecosystem projects, like ZooKeeper and Hadoop itself. @@ -210,7 +203,7 @@ compilation. More advanced developers may wish to use SBT. The SBT build is derived from the Maven POM files, and so the same Maven profiles and variables can be set to control the SBT build. For example: - build/sbt -Pyarn -Phadoop-2.3 assembly + build/sbt -Pyarn -Phadoop-2.3 package To avoid the overhead of launching sbt each time you need to re-compile, you can launch sbt in interactive mode by running `build/sbt`, and then run all build commands at the command @@ -219,9 +212,9 @@ prompt. For more recommendations on reducing build time, refer to the # Testing with SBT -Some of the tests require Spark to be packaged first, so always run `build/sbt assembly` the first time. The following is an example of a correct (build, test) sequence: +Some of the tests require Spark to be packaged first, so always run `build/sbt package` the first time. The following is an example of a correct (build, test) sequence: - build/sbt -Pyarn -Phadoop-2.3 -Phive -Phive-thriftserver assembly + build/sbt -Pyarn -Phadoop-2.3 -Phive -Phive-thriftserver package build/sbt -Pyarn -Phadoop-2.3 -Phive -Phive-thriftserver test To run only a specific test suite as follows: -- cgit v1.2.3 From d717ae1fd74d125a9df21350a70e7c2b2d2b4786 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Wed, 6 Apr 2016 16:02:55 -0700 Subject: [SPARK-14444][BUILD] Add a new scalastyle `NoScalaDoc` to prevent ScalaDoc-style multiline comments ## What changes were proposed in this pull request? According to the [Spark Code Style Guide](https://cwiki.apache.org/confluence/display/SPARK/Spark+Code+Style+Guide#SparkCodeStyleGuide-Indentation), this PR adds a new scalastyle rule to prevent the followings. ``` /** In Spark, we don't use the ScalaDoc style so this * is not correct. */ ``` ## How was this patch tested? Pass the Jenkins tests (including `lint-scala`). Author: Dongjoon Hyun Closes #12221 from dongjoon-hyun/SPARK-14444. --- core/src/main/scala/org/apache/spark/SparkConf.scala | 6 ++++-- .../org/apache/spark/deploy/SparkHadoopUtil.scala | 12 ++++++------ .../scala/org/apache/spark/partial/BoundedDouble.scala | 4 ++-- .../apache/spark/examples/DriverSubmissionTest.scala | 6 ++++-- .../spark/streaming/flume/FlumeInputDStream.scala | 6 ++++-- .../mllib/stat/distribution/MultivariateGaussian.scala | 10 ++++++---- scalastyle-config.xml | 5 +++++ .../apache/spark/sql/catalyst/ScalaReflection.scala | 18 +++++++++--------- .../catalyst/expressions/codegen/CodeGenerator.scala | 8 ++++---- .../apache/spark/sql/catalyst/plans/QueryPlan.scala | 6 ++++-- .../apache/spark/sql/RelationalGroupedDataset.scala | 4 ++-- .../apache/spark/sql/execution/WholeStageCodegen.scala | 4 ++-- .../apache/spark/sql/execution/ui/SparkPlanGraph.scala | 4 ++-- .../scala/org/apache/spark/sql/jdbc/JdbcDialects.scala | 10 +++++----- 14 files changed, 59 insertions(+), 44 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala index 5da2e98f1f..e0fd248c43 100644 --- a/core/src/main/scala/org/apache/spark/SparkConf.scala +++ b/core/src/main/scala/org/apache/spark/SparkConf.scala @@ -419,8 +419,10 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging { */ private[spark] def getenv(name: String): String = System.getenv(name) - /** Checks for illegal or deprecated config settings. Throws an exception for the former. Not - * idempotent - may mutate this conf object to convert deprecated settings to supported ones. */ + /** + * Checks for illegal or deprecated config settings. Throws an exception for the former. Not + * idempotent - may mutate this conf object to convert deprecated settings to supported ones. + */ private[spark] def validateSettings() { if (contains("spark.local.dir")) { val msg = "In Spark 1.0 and later spark.local.dir will be overridden by the value set by " + diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala index 4e8e363635..41ac308808 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala @@ -76,9 +76,9 @@ class SparkHadoopUtil extends Logging { /** - * Appends S3-specific, spark.hadoop.*, and spark.buffer.size configurations to a Hadoop - * configuration. - */ + * Appends S3-specific, spark.hadoop.*, and spark.buffer.size configurations to a Hadoop + * configuration. + */ def appendS3AndSparkHadoopConfigurations(conf: SparkConf, hadoopConf: Configuration): Unit = { // Note: this null check is around more than just access to the "conf" object to maintain // the behavior of the old implementation of this code, for backwards compatibility. @@ -108,9 +108,9 @@ class SparkHadoopUtil extends Logging { } /** - * Return an appropriate (subclass) of Configuration. Creating config can initializes some Hadoop - * subsystems. - */ + * Return an appropriate (subclass) of Configuration. Creating config can initializes some Hadoop + * subsystems. + */ def newConfiguration(conf: SparkConf): Configuration = { val hadoopConf = new Configuration() appendS3AndSparkHadoopConfigurations(conf, hadoopConf) diff --git a/core/src/main/scala/org/apache/spark/partial/BoundedDouble.scala b/core/src/main/scala/org/apache/spark/partial/BoundedDouble.scala index d06b2c67d2..c562c70aba 100644 --- a/core/src/main/scala/org/apache/spark/partial/BoundedDouble.scala +++ b/core/src/main/scala/org/apache/spark/partial/BoundedDouble.scala @@ -28,8 +28,8 @@ class BoundedDouble(val mean: Double, val confidence: Double, val low: Double, v this.mean.hashCode ^ this.confidence.hashCode ^ this.low.hashCode ^ this.high.hashCode /** - * Note that consistent with Double, any NaN value will make equality false - */ + * Note that consistent with Double, any NaN value will make equality false + */ override def equals(that: Any): Boolean = that match { case that: BoundedDouble => { diff --git a/examples/src/main/scala/org/apache/spark/examples/DriverSubmissionTest.scala b/examples/src/main/scala/org/apache/spark/examples/DriverSubmissionTest.scala index a2d59a1c95..d12ef642bd 100644 --- a/examples/src/main/scala/org/apache/spark/examples/DriverSubmissionTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/DriverSubmissionTest.scala @@ -22,8 +22,10 @@ import scala.collection.JavaConverters._ import org.apache.spark.util.Utils -/** Prints out environmental information, sleeps, and then exits. Made to - * test driver submission in the standalone scheduler. */ +/** + * Prints out environmental information, sleeps, and then exits. Made to + * test driver submission in the standalone scheduler. + */ object DriverSubmissionTest { def main(args: Array[String]) { if (args.length < 1) { diff --git a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeInputDStream.scala b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeInputDStream.scala index 6e7c3f358e..13aa817492 100644 --- a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeInputDStream.scala +++ b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeInputDStream.scala @@ -130,8 +130,10 @@ class FlumeEventServer(receiver: FlumeReceiver) extends AvroSourceProtocol { } } -/** A NetworkReceiver which listens for events using the - * Flume Avro interface. */ +/** + * A NetworkReceiver which listens for events using the + * Flume Avro interface. + */ private[streaming] class FlumeReceiver( host: String, diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussian.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussian.scala index 052b5b1d65..6c6e9fb7c6 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussian.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussian.scala @@ -61,15 +61,17 @@ class MultivariateGaussian @Since("1.3.0") ( */ private val (rootSigmaInv: DBM[Double], u: Double) = calculateCovarianceConstants - /** Returns density of this multivariate Gaussian at given point, x - */ + /** + * Returns density of this multivariate Gaussian at given point, x + */ @Since("1.3.0") def pdf(x: Vector): Double = { pdf(x.toBreeze) } - /** Returns the log-density of this multivariate Gaussian at given point, x - */ + /** + * Returns the log-density of this multivariate Gaussian at given point, x + */ @Since("1.3.0") def logpdf(x: Vector): Double = { logpdf(x.toBreeze) diff --git a/scalastyle-config.xml b/scalastyle-config.xml index 33c2cbd293..472a8f4084 100644 --- a/scalastyle-config.xml +++ b/scalastyle-config.xml @@ -223,6 +223,11 @@ This file is divided into 3 sections: ]]> + + (?m)^(\s*)/[*][*].*$(\r|)\n^\1 [*] + Use Javadoc style indentation for multiline comments + + diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index d241b8a79b..4795fc2557 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -762,15 +762,15 @@ trait ScalaReflection { } /** - * Returns the full class name for a type. The returned name is the canonical - * Scala name, where each component is separated by a period. It is NOT the - * Java-equivalent runtime name (no dollar signs). - * - * In simple cases, both the Scala and Java names are the same, however when Scala - * generates constructs that do not map to a Java equivalent, such as singleton objects - * or nested classes in package objects, it uses the dollar sign ($) to create - * synthetic classes, emulating behaviour in Java bytecode. - */ + * Returns the full class name for a type. The returned name is the canonical + * Scala name, where each component is separated by a period. It is NOT the + * Java-equivalent runtime name (no dollar signs). + * + * In simple cases, both the Scala and Java names are the same, however when Scala + * generates constructs that do not map to a Java equivalent, such as singleton objects + * or nested classes in package objects, it uses the dollar sign ($) to create + * synthetic classes, emulating behaviour in Java bytecode. + */ def getClassNameFromType(tpe: `Type`): String = { tpe.erasure.typeSymbol.asClass.fullName } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 1bebd4e904..ee7f4fadca 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -626,15 +626,15 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin object CodeGenerator extends Logging { /** - * Compile the Java source code into a Java class, using Janino. - */ + * Compile the Java source code into a Java class, using Janino. + */ def compile(code: String): GeneratedClass = { cache.get(code) } /** - * Compile the Java source code into a Java class, using Janino. - */ + * Compile the Java source code into a Java class, using Janino. + */ private[this] def doCompile(code: String): GeneratedClass = { val evaluator = new ClassBodyEvaluator() evaluator.setParentClassLoader(Utils.getContextOrSparkClassLoader) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index 609a33e2f1..0a11574f44 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -211,8 +211,10 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT if (changed) makeCopy(newArgs).asInstanceOf[this.type] else this } - /** Returns the result of running [[transformExpressions]] on this node - * and all its children. */ + /** + * Returns the result of running [[transformExpressions]] on this node + * and all its children. + */ def transformAllExpressions(rule: PartialFunction[Expression, Expression]): this.type = { transform { case q: QueryPlan[_] => q.transformExpressions(rule).asInstanceOf[PlanType] diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala index 91c02053ae..7dbf2e6c7c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala @@ -408,7 +408,7 @@ private[sql] object RelationalGroupedDataset { private[sql] object RollupType extends GroupType /** - * To indicate it's the PIVOT - */ + * To indicate it's the PIVOT + */ private[sql] case class PivotType(pivotCol: Expression, values: Seq[Literal]) extends GroupType } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala index 98129d6c52..c4594f0480 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala @@ -312,8 +312,8 @@ case class WholeStageCodegen(child: SparkPlan) extends UnaryNode with CodegenSup } /** Codegened pipeline for: - * ${toCommentSafeString(child.treeString.trim)} - */ + * ${toCommentSafeString(child.treeString.trim)} + */ final class GeneratedIterator extends org.apache.spark.sql.execution.BufferedRowIterator { private Object[] references; diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala index 012b125d6b..c6fcb6956c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala @@ -167,8 +167,8 @@ private[ui] class SparkPlanGraphNode( } /** - * Represent a tree of SparkPlan for WholeStageCodegen. - */ + * Represent a tree of SparkPlan for WholeStageCodegen. + */ private[ui] class SparkPlanGraphCluster( id: Long, name: String, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala index cfe4911cb7..948106fd06 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala @@ -100,11 +100,11 @@ abstract class JdbcDialect extends Serializable { } /** - * Override connection specific properties to run before a select is made. This is in place to - * allow dialects that need special treatment to optimize behavior. - * @param connection The connection object - * @param properties The connection properties. This is passed through from the relation. - */ + * Override connection specific properties to run before a select is made. This is in place to + * allow dialects that need special treatment to optimize behavior. + * @param connection The connection object + * @param properties The connection properties. This is passed through from the relation. + */ def beforeFetch(connection: Connection, properties: Map[String, String]): Unit = { } -- cgit v1.2.3 From c4bb02abf2c5b1724f2f848c79da5ebbf2584e45 Mon Sep 17 00:00:00 2001 From: "Zhang, Liye" Date: Wed, 6 Apr 2016 16:11:59 -0700 Subject: [SPARK-14290][CORE][NETWORK] avoid significant memory copy in netty's transferTo ## What changes were proposed in this pull request? When netty transfer data that is not `FileRegion`, data will be in format of `ByteBuf`, If the data is large, there will occur significant performance issue because there is memory copy underlying in `sun.nio.ch.IOUtil.write`, the CPU is 100% used, and network is very low. In this PR, if data size is large, we will split it into small chunks to call `WritableByteChannel.write()`, so that avoid wasting of memory copy. Because the data can't be written within a single write, and it will call `transferTo` multiple times. ## How was this patch tested? Spark unit test and manual test. Manual test: `sc.parallelize(Array(1,2,3),3).mapPartitions(a=>Array(new Array[Double](1024 * 1024 * 50)).iterator).reduce((a,b)=> a).length` For more details, please refer to [SPARK-14290](https://issues.apache.org/jira/browse/SPARK-14290) Author: Zhang, Liye Closes #12083 from liyezhang556520/spark-14290. --- .../spark/network/protocol/MessageWithHeader.java | 30 +++++++++++++++++++++- 1 file changed, 29 insertions(+), 1 deletion(-) diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageWithHeader.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageWithHeader.java index 66227f96a1..4f8781b42a 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageWithHeader.java +++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageWithHeader.java @@ -18,6 +18,7 @@ package org.apache.spark.network.protocol; import java.io.IOException; +import java.nio.ByteBuffer; import java.nio.channels.WritableByteChannel; import javax.annotation.Nullable; @@ -43,6 +44,14 @@ class MessageWithHeader extends AbstractReferenceCounted implements FileRegion { private final long bodyLength; private long totalBytesTransferred; + /** + * When the write buffer size is larger than this limit, I/O will be done in chunks of this size. + * The size should not be too large as it will waste underlying memory copy. e.g. If network + * avaliable buffer is smaller than this limit, the data cannot be sent within one single write + * operation while it still will make memory copy with this size. + */ + private static final int NIO_BUFFER_LIMIT = 256 * 1024; + /** * Construct a new MessageWithHeader. * @@ -128,8 +137,27 @@ class MessageWithHeader extends AbstractReferenceCounted implements FileRegion { } private int copyByteBuf(ByteBuf buf, WritableByteChannel target) throws IOException { - int written = target.write(buf.nioBuffer()); + ByteBuffer buffer = buf.nioBuffer(); + int written = (buffer.remaining() <= NIO_BUFFER_LIMIT) ? + target.write(buffer) : writeNioBuffer(target, buffer); buf.skipBytes(written); return written; } + + private int writeNioBuffer( + WritableByteChannel writeCh, + ByteBuffer buf) throws IOException { + int originalLimit = buf.limit(); + int ret = 0; + + try { + int ioSize = Math.min(buf.remaining(), NIO_BUFFER_LIMIT); + buf.limit(buf.position() + ioSize); + ret = writeCh.write(buf); + } finally { + buf.limit(originalLimit); + } + + return ret; + } } -- cgit v1.2.3 From f1def573f4c1c757f727476521a1509b5285051d Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Wed, 6 Apr 2016 16:18:04 -0700 Subject: [SPARK-13112][CORE] Make sure RegisterExecutorResponse arrive before LaunchTask ## What changes were proposed in this pull request? Send `RegisterExecutorResponse` using `executorRef` in order to make sure RegisterExecutorResponse and LaunchTask are both sent using the same channel. Then RegisterExecutorResponse will always arrive before LaunchTask ## How was this patch tested? Existing unit tests Closes #12078 Author: Shixiong Zhu Closes #12211 from zsxwing/SPARK-13112. --- .../apache/spark/executor/CoarseGrainedExecutorBackend.scala | 7 +++---- .../scheduler/cluster/CoarseGrainedSchedulerBackend.scala | 6 ++++-- .../test/scala/org/apache/spark/HeartbeatReceiverSuite.scala | 11 ++++++++--- .../spark/deploy/StandaloneDynamicAllocationSuite.scala | 2 +- 4 files changed, 16 insertions(+), 10 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala index 81e41e6fa7..d4ed5845e7 100644 --- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala @@ -57,12 +57,11 @@ private[spark] class CoarseGrainedExecutorBackend( rpcEnv.asyncSetupEndpointRefByURI(driverUrl).flatMap { ref => // This is a very fast action so we can use "ThreadUtils.sameThread" driver = Some(ref) - ref.ask[RegisterExecutorResponse](RegisterExecutor(executorId, self, cores, extractLogUrls)) + ref.ask[Boolean](RegisterExecutor(executorId, self, cores, extractLogUrls)) }(ThreadUtils.sameThread).onComplete { // This is a very fast action so we can use "ThreadUtils.sameThread" - case Success(msg) => Utils.tryLogNonFatalError { - Option(self).foreach(_.send(msg)) // msg must be RegisterExecutorResponse - } + case Success(msg) => + // Always receive `true`. Just ignore it case Failure(e) => { logError(s"Cannot register with driver: $driverUrl", e) System.exit(1) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index e5abf0e150..8896391f97 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -150,7 +150,8 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp case RegisterExecutor(executorId, executorRef, cores, logUrls) => if (executorDataMap.contains(executorId)) { - context.reply(RegisterExecutorFailed("Duplicate executor ID: " + executorId)) + executorRef.send(RegisterExecutorFailed("Duplicate executor ID: " + executorId)) + context.reply(true) } else { // If the executor's rpc env is not listening for incoming connections, `hostPort` // will be null, and the client connection should be used to contact the executor. @@ -177,8 +178,9 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp logDebug(s"Decremented number of pending executors ($numPendingExecutors left)") } } + executorRef.send(RegisteredExecutor(executorAddress.host)) // Note: some tests expect the reply to come after we put the executor in the map - context.reply(RegisteredExecutor(executorAddress.host)) + context.reply(true) listenerBus.post( SparkListenerExecutorAdded(System.currentTimeMillis(), executorId, data)) makeOffers() diff --git a/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala b/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala index 3777d77f8f..713d5e58b4 100644 --- a/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala +++ b/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala @@ -174,9 +174,9 @@ class HeartbeatReceiverSuite val dummyExecutorEndpoint2 = new FakeExecutorEndpoint(rpcEnv) val dummyExecutorEndpointRef1 = rpcEnv.setupEndpoint("fake-executor-1", dummyExecutorEndpoint1) val dummyExecutorEndpointRef2 = rpcEnv.setupEndpoint("fake-executor-2", dummyExecutorEndpoint2) - fakeSchedulerBackend.driverEndpoint.askWithRetry[RegisterExecutorResponse]( + fakeSchedulerBackend.driverEndpoint.askWithRetry[Boolean]( RegisterExecutor(executorId1, dummyExecutorEndpointRef1, 0, Map.empty)) - fakeSchedulerBackend.driverEndpoint.askWithRetry[RegisterExecutorResponse]( + fakeSchedulerBackend.driverEndpoint.askWithRetry[Boolean]( RegisterExecutor(executorId2, dummyExecutorEndpointRef2, 0, Map.empty)) heartbeatReceiverRef.askWithRetry[Boolean](TaskSchedulerIsSet) addExecutorAndVerify(executorId1) @@ -255,7 +255,12 @@ class HeartbeatReceiverSuite /** * Dummy RPC endpoint to simulate executors. */ -private class FakeExecutorEndpoint(override val rpcEnv: RpcEnv) extends RpcEndpoint +private class FakeExecutorEndpoint(override val rpcEnv: RpcEnv) extends RpcEndpoint { + + override def receive: PartialFunction[Any, Unit] = { + case _ => + } +} /** * Dummy scheduler backend to simulate executor allocation requests to the cluster manager. diff --git a/core/src/test/scala/org/apache/spark/deploy/StandaloneDynamicAllocationSuite.scala b/core/src/test/scala/org/apache/spark/deploy/StandaloneDynamicAllocationSuite.scala index d2e24912b5..3d39bd4a74 100644 --- a/core/src/test/scala/org/apache/spark/deploy/StandaloneDynamicAllocationSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/StandaloneDynamicAllocationSuite.scala @@ -561,7 +561,7 @@ class StandaloneDynamicAllocationSuite when(endpointRef.address).thenReturn(mockAddress) val message = RegisterExecutor(id, endpointRef, 10, Map.empty) val backend = sc.schedulerBackend.asInstanceOf[CoarseGrainedSchedulerBackend] - backend.driverEndpoint.askWithRetry[CoarseGrainedClusterMessage](message) + backend.driverEndpoint.askWithRetry[Boolean](message) } } -- cgit v1.2.3 From 864d1b4d665e2cc1d40b53502a4ddf26c1fbfc1d Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Wed, 6 Apr 2016 16:50:59 -0700 Subject: [SPARK-14436][SQL] Make JavaDatasetAggregatorSuiteBase public. Without this, unit tests that extend that class fail for me locally on maven, because JUnit tries to run methods in that class and gets an IllegalAccessError. Author: Marcelo Vanzin Closes #12212 from vanzin/SPARK-14436. --- .../sql/sources/JavaDatasetAggregatorSuite.java | 55 +-------------- .../sources/JavaDatasetAggregatorSuiteBase.java | 81 ++++++++++++++++++++++ 2 files changed, 83 insertions(+), 53 deletions(-) create mode 100644 sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaDatasetAggregatorSuiteBase.java diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaDatasetAggregatorSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaDatasetAggregatorSuite.java index 594f4675bd..8cb174b906 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaDatasetAggregatorSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaDatasetAggregatorSuite.java @@ -17,26 +17,19 @@ package test.org.apache.spark.sql.sources; -import java.io.Serializable; import java.util.Arrays; -import java.util.List; -import org.junit.After; +import scala.Tuple2; + import org.junit.Assert; -import org.junit.Before; import org.junit.Test; -import scala.Tuple2; -import org.apache.spark.SparkContext; -import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.function.MapFunction; import org.apache.spark.sql.Dataset; -import org.apache.spark.sql.Encoder; import org.apache.spark.sql.Encoders; import org.apache.spark.sql.KeyValueGroupedDataset; import org.apache.spark.sql.expressions.Aggregator; import org.apache.spark.sql.expressions.java.typed; -import org.apache.spark.sql.test.TestSQLContext; /** * Suite for testing the aggregate functionality of Datasets in Java. @@ -130,47 +123,3 @@ public class JavaDatasetAggregatorSuite extends JavaDatasetAggregatorSuiteBase { Assert.assertEquals(Arrays.asList(tuple2("a", 3), tuple2("b", 3)), agged.collectAsList()); } } - -/** - * Common test base shared across this and Java8DatasetAggregatorSuite. - */ -class JavaDatasetAggregatorSuiteBase implements Serializable { - protected transient JavaSparkContext jsc; - protected transient TestSQLContext context; - - @Before - public void setUp() { - // Trigger static initializer of TestData - SparkContext sc = new SparkContext("local[*]", "testing"); - jsc = new JavaSparkContext(sc); - context = new TestSQLContext(sc); - context.loadTestData(); - } - - @After - public void tearDown() { - context.sparkContext().stop(); - context = null; - jsc = null; - } - - protected Tuple2 tuple2(T1 t1, T2 t2) { - return new Tuple2<>(t1, t2); - } - - protected KeyValueGroupedDataset> generateGroupedDataset() { - Encoder> encoder = Encoders.tuple(Encoders.STRING(), Encoders.INT()); - List> data = - Arrays.asList(tuple2("a", 1), tuple2("a", 2), tuple2("b", 3)); - Dataset> ds = context.createDataset(data, encoder); - - return ds.groupByKey( - new MapFunction, String>() { - @Override - public String call(Tuple2 value) throws Exception { - return value._1(); - } - }, - Encoders.STRING()); - } -} diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaDatasetAggregatorSuiteBase.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaDatasetAggregatorSuiteBase.java new file mode 100644 index 0000000000..7863177093 --- /dev/null +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaDatasetAggregatorSuiteBase.java @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package test.org.apache.spark.sql.sources; + +import java.io.Serializable; +import java.util.Arrays; +import java.util.List; + +import scala.Tuple2; + +import org.junit.After; +import org.junit.Before; + +import org.apache.spark.SparkContext; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.api.java.function.MapFunction; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Encoder; +import org.apache.spark.sql.Encoders; +import org.apache.spark.sql.KeyValueGroupedDataset; +import org.apache.spark.sql.test.TestSQLContext; + +/** + * Common test base shared across this and Java8DatasetAggregatorSuite. + */ +public class JavaDatasetAggregatorSuiteBase implements Serializable { + protected transient JavaSparkContext jsc; + protected transient TestSQLContext context; + + @Before + public void setUp() { + // Trigger static initializer of TestData + SparkContext sc = new SparkContext("local[*]", "testing"); + jsc = new JavaSparkContext(sc); + context = new TestSQLContext(sc); + context.loadTestData(); + } + + @After + public void tearDown() { + context.sparkContext().stop(); + context = null; + jsc = null; + } + + protected Tuple2 tuple2(T1 t1, T2 t2) { + return new Tuple2<>(t1, t2); + } + + protected KeyValueGroupedDataset> generateGroupedDataset() { + Encoder> encoder = Encoders.tuple(Encoders.STRING(), Encoders.INT()); + List> data = + Arrays.asList(tuple2("a", 1), tuple2("a", 2), tuple2("b", 3)); + Dataset> ds = context.createDataset(data, encoder); + + return ds.groupByKey( + new MapFunction, String>() { + @Override + public String call(Tuple2 value) throws Exception { + return value._1(); + } + }, + Encoders.STRING()); + } +} + -- cgit v1.2.3 From bb873754b4700104755ab969694bf30945557dc3 Mon Sep 17 00:00:00 2001 From: sethah Date: Wed, 6 Apr 2016 17:13:34 -0700 Subject: [SPARK-12382][ML] Remove mllib GBT implementation and wrap ml ## What changes were proposed in this pull request? This patch removes the implementation of gradient boosted trees in mllib/tree/GradientBoostedTrees.scala and changes mllib GBTs to call the implementation in spark.ML. Primary changes: * Removed `boost` method in mllib GradientBoostedTrees.scala * Created new test suite GradientBoostedTreesSuite in ML, which contains unit tests that were specific to GBT internals from mllib Other changes: * Added an `updatePrediction` method in GradientBoostedTrees package. This method is added to provide consistency for methods that build predictions from boosted models. There are several methods that hard code the method of predicting as: sum_{i=1}^{numTrees} (treePrediction*treeWeight). Calling this function ensures that test methods that check accuracy use the same prediction method that the algorithm uses during training * Added methods that were previously only used in testing, but were public methods, to GradientBoostedTrees. This includes `computeError` (previously part of `Loss` trait) and `evaluateEachIteration`. These are used in the new spark.ML unit tests. They are left in mllib as well so as to not break the API. ## How was this patch tested? Existing unit tests which compare ML and MLlib ensure that mllib GBTs have not changed. Only a single unit test was moved to ML, which verifies that `runWithValidation` performs as expected. Author: sethah Closes #12050 from sethah/SPARK-12382. --- .../spark/ml/tree/impl/GradientBoostedTrees.scala | 115 ++++++++++++- .../org/apache/spark/ml/tree/treeModels.scala | 4 +- .../spark/mllib/tree/GradientBoostedTrees.scala | 182 +-------------------- .../spark/ml/regression/GBTRegressorSuite.scala | 2 +- .../ml/tree/impl/GradientBoostedTreesSuite.scala | 85 ++++++++++ .../mllib/tree/GradientBoostedTreesSuite.scala | 45 +---- 6 files changed, 207 insertions(+), 226 deletions(-) create mode 100644 mllib/src/test/scala/org/apache/spark/ml/tree/impl/GradientBoostedTreesSuite.scala diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala index 0749d93b7d..d365655674 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala @@ -19,7 +19,6 @@ package org.apache.spark.ml.tree.impl import org.apache.spark.internal.Logging import org.apache.spark.ml.regression.{DecisionTreeRegressionModel, DecisionTreeRegressor} -import org.apache.spark.ml.tree.DecisionTreeModel import org.apache.spark.mllib.impl.PeriodicRDDCheckpointer import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.regression.LabeledPoint @@ -30,7 +29,24 @@ import org.apache.spark.mllib.tree.loss.{Loss => OldLoss} import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel -private[ml] object GradientBoostedTrees extends Logging { + +/** + * A package that implements + * [[http://en.wikipedia.org/wiki/Gradient_boosting Stochastic Gradient Boosting]] + * for regression and binary classification. + * + * The implementation is based upon: + * J.H. Friedman. "Stochastic Gradient Boosting." 1999. + * + * Notes on Gradient Boosting vs. TreeBoost: + * - This implementation is for Stochastic Gradient Boosting, not for TreeBoost. + * - Both algorithms learn tree ensembles by minimizing loss functions. + * - TreeBoost (Friedman, 1999) additionally modifies the outputs at tree leaf nodes + * based on the loss function, whereas the original gradient boosting method does not. + * - When the loss is SquaredError, these methods give the same result, but they could differ + * for other loss functions. + */ +private[spark] object GradientBoostedTrees extends Logging { /** * Method to train a gradient boosting model @@ -107,7 +123,7 @@ private[ml] object GradientBoostedTrees extends Logging { initTree: DecisionTreeRegressionModel, loss: OldLoss): RDD[(Double, Double)] = { data.map { lp => - val pred = initTreeWeight * initTree.rootNode.predictImpl(lp.features).prediction + val pred = updatePrediction(lp.features, 0.0, initTree, initTreeWeight) val error = loss.computeError(pred, lp.label) (pred, error) } @@ -133,7 +149,7 @@ private[ml] object GradientBoostedTrees extends Logging { val newPredError = data.zip(predictionAndError).mapPartitions { iter => iter.map { case (lp, (pred, error)) => - val newPred = pred + tree.rootNode.predictImpl(lp.features).prediction * treeWeight + val newPred = updatePrediction(lp.features, pred, tree, treeWeight) val newError = loss.computeError(newPred, lp.label) (newPred, newError) } @@ -141,6 +157,97 @@ private[ml] object GradientBoostedTrees extends Logging { newPredError } + /** + * Add prediction from a new boosting iteration to an existing prediction. + * + * @param features Vector of features representing a single data point. + * @param prediction The existing prediction. + * @param tree New Decision Tree model. + * @param weight Tree weight. + * @return Updated prediction. + */ + def updatePrediction( + features: Vector, + prediction: Double, + tree: DecisionTreeRegressionModel, + weight: Double): Double = { + prediction + tree.rootNode.predictImpl(features).prediction * weight + } + + /** + * Method to calculate error of the base learner for the gradient boosting calculation. + * Note: This method is not used by the gradient boosting algorithm but is useful for debugging + * purposes. + * @param data Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. + * @param trees Boosted Decision Tree models + * @param treeWeights Learning rates at each boosting iteration. + * @param loss evaluation metric. + * @return Measure of model error on data + */ + def computeError( + data: RDD[LabeledPoint], + trees: Array[DecisionTreeRegressionModel], + treeWeights: Array[Double], + loss: OldLoss): Double = { + data.map { lp => + val predicted = trees.zip(treeWeights).foldLeft(0.0) { case (acc, (model, weight)) => + updatePrediction(lp.features, acc, model, weight) + } + loss.computeError(predicted, lp.label) + }.mean() + } + + /** + * Method to compute error or loss for every iteration of gradient boosting. + * + * @param data RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] + * @param trees Boosted Decision Tree models + * @param treeWeights Learning rates at each boosting iteration. + * @param loss evaluation metric. + * @param algo algorithm for the ensemble, either Classification or Regression + * @return an array with index i having the losses or errors for the ensemble + * containing the first i+1 trees + */ + def evaluateEachIteration( + data: RDD[LabeledPoint], + trees: Array[DecisionTreeRegressionModel], + treeWeights: Array[Double], + loss: OldLoss, + algo: OldAlgo.Value): Array[Double] = { + + val sc = data.sparkContext + val remappedData = algo match { + case OldAlgo.Classification => data.map(x => new LabeledPoint((x.label * 2) - 1, x.features)) + case _ => data + } + + val numIterations = trees.length + val evaluationArray = Array.fill(numIterations)(0.0) + val localTreeWeights = treeWeights + + var predictionAndError = computeInitialPredictionAndError( + remappedData, localTreeWeights(0), trees(0), loss) + + evaluationArray(0) = predictionAndError.values.mean() + + val broadcastTrees = sc.broadcast(trees) + (1 until numIterations).foreach { nTree => + predictionAndError = remappedData.zip(predictionAndError).mapPartitions { iter => + val currentTree = broadcastTrees.value(nTree) + val currentTreeWeight = localTreeWeights(nTree) + iter.map { case (point, (pred, error)) => + val newPred = updatePrediction(point.features, pred, currentTree, currentTreeWeight) + val newError = loss.computeError(newPred, point.label) + (newPred, newError) + } + } + evaluationArray(nTree) = predictionAndError.values.mean() + } + + broadcastTrees.unpersist() + evaluationArray + } + /** * Internal method for performing regression using trees as base learners. * @param input training dataset diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala index db0ff28d82..c4ab673d9a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala @@ -133,8 +133,8 @@ private[ml] object TreeEnsembleModel { * following the explanation of Gini importance from "Random Forests" documentation * by Leo Breiman and Adele Cutler, and following the implementation from scikit-learn. * - * For collections of trees, including boosting and bagging, Hastie et al. - * propose to use the average of single tree importances across all trees in the ensemble. + * For collections of trees, including boosting and bagging, Hastie et al. + * propose to use the average of single tree importances across all trees in the ensemble. * * This feature importance is calculated as follows: * - Average over trees: diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala index 0f0c6b466d..7fe60e2d99 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala @@ -20,15 +20,11 @@ package org.apache.spark.mllib.tree import org.apache.spark.annotation.Since import org.apache.spark.api.java.JavaRDD import org.apache.spark.internal.Logging -import org.apache.spark.ml.tree.impl.TimeTracker -import org.apache.spark.mllib.impl.PeriodicRDDCheckpointer +import org.apache.spark.ml.tree.impl.{GradientBoostedTrees => NewGBT} import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.mllib.tree.configuration.Algo._ import org.apache.spark.mllib.tree.configuration.BoostingStrategy -import org.apache.spark.mllib.tree.impurity.Variance -import org.apache.spark.mllib.tree.model.{DecisionTreeModel, GradientBoostedTreesModel} +import org.apache.spark.mllib.tree.model.GradientBoostedTreesModel import org.apache.spark.rdd.RDD -import org.apache.spark.storage.StorageLevel /** * A class that implements @@ -70,17 +66,8 @@ class GradientBoostedTrees private[spark] ( @Since("1.2.0") def run(input: RDD[LabeledPoint]): GradientBoostedTreesModel = { val algo = boostingStrategy.treeStrategy.algo - algo match { - case Regression => - GradientBoostedTrees.boost(input, input, boostingStrategy, validate = false, seed) - case Classification => - // Map labels to -1, +1 so binary classification can be treated as regression. - val remappedInput = input.map(x => new LabeledPoint((x.label * 2) - 1, x.features)) - GradientBoostedTrees.boost(remappedInput, remappedInput, boostingStrategy, validate = false, - seed) - case _ => - throw new IllegalArgumentException(s"$algo is not supported by the gradient boosting.") - } + val (trees, treeWeights) = NewGBT.run(input, boostingStrategy, seed.toLong) + new GradientBoostedTreesModel(algo, trees.map(_.toOld), treeWeights) } /** @@ -107,20 +94,9 @@ class GradientBoostedTrees private[spark] ( input: RDD[LabeledPoint], validationInput: RDD[LabeledPoint]): GradientBoostedTreesModel = { val algo = boostingStrategy.treeStrategy.algo - algo match { - case Regression => - GradientBoostedTrees.boost(input, validationInput, boostingStrategy, validate = true, seed) - case Classification => - // Map labels to -1, +1 so binary classification can be treated as regression. - val remappedInput = input.map( - x => new LabeledPoint((x.label * 2) - 1, x.features)) - val remappedValidationInput = validationInput.map( - x => new LabeledPoint((x.label * 2) - 1, x.features)) - GradientBoostedTrees.boost(remappedInput, remappedValidationInput, boostingStrategy, - validate = true, seed) - case _ => - throw new IllegalArgumentException(s"$algo is not supported by the gradient boosting.") - } + val (trees, treeWeights) = NewGBT.runWithValidation(input, validationInput, boostingStrategy, + seed.toLong) + new GradientBoostedTreesModel(algo, trees.map(_.toOld), treeWeights) } /** @@ -162,148 +138,4 @@ object GradientBoostedTrees extends Logging { boostingStrategy: BoostingStrategy): GradientBoostedTreesModel = { train(input.rdd, boostingStrategy) } - - /** - * Internal method for performing regression using trees as base learners. - * - * @param input Training dataset. - * @param validationInput Validation dataset, ignored if validate is set to false. - * @param boostingStrategy Boosting parameters. - * @param validate Whether or not to use the validation dataset. - * @param seed Random seed. - * @return GradientBoostedTreesModel that can be used for prediction. - */ - private def boost( - input: RDD[LabeledPoint], - validationInput: RDD[LabeledPoint], - boostingStrategy: BoostingStrategy, - validate: Boolean, - seed: Int): GradientBoostedTreesModel = { - val timer = new TimeTracker() - timer.start("total") - timer.start("init") - - boostingStrategy.assertValid() - - // Initialize gradient boosting parameters - val numIterations = boostingStrategy.numIterations - val baseLearners = new Array[DecisionTreeModel](numIterations) - val baseLearnerWeights = new Array[Double](numIterations) - val loss = boostingStrategy.loss - val learningRate = boostingStrategy.learningRate - // Prepare strategy for individual trees, which use regression with variance impurity. - val treeStrategy = boostingStrategy.treeStrategy.copy - val validationTol = boostingStrategy.validationTol - treeStrategy.algo = Regression - treeStrategy.impurity = Variance - treeStrategy.assertValid() - - // Cache input - val persistedInput = if (input.getStorageLevel == StorageLevel.NONE) { - input.persist(StorageLevel.MEMORY_AND_DISK) - true - } else { - false - } - - // Prepare periodic checkpointers - val predErrorCheckpointer = new PeriodicRDDCheckpointer[(Double, Double)]( - treeStrategy.getCheckpointInterval, input.sparkContext) - val validatePredErrorCheckpointer = new PeriodicRDDCheckpointer[(Double, Double)]( - treeStrategy.getCheckpointInterval, input.sparkContext) - - timer.stop("init") - - logDebug("##########") - logDebug("Building tree 0") - logDebug("##########") - - // Initialize tree - timer.start("building tree 0") - val firstTreeModel = new DecisionTree(treeStrategy, seed).run(input) - val firstTreeWeight = 1.0 - baseLearners(0) = firstTreeModel - baseLearnerWeights(0) = firstTreeWeight - - var predError: RDD[(Double, Double)] = GradientBoostedTreesModel. - computeInitialPredictionAndError(input, firstTreeWeight, firstTreeModel, loss) - predErrorCheckpointer.update(predError) - logDebug("error of gbt = " + predError.values.mean()) - - // Note: A model of type regression is used since we require raw prediction - timer.stop("building tree 0") - - var validatePredError: RDD[(Double, Double)] = GradientBoostedTreesModel. - computeInitialPredictionAndError(validationInput, firstTreeWeight, firstTreeModel, loss) - if (validate) validatePredErrorCheckpointer.update(validatePredError) - var bestValidateError = if (validate) validatePredError.values.mean() else 0.0 - var bestM = 1 - - var m = 1 - var doneLearning = false - while (m < numIterations && !doneLearning) { - // Update data with pseudo-residuals - val data = predError.zip(input).map { case ((pred, _), point) => - LabeledPoint(-loss.gradient(pred, point.label), point.features) - } - - timer.start(s"building tree $m") - logDebug("###################################################") - logDebug("Gradient boosting tree iteration " + m) - logDebug("###################################################") - val model = new DecisionTree(treeStrategy, seed + m).run(data) - timer.stop(s"building tree $m") - // Update partial model - baseLearners(m) = model - // Note: The setting of baseLearnerWeights is incorrect for losses other than SquaredError. - // Technically, the weight should be optimized for the particular loss. - // However, the behavior should be reasonable, though not optimal. - baseLearnerWeights(m) = learningRate - - predError = GradientBoostedTreesModel.updatePredictionError( - input, predError, baseLearnerWeights(m), baseLearners(m), loss) - predErrorCheckpointer.update(predError) - logDebug("error of gbt = " + predError.values.mean()) - - if (validate) { - // Stop training early if - // 1. Reduction in error is less than the validationTol or - // 2. If the error increases, that is if the model is overfit. - // We want the model returned corresponding to the best validation error. - - validatePredError = GradientBoostedTreesModel.updatePredictionError( - validationInput, validatePredError, baseLearnerWeights(m), baseLearners(m), loss) - validatePredErrorCheckpointer.update(validatePredError) - val currentValidateError = validatePredError.values.mean() - if (bestValidateError - currentValidateError < validationTol * Math.max( - currentValidateError, 0.01)) { - doneLearning = true - } else if (currentValidateError < bestValidateError) { - bestValidateError = currentValidateError - bestM = m + 1 - } - } - m += 1 - } - - timer.stop("total") - - logInfo("Internal timing for DecisionTree:") - logInfo(s"$timer") - - predErrorCheckpointer.deleteAllCheckpoints() - validatePredErrorCheckpointer.deleteAllCheckpoints() - if (persistedInput) input.unpersist() - - if (validate) { - new GradientBoostedTreesModel( - boostingStrategy.treeStrategy.algo, - baseLearners.slice(0, bestM), - baseLearnerWeights.slice(0, bestM)) - } else { - new GradientBoostedTreesModel( - boostingStrategy.treeStrategy.algo, baseLearners, baseLearnerWeights) - } - } - } diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala index 914818f41f..3c11631f98 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala @@ -53,7 +53,7 @@ class GBTRegressorSuite extends SparkFunSuite with MLlibTestSparkContext { sc.parallelize(EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 20, 80), 2) } - test("Regression with continuous features: SquaredError") { + test("Regression with continuous features") { val categoricalFeatures = Map.empty[Int, Int] GBTRegressor.supportedLossTypes.foreach { loss => testCombinations.foreach { diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/GradientBoostedTreesSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/GradientBoostedTreesSuite.scala new file mode 100644 index 0000000000..fecf372c3d --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/GradientBoostedTreesSuite.scala @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.tree.impl + +import org.apache.spark.SparkFunSuite +import org.apache.spark.internal.Logging +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.tree.{GradientBoostedTreesSuite => OldGBTSuite} +import org.apache.spark.mllib.tree.configuration.{BoostingStrategy, Strategy} +import org.apache.spark.mllib.tree.configuration.Algo._ +import org.apache.spark.mllib.tree.impurity.Variance +import org.apache.spark.mllib.tree.loss.{AbsoluteError, LogLoss, SquaredError} +import org.apache.spark.mllib.util.MLlibTestSparkContext + +/** + * Test suite for [[GradientBoostedTrees]]. + */ +class GradientBoostedTreesSuite extends SparkFunSuite with MLlibTestSparkContext with Logging { + + test("runWithValidation stops early and performs better on a validation dataset") { + // Set numIterations large enough so that it stops early. + val numIterations = 20 + val trainRdd = sc.parallelize(OldGBTSuite.trainData, 2) + val validateRdd = sc.parallelize(OldGBTSuite.validateData, 2) + val trainDF = sqlContext.createDataFrame(trainRdd) + val validateDF = sqlContext.createDataFrame(validateRdd) + + val algos = Array(Regression, Regression, Classification) + val losses = Array(SquaredError, AbsoluteError, LogLoss) + algos.zip(losses).foreach { case (algo, loss) => + val treeStrategy = new Strategy(algo = algo, impurity = Variance, maxDepth = 2, + categoricalFeaturesInfo = Map.empty) + val boostingStrategy = + new BoostingStrategy(treeStrategy, loss, numIterations, validationTol = 0.0) + val (validateTrees, validateTreeWeights) = GradientBoostedTrees + .runWithValidation(trainRdd, validateRdd, boostingStrategy, 42L) + val numTrees = validateTrees.length + assert(numTrees !== numIterations) + + // Test that it performs better on the validation dataset. + val (trees, treeWeights) = GradientBoostedTrees.run(trainRdd, boostingStrategy, 42L) + val (errorWithoutValidation, errorWithValidation) = { + if (algo == Classification) { + val remappedRdd = validateRdd.map(x => new LabeledPoint(2 * x.label - 1, x.features)) + (GradientBoostedTrees.computeError(remappedRdd, trees, treeWeights, loss), + GradientBoostedTrees.computeError(remappedRdd, validateTrees, + validateTreeWeights, loss)) + } else { + (GradientBoostedTrees.computeError(validateRdd, trees, treeWeights, loss), + GradientBoostedTrees.computeError(validateRdd, validateTrees, + validateTreeWeights, loss)) + } + } + assert(errorWithValidation <= errorWithoutValidation) + + // Test that results from evaluateEachIteration comply with runWithValidation. + // Note that convergenceTol is set to 0.0 + val evaluationArray = GradientBoostedTrees + .evaluateEachIteration(validateRdd, trees, treeWeights, loss, algo) + assert(evaluationArray.length === numIterations) + assert(evaluationArray(numTrees) > evaluationArray(numTrees - 1)) + var i = 1 + while (i < numTrees) { + assert(evaluationArray(i) <= evaluationArray(i - 1)) + i += 1 + } + } + } + +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala index 747c267b4f..c61f89322d 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala @@ -158,49 +158,6 @@ class GradientBoostedTreesSuite extends SparkFunSuite with MLlibTestSparkContext } } - test("runWithValidation stops early and performs better on a validation dataset") { - // Set numIterations large enough so that it stops early. - val numIterations = 20 - val trainRdd = sc.parallelize(GradientBoostedTreesSuite.trainData, 2) - val validateRdd = sc.parallelize(GradientBoostedTreesSuite.validateData, 2) - - val algos = Array(Regression, Regression, Classification) - val losses = Array(SquaredError, AbsoluteError, LogLoss) - algos.zip(losses).foreach { case (algo, loss) => - val treeStrategy = new Strategy(algo = algo, impurity = Variance, maxDepth = 2, - categoricalFeaturesInfo = Map.empty) - val boostingStrategy = - new BoostingStrategy(treeStrategy, loss, numIterations, validationTol = 0.0) - val gbtValidate = new GradientBoostedTrees(boostingStrategy, seed = 0) - .runWithValidation(trainRdd, validateRdd) - val numTrees = gbtValidate.numTrees - assert(numTrees !== numIterations) - - // Test that it performs better on the validation dataset. - val gbt = new GradientBoostedTrees(boostingStrategy, seed = 0).run(trainRdd) - val (errorWithoutValidation, errorWithValidation) = { - if (algo == Classification) { - val remappedRdd = validateRdd.map(x => new LabeledPoint(2 * x.label - 1, x.features)) - (loss.computeError(gbt, remappedRdd), loss.computeError(gbtValidate, remappedRdd)) - } else { - (loss.computeError(gbt, validateRdd), loss.computeError(gbtValidate, validateRdd)) - } - } - assert(errorWithValidation <= errorWithoutValidation) - - // Test that results from evaluateEachIteration comply with runWithValidation. - // Note that convergenceTol is set to 0.0 - val evaluationArray = gbt.evaluateEachIteration(validateRdd, loss) - assert(evaluationArray.length === numIterations) - assert(evaluationArray(numTrees) > evaluationArray(numTrees - 1)) - var i = 1 - while (i < numTrees) { - assert(evaluationArray(i) <= evaluationArray(i - 1)) - i += 1 - } - } - } - test("Checkpointing") { val tempDir = Utils.createTempDir() val path = tempDir.toURI.toString @@ -220,7 +177,7 @@ class GradientBoostedTreesSuite extends SparkFunSuite with MLlibTestSparkContext } -private object GradientBoostedTreesSuite { +private[spark] object GradientBoostedTreesSuite { // Combinations for estimators, learning rates and subsamplingRate val testCombinations = Array((10, 1.0, 1.0), (10, 0.1, 1.0), (10, 0.5, 0.75), (10, 0.1, 0.75)) -- cgit v1.2.3 From 611dbce4bdd6f34ac1fa67d8dfa3d407600a0237 Mon Sep 17 00:00:00 2001 From: Luciano Resende Date: Thu, 7 Apr 2016 08:35:00 +0800 Subject: [SPARK-12555][SQL] Result should not be corrupted after input columns are reordered This PR add test case described in SPARK-12555 to validate that correct data is returned when input data is reordered and to avoid future regressions. Author: Luciano Resende Closes #11623 from lresende/SPARK-12555. --- .../org/apache/spark/sql/DatasetAggregatorSuite.scala | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala index 5430aff6ce..08b3389ad9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala @@ -84,6 +84,16 @@ object ComplexBufferAgg extends Aggregator[AggData, (Int, AggData), Int] { (b1._1 + b2._1, b1._2) } +object NameAgg extends Aggregator[AggData, String, String] { + def zero: String = "" + + def reduce(b: String, a: AggData): String = a.b + b + + def merge(b1: String, b2: String): String = b1 + b2 + + def finish(r: String): String = r +} + class DatasetAggregatorSuite extends QueryTest with SharedSQLContext { import testImplicits._ @@ -176,4 +186,13 @@ class DatasetAggregatorSuite extends QueryTest with SharedSQLContext { typed.avg(_._2), typed.count(_._2), typed.sum(_._2), typed.sumLong(_._2)), ("a", 2.0, 2L, 4.0, 4L), ("b", 3.0, 1L, 3.0, 3L)) } + + test("SPARK-12555 - result should not be corrupted after input columns are reordered") { + val ds = sql("SELECT 'Some String' AS b, 1279869254 AS a").as[AggData] + + checkDataset( + ds.groupByKey(_.a).agg(NameAgg.toColumn), + (1279869254, "Some String")) + } + } -- cgit v1.2.3 From 4901086fea969a34ec312ef4a8f83d84e1bf21fb Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Wed, 6 Apr 2016 18:30:41 -0700 Subject: [SPARK-14446][TESTS] Fix ReplSuite for Scala 2.10. Just use the same test code as the 2.11 version, which seems to pass. Author: Marcelo Vanzin Closes #12223 from vanzin/SPARK-14446. --- repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala b/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala index 6b9aa5071e..c8b78bc14a 100644 --- a/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala +++ b/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala @@ -357,7 +357,7 @@ class ReplSuite extends SparkFunSuite { | |def sum[I, N : Numeric : Encoder](f: I => N): TypedColumn[I, N] = new SumOf(f).toColumn |val ds = Seq((1, 1, 2L), (1, 2, 3L), (1, 3, 4L), (2, 1, 5L)).toDS() - |ds.groupBy(_._1).agg(sum(_._2), sum(_._3)).collect() + |ds.groupByKey(_._1).agg(sum(_._2), sum(_._3)).collect() """.stripMargin) assertDoesNotContain("error:", output) assertDoesNotContain("Exception", output) -- cgit v1.2.3 From d76592276f9f66fed8012d876595de8717f516a9 Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Wed, 6 Apr 2016 19:25:10 -0700 Subject: [SPARK-12610][SQL] Left Anti Join ### What changes were proposed in this pull request? This PR adds support for `LEFT ANTI JOIN` to Spark SQL. A `LEFT ANTI JOIN` is the exact opposite of a `LEFT SEMI JOIN` and can be used to identify rows in one dataset that are not in another dataset. Note that `nulls` on the left side of the join cannot match a row on the right hand side of the join; the result is that left anti join will always select a row with a `null` in one or more of its keys. We currently add support for the following SQL join syntax: SELECT * FROM tbl1 A LEFT ANTI JOIN tbl2 B ON A.Id = B.Id Or using a dataframe: tbl1.as("a").join(tbl2.as("b"), $"a.id" === $"b.id", "left_anti) This PR provides serves as the basis for implementing `NOT EXISTS` and `NOT IN (...)` correlated sub-queries. It would also serve as good basis for implementing an more efficient `EXCEPT` operator. The PR has been (losely) based on PR's by both davies (https://github.com/apache/spark/pull/10706) and chenghao-intel (https://github.com/apache/spark/pull/10563); credit should be given where credit is due. This PR adds supports for `LEFT ANTI JOIN` to `BroadcastHashJoin` (including codegeneration), `ShuffledHashJoin` and `BroadcastNestedLoopJoin`. ### How was this patch tested? Added tests to `JoinSuite` and ported `ExistenceJoinSuite` from https://github.com/apache/spark/pull/10563. cc davies chenghao-intel rxin Author: Herman van Hovell Closes #12214 from hvanhovell/SPARK-12610. --- .../apache/spark/sql/catalyst/parser/SqlBase.g4 | 2 + .../spark/sql/catalyst/analysis/Analyzer.scala | 2 +- .../spark/sql/catalyst/optimizer/Optimizer.scala | 8 +- .../spark/sql/catalyst/parser/AstBuilder.scala | 1 + .../spark/sql/catalyst/plans/joinTypes.scala | 17 ++- .../catalyst/plans/logical/basicOperators.scala | 4 +- .../sql/catalyst/parser/PlanParserSuite.scala | 5 +- .../apache/spark/sql/execution/SparkPlanner.scala | 2 +- .../spark/sql/execution/SparkStrategies.scala | 11 +- .../sql/execution/joins/BroadcastHashJoin.scala | 99 +++++++++---- .../execution/joins/BroadcastNestedLoopJoin.scala | 57 +++++--- .../spark/sql/execution/joins/HashJoin.scala | 18 ++- .../sql/execution/joins/ShuffledHashJoin.scala | 1 + .../scala/org/apache/spark/sql/JoinSuite.scala | 36 ++--- .../sql/execution/joins/ExistenceJoinSuite.scala | 159 +++++++++++++++++++++ .../spark/sql/execution/joins/SemiJoinSuite.scala | 129 ----------------- .../apache/spark/sql/hive/HiveSessionState.scala | 2 +- 17 files changed, 338 insertions(+), 215 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/joins/ExistenceJoinSuite.scala delete mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index 8a45b4f2e1..85cb585919 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -380,6 +380,7 @@ joinType | LEFT SEMI | RIGHT OUTER? | FULL OUTER? + | LEFT? ANTI ; joinCriteria @@ -878,6 +879,7 @@ INDEX: 'INDEX'; INDEXES: 'INDEXES'; LOCKS: 'LOCKS'; OPTION: 'OPTION'; +ANTI: 'ANTI'; STRING : '\'' ( ~('\''|'\\') | ('\\' .) )* '\'' diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 473c91e69e..bc8cf4e78a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -1424,7 +1424,7 @@ class Analyzer( val projectList = joinType match { case LeftOuter => leftKeys ++ lUniqueOutput ++ rUniqueOutput.map(_.withNullability(true)) - case LeftSemi => + case LeftExistence(_) => leftKeys ++ lUniqueOutput case RightOuter => rightKeys ++ lUniqueOutput.map(_.withNullability(true)) ++ rUniqueOutput diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index c085a377ff..f581810c26 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -361,8 +361,8 @@ object ColumnPruning extends Rule[LogicalPlan] { case p @ Project(_, g: Generate) if g.join && p.references.subsetOf(g.generatedSet) => p.copy(child = g.copy(join = false)) - // Eliminate unneeded attributes from right side of a LeftSemiJoin. - case j @ Join(left, right, LeftSemi, condition) => + // Eliminate unneeded attributes from right side of a Left Existence Join. + case j @ Join(left, right, LeftExistence(_), condition) => j.copy(right = prunedChild(right, j.references)) // all the columns will be used to compare, so we can't prune them @@ -1126,7 +1126,7 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper { (leftFilterConditions ++ commonFilterCondition). reduceLeftOption(And).map(Filter(_, newJoin)).getOrElse(newJoin) - case _ @ (LeftOuter | LeftSemi) => + case LeftOuter | LeftExistence(_) => // push down the left side only `where` condition val newLeft = leftFilterConditions. reduceLeftOption(And).map(Filter(_, left)).getOrElse(left) @@ -1147,7 +1147,7 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper { split(joinCondition.map(splitConjunctivePredicates).getOrElse(Nil), left, right) joinType match { - case _ @ (Inner | LeftSemi) => + case Inner | LeftExistence(_) => // push down the single side only join filter for both sides sub queries val newLeft = leftJoinConditions. reduceLeftOption(And).map(Filter(_, left)).getOrElse(left) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 5a3aebff09..aa59f3fb2a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -572,6 +572,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { case null => Inner case jt if jt.FULL != null => FullOuter case jt if jt.SEMI != null => LeftSemi + case jt if jt.ANTI != null => LeftAnti case jt if jt.LEFT != null => LeftOuter case jt if jt.RIGHT != null => RightOuter case _ => Inner diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala index 9ca4f13dd7..13f57c54a5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala @@ -26,13 +26,15 @@ object JoinType { case "leftouter" | "left" => LeftOuter case "rightouter" | "right" => RightOuter case "leftsemi" => LeftSemi + case "leftanti" => LeftAnti case _ => val supported = Seq( "inner", "outer", "full", "fullouter", "leftouter", "left", "rightouter", "right", - "leftsemi") + "leftsemi", + "leftanti") throw new IllegalArgumentException(s"Unsupported join type '$typ'. " + "Supported join types include: " + supported.mkString("'", "', '", "'") + ".") @@ -63,6 +65,10 @@ case object LeftSemi extends JoinType { override def sql: String = "LEFT SEMI" } +case object LeftAnti extends JoinType { + override def sql: String = "LEFT ANTI" +} + case class NaturalJoin(tpe: JoinType) extends JoinType { require(Seq(Inner, LeftOuter, RightOuter, FullOuter).contains(tpe), "Unsupported natural join type " + tpe) @@ -70,7 +76,14 @@ case class NaturalJoin(tpe: JoinType) extends JoinType { } case class UsingJoin(tpe: JoinType, usingColumns: Seq[UnresolvedAttribute]) extends JoinType { - require(Seq(Inner, LeftOuter, LeftSemi, RightOuter, FullOuter).contains(tpe), + require(Seq(Inner, LeftOuter, LeftSemi, RightOuter, FullOuter, LeftAnti).contains(tpe), "Unsupported using join type " + tpe) override def sql: String = "USING " + tpe.sql } + +object LeftExistence { + def unapply(joinType: JoinType): Option[JoinType] = joinType match { + case LeftSemi | LeftAnti => Some(joinType) + case _ => None + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index a18efc90ab..d3353beb09 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -252,7 +252,7 @@ case class Join( override def output: Seq[Attribute] = { joinType match { - case LeftSemi => + case LeftExistence(_) => left.output case LeftOuter => left.output ++ right.output.map(_.withNullability(true)) @@ -276,7 +276,7 @@ case class Join( .union(splitConjunctivePredicates(condition.get).toSet) case Inner => left.constraints.union(right.constraints) - case LeftSemi => + case LeftExistence(_) => left.constraints case LeftOuter => left.constraints diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala index 262537d9c7..411e2372f2 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala @@ -334,7 +334,7 @@ class PlanParserSuite extends PlanTest { table("t").join(table("u"), UsingJoin(jt, Seq('a.attr, 'b.attr)), None).select(star())) } val testAll = Seq(testUnconditionalJoin, testConditionalJoin, testNaturalJoin, testUsingJoin) - + val testExistence = Seq(testUnconditionalJoin, testConditionalJoin, testUsingJoin) def test(sql: String, jt: JoinType, tests: Seq[(String, JoinType) => Unit]): Unit = { tests.foreach(_(sql, jt)) } @@ -348,6 +348,9 @@ class PlanParserSuite extends PlanTest { test("right outer join", RightOuter, testAll) test("full join", FullOuter, testAll) test("full outer join", FullOuter, testAll) + test("left semi join", LeftSemi, testExistence) + test("left anti join", LeftAnti, testExistence) + test("anti join", LeftAnti, testExistence) // Test multiple consecutive joins assertEqual( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala index ac8072f3ca..8d05ae470d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala @@ -38,7 +38,7 @@ class SparkPlanner( DDLStrategy :: SpecialLimits :: Aggregation :: - LeftSemiJoin :: + ExistenceJoin :: EquiJoinSelection :: InMemoryScans :: BasicOperators :: diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index d77aba7260..eee2b946e3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -62,16 +62,17 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { } } - object LeftSemiJoin extends Strategy with PredicateHelper { + object ExistenceJoin extends Strategy with PredicateHelper { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case ExtractEquiJoinKeys( - LeftSemi, leftKeys, rightKeys, condition, left, CanBroadcast(right)) => + LeftExistence(jt), leftKeys, rightKeys, condition, left, CanBroadcast(right)) => Seq(joins.BroadcastHashJoin( - leftKeys, rightKeys, LeftSemi, BuildRight, condition, planLater(left), planLater(right))) + leftKeys, rightKeys, jt, BuildRight, condition, planLater(left), planLater(right))) // Find left semi joins where at least some predicates can be evaluated by matching join keys - case ExtractEquiJoinKeys(LeftSemi, leftKeys, rightKeys, condition, left, right) => + case ExtractEquiJoinKeys( + LeftExistence(jt), leftKeys, rightKeys, condition, left, right) => Seq(joins.ShuffledHashJoin( - leftKeys, rightKeys, LeftSemi, BuildRight, condition, planLater(left), planLater(right))) + leftKeys, rightKeys, jt, BuildRight, condition, planLater(left), planLater(right))) case _ => Nil } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala index 67ac9e94ff..e3d554c2de 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala @@ -27,7 +27,6 @@ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.physical.{BroadcastDistribution, Distribution, Partitioning, UnspecifiedDistribution} import org.apache.spark.sql.execution.{BinaryNode, CodegenSupport, SparkPlan} import org.apache.spark.sql.execution.metric.SQLMetrics -import org.apache.spark.util.collection.CompactBuffer /** * Performs an inner hash join of two child relations. When the output RDD of this operator is @@ -87,6 +86,7 @@ case class BroadcastHashJoin( case Inner => codegenInner(ctx, input) case LeftOuter | RightOuter => codegenOuter(ctx, input) case LeftSemi => codegenSemi(ctx, input) + case LeftAnti => codegenAnti(ctx, input) case x => throw new IllegalArgumentException( s"BroadcastHashJoin should not take $x as the JoinType") @@ -160,15 +160,14 @@ case class BroadcastHashJoin( } /** - * Generates the code for Inner join. + * Generate the (non-equi) condition used to filter joined rows. This is used in Inner, Left Semi + * and Left Anti joins. */ - private def codegenInner(ctx: CodegenContext, input: Seq[ExprCode]): String = { - val (broadcastRelation, relationTerm) = prepareBroadcast(ctx) - val (keyEv, anyNull) = genStreamSideJoinKey(ctx, input) + private def getJoinCondition( + ctx: CodegenContext, + input: Seq[ExprCode]): (String, String, Seq[ExprCode]) = { val matched = ctx.freshName("matched") val buildVars = genBuildSideVars(ctx, matched) - val numOutput = metricTerm(ctx, "numOutputRows") - val checkCondition = if (condition.isDefined) { val expr = condition.get // evaluate the variables from build side that used by condition @@ -184,6 +183,17 @@ case class BroadcastHashJoin( } else { "" } + (matched, checkCondition, buildVars) + } + + /** + * Generates the code for Inner join. + */ + private def codegenInner(ctx: CodegenContext, input: Seq[ExprCode]): String = { + val (broadcastRelation, relationTerm) = prepareBroadcast(ctx) + val (keyEv, anyNull) = genStreamSideJoinKey(ctx, input) + val (matched, checkCondition, buildVars) = getJoinCondition(ctx, input) + val numOutput = metricTerm(ctx, "numOutputRows") val resultVars = buildSide match { case BuildLeft => buildVars ++ input @@ -221,7 +231,6 @@ case class BroadcastHashJoin( } } - /** * Generates the code for left or right outer join. */ @@ -276,7 +285,6 @@ case class BroadcastHashJoin( ctx.copyResult = true val matches = ctx.freshName("matches") val iteratorCls = classOf[Iterator[UnsafeRow]].getName - val i = ctx.freshName("i") val found = ctx.freshName("found") s""" |// generate join key for stream side @@ -304,26 +312,8 @@ case class BroadcastHashJoin( private def codegenSemi(ctx: CodegenContext, input: Seq[ExprCode]): String = { val (broadcastRelation, relationTerm) = prepareBroadcast(ctx) val (keyEv, anyNull) = genStreamSideJoinKey(ctx, input) - val matched = ctx.freshName("matched") - val buildVars = genBuildSideVars(ctx, matched) + val (matched, checkCondition, _) = getJoinCondition(ctx, input) val numOutput = metricTerm(ctx, "numOutputRows") - - val checkCondition = if (condition.isDefined) { - val expr = condition.get - // evaluate the variables from build side that used by condition - val eval = evaluateRequiredVariables(buildPlan.output, buildVars, expr.references) - // filter the output via condition - ctx.currentVars = input ++ buildVars - val ev = BindReferences.bindReference(expr, streamedPlan.output ++ buildPlan.output).gen(ctx) - s""" - |$eval - |${ev.code} - |if (${ev.isNull} || !${ev.value}) continue; - """.stripMargin - } else { - "" - } - if (broadcastRelation.value.keyIsUnique) { s""" |// generate join key for stream side @@ -357,4 +347,57 @@ case class BroadcastHashJoin( """.stripMargin } } + + /** + * Generates the code for anti join. + */ + private def codegenAnti(ctx: CodegenContext, input: Seq[ExprCode]): String = { + val (broadcastRelation, relationTerm) = prepareBroadcast(ctx) + val (keyEv, anyNull) = genStreamSideJoinKey(ctx, input) + val (matched, checkCondition, _) = getJoinCondition(ctx, input) + val numOutput = metricTerm(ctx, "numOutputRows") + + if (broadcastRelation.value.keyIsUnique) { + s""" + |// generate join key for stream side + |${keyEv.code} + |// Check if the key has nulls. + |if (!($anyNull)) { + | // Check if the HashedRelation exists. + | UnsafeRow $matched = (UnsafeRow)$relationTerm.getValue(${keyEv.value}); + | if ($matched != null) { + | // Evaluate the condition. + | $checkCondition + | } + |} + |$numOutput.add(1); + |${consume(ctx, input)} + """.stripMargin + } else { + val matches = ctx.freshName("matches") + val iteratorCls = classOf[Iterator[UnsafeRow]].getName + val found = ctx.freshName("found") + s""" + |// generate join key for stream side + |${keyEv.code} + |// Check if the key has nulls. + |if (!($anyNull)) { + | // Check if the HashedRelation exists. + | $iteratorCls $matches = ($iteratorCls)$relationTerm.get(${keyEv.value}); + | if ($matches != null) { + | // Evaluate the condition. + | boolean $found = false; + | while (!$found && $matches.hasNext()) { + | UnsafeRow $matched = (UnsafeRow) $matches.next(); + | $checkCondition + | $found = true; + | } + | if ($found) continue; + | } + |} + |$numOutput.add(1); + |${consume(ctx, input)} + """.stripMargin + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala index 4143e944e5..4ba710c10a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala @@ -73,7 +73,7 @@ case class BroadcastNestedLoopJoin( left.output.map(_.withNullability(true)) ++ right.output case FullOuter => left.output.map(_.withNullability(true)) ++ right.output.map(_.withNullability(true)) - case LeftSemi => + case LeftExistence(_) => left.output case x => throw new IllegalArgumentException( @@ -175,8 +175,11 @@ case class BroadcastNestedLoopJoin( * The implementation for these joins: * * LeftSemi with BuildRight + * Anti with BuildRight */ - private def leftSemiJoin(relation: Broadcast[Array[InternalRow]]): RDD[InternalRow] = { + private def leftExistenceJoin( + relation: Broadcast[Array[InternalRow]], + exists: Boolean): RDD[InternalRow] = { assert(buildSide == BuildRight) streamed.execute().mapPartitionsInternal { streamedIter => val buildRows = relation.value @@ -184,10 +187,12 @@ case class BroadcastNestedLoopJoin( if (condition.isDefined) { streamedIter.filter(l => - buildRows.exists(r => boundCondition(joinedRow(l, r))) + buildRows.exists(r => boundCondition(joinedRow(l, r))) == exists ) + } else if (buildRows.nonEmpty == exists) { + streamedIter } else { - streamedIter.filter(r => !buildRows.isEmpty) + Iterator.empty } } } @@ -199,6 +204,7 @@ case class BroadcastNestedLoopJoin( * RightOuter with BuildRight * FullOuter * LeftSemi with BuildLeft + * Anti with BuildLeft */ private def defaultJoin(relation: Broadcast[Array[InternalRow]]): RDD[InternalRow] = { /** All rows that either match both-way, or rows from streamed joined with nulls. */ @@ -236,7 +242,27 @@ case class BroadcastNestedLoopJoin( } i += 1 } - return sparkContext.makeRDD(buf.toSeq) + return sparkContext.makeRDD(buf) + } + + val notMatchedBroadcastRows: Seq[InternalRow] = { + val nulls = new GenericMutableRow(streamed.output.size) + val buf: CompactBuffer[InternalRow] = new CompactBuffer() + var i = 0 + val buildRows = relation.value + val joinedRow = new JoinedRow + joinedRow.withLeft(nulls) + while (i < buildRows.length) { + if (!matchedBroadcastRows.get(i)) { + buf += joinedRow.withRight(buildRows(i)).copy() + } + i += 1 + } + buf + } + + if (joinType == LeftAnti) { + return sparkContext.makeRDD(notMatchedBroadcastRows) } val matchedStreamRows = streamRdd.mapPartitionsInternal { streamedIter => @@ -264,22 +290,6 @@ case class BroadcastNestedLoopJoin( } } - val notMatchedBroadcastRows: Seq[InternalRow] = { - val nulls = new GenericMutableRow(streamed.output.size) - val buf: CompactBuffer[InternalRow] = new CompactBuffer() - var i = 0 - val buildRows = relation.value - val joinedRow = new JoinedRow - joinedRow.withLeft(nulls) - while (i < buildRows.length) { - if (!matchedBroadcastRows.get(i)) { - buf += joinedRow.withRight(buildRows(i)).copy() - } - i += 1 - } - buf.toSeq - } - sparkContext.union( matchedStreamRows, sparkContext.makeRDD(notMatchedBroadcastRows) @@ -295,13 +305,16 @@ case class BroadcastNestedLoopJoin( case (LeftOuter, BuildRight) | (RightOuter, BuildLeft) => outerJoin(broadcastedRelation) case (LeftSemi, BuildRight) => - leftSemiJoin(broadcastedRelation) + leftExistenceJoin(broadcastedRelation, exists = true) + case (LeftAnti, BuildRight) => + leftExistenceJoin(broadcastedRelation, exists = false) case _ => /** * LeftOuter with BuildLeft * RightOuter with BuildRight * FullOuter * LeftSemi with BuildLeft + * Anti with BuildLeft */ defaultJoin(broadcastedRelation) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala index b7c0f3e7d1..8f45d57126 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala @@ -47,7 +47,7 @@ trait HashJoin { left.output ++ right.output.map(_.withNullability(true)) case RightOuter => left.output.map(_.withNullability(true)) ++ right.output - case LeftSemi => + case LeftExistence(_) => left.output case x => throw new IllegalArgumentException(s"HashJoin should not take $x as the JoinType") @@ -197,6 +197,20 @@ trait HashJoin { } } + private def antiJoin( + streamIter: Iterator[InternalRow], + hashedRelation: HashedRelation): Iterator[InternalRow] = { + val joinKeys = streamSideKeyGenerator() + val joinedRow = new JoinedRow + streamIter.filter { current => + val key = joinKeys(current) + lazy val buildIter = hashedRelation.get(key) + key.anyNull || buildIter == null || (condition.isDefined && !buildIter.exists { + row => boundCondition(joinedRow(current, row)) + }) + } + } + protected def join( streamedIter: Iterator[InternalRow], hashed: HashedRelation, @@ -209,6 +223,8 @@ trait HashJoin { outerJoin(streamedIter, hashed) case LeftSemi => semiJoin(streamedIter, hashed) + case LeftAnti => + antiJoin(streamedIter, hashed) case x => throw new IllegalArgumentException( s"BroadcastHashJoin should not take $x as the JoinType") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala index c63faacf33..bf86096379 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala @@ -45,6 +45,7 @@ case class ShuffledHashJoin( override def outputPartitioning: Partitioning = joinType match { case Inner => PartitioningCollection(Seq(left.outputPartitioning, right.outputPartitioning)) + case LeftAnti => left.outputPartitioning case LeftSemi => left.outputPartitioning case LeftOuter => left.outputPartitioning case RightOuter => right.outputPartitioning diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index a5a4ff13de..a87a41c126 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -41,7 +41,8 @@ class JoinSuite extends QueryTest with SharedSQLContext { assert(planned.size === 1) } - def assertJoin(sqlString: String, c: Class[_]): Any = { + def assertJoin(pair: (String, Class[_])): Any = { + val (sqlString, c) = pair val df = sql(sqlString) val physical = df.queryExecution.sparkPlan val operators = physical.collect { @@ -53,8 +54,8 @@ class JoinSuite extends QueryTest with SharedSQLContext { } assert(operators.size === 1) - if (operators(0).getClass() != c) { - fail(s"$sqlString expected operator: $c, but got ${operators(0)}\n physical: \n$physical") + if (operators.head.getClass != c) { + fail(s"$sqlString expected operator: $c, but got ${operators.head}\n physical: \n$physical") } } @@ -93,8 +94,10 @@ class JoinSuite extends QueryTest with SharedSQLContext { ("SELECT * FROM testData right JOIN testData2 ON (key * a != key + a)", classOf[BroadcastNestedLoopJoin]), ("SELECT * FROM testData full JOIN testData2 ON (key * a != key + a)", - classOf[BroadcastNestedLoopJoin]) - ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } + classOf[BroadcastNestedLoopJoin]), + ("SELECT * FROM testData ANTI JOIN testData2 ON key = a", classOf[ShuffledHashJoin]), + ("SELECT * FROM testData LEFT ANTI JOIN testData2", classOf[BroadcastNestedLoopJoin]) + ).foreach(assertJoin) } } @@ -114,7 +117,7 @@ class JoinSuite extends QueryTest with SharedSQLContext { classOf[BroadcastHashJoin]), ("SELECT * FROM testData join testData2 ON key = a where key = 2", classOf[BroadcastHashJoin]) - ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } + ).foreach(assertJoin) sql("UNCACHE TABLE testData") } @@ -129,7 +132,7 @@ class JoinSuite extends QueryTest with SharedSQLContext { classOf[BroadcastHashJoin]), ("SELECT * FROM testData right join testData2 ON key = a and key = 2", classOf[BroadcastHashJoin]) - ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } + ).foreach(assertJoin) sql("UNCACHE TABLE testData") } @@ -419,25 +422,22 @@ class JoinSuite extends QueryTest with SharedSQLContext { Row(null, 10)) } - test("broadcasted left semi join operator selection") { + test("broadcasted existence join operator selection") { sqlContext.cacheManager.clearCache() sql("CACHE TABLE testData") withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "1000000000") { Seq( - ("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a", - classOf[BroadcastHashJoin]) - ).foreach { - case (query, joinClass) => assertJoin(query, joinClass) - } + ("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a", classOf[BroadcastHashJoin]), + ("SELECT * FROM testData ANT JOIN testData2 ON key = a", classOf[BroadcastHashJoin]) + ).foreach(assertJoin) } withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { Seq( - ("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a", classOf[ShuffledHashJoin]) - ).foreach { - case (query, joinClass) => assertJoin(query, joinClass) - } + ("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a", classOf[ShuffledHashJoin]), + ("SELECT * FROM testData LEFT ANTI JOIN testData2 ON key = a", classOf[ShuffledHashJoin]) + ).foreach(assertJoin) } sql("UNCACHE TABLE testData") @@ -489,7 +489,7 @@ class JoinSuite extends QueryTest with SharedSQLContext { classOf[BroadcastNestedLoopJoin]), ("SELECT * FROM testData full JOIN testData2 WHERE (key * a != key + a)", classOf[BroadcastNestedLoopJoin]) - ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } + ).foreach(assertJoin) checkAnswer( sql( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/ExistenceJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/ExistenceJoinSuite.scala new file mode 100644 index 0000000000..8cdfa8afd0 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/ExistenceJoinSuite.scala @@ -0,0 +1,159 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.joins + +import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.catalyst.expressions.{And, Expression, LessThan} +import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys +import org.apache.spark.sql.catalyst.plans.{Inner, JoinType, LeftAnti, LeftSemi} +import org.apache.spark.sql.catalyst.plans.logical.Join +import org.apache.spark.sql.execution.{SparkPlan, SparkPlanTest} +import org.apache.spark.sql.execution.exchange.EnsureRequirements +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types.{DoubleType, IntegerType, StructType} + +class ExistenceJoinSuite extends SparkPlanTest with SharedSQLContext { + + private lazy val left = sqlContext.createDataFrame( + sparkContext.parallelize(Seq( + Row(1, 2.0), + Row(1, 2.0), + Row(2, 1.0), + Row(2, 1.0), + Row(3, 3.0), + Row(null, null), + Row(null, 5.0), + Row(6, null) + )), new StructType().add("a", IntegerType).add("b", DoubleType)) + + private lazy val right = sqlContext.createDataFrame( + sparkContext.parallelize(Seq( + Row(2, 3.0), + Row(2, 3.0), + Row(3, 2.0), + Row(4, 1.0), + Row(null, null), + Row(null, 5.0), + Row(6, null) + )), new StructType().add("c", IntegerType).add("d", DoubleType)) + + private lazy val condition = { + And((left.col("a") === right.col("c")).expr, + LessThan(left.col("b").expr, right.col("d").expr)) + } + + private lazy val conditionNEQ = { + And((left.col("a") < right.col("c")).expr, + LessThan(left.col("b").expr, right.col("d").expr)) + } + + // Note: the input dataframes and expression must be evaluated lazily because + // the SQLContext should be used only within a test to keep SQL tests stable + private def testExistenceJoin( + testName: String, + joinType: JoinType, + leftRows: => DataFrame, + rightRows: => DataFrame, + condition: => Expression, + expectedAnswer: Seq[Row]): Unit = { + + def extractJoinParts(): Option[ExtractEquiJoinKeys.ReturnType] = { + val join = Join(leftRows.logicalPlan, rightRows.logicalPlan, Inner, Some(condition)) + ExtractEquiJoinKeys.unapply(join) + } + + test(s"$testName using ShuffledHashJoin") { + extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) => + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => + EnsureRequirements(left.sqlContext.sessionState.conf).apply( + ShuffledHashJoin( + leftKeys, rightKeys, joinType, BuildRight, boundCondition, left, right)), + expectedAnswer, + sortAnswers = true) + } + } + } + + test(s"$testName using BroadcastHashJoin") { + extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) => + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => + EnsureRequirements(left.sqlContext.sessionState.conf).apply( + BroadcastHashJoin( + leftKeys, rightKeys, joinType, BuildRight, boundCondition, left, right)), + expectedAnswer, + sortAnswers = true) + } + } + } + + test(s"$testName using BroadcastNestedLoopJoin build left") { + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => + EnsureRequirements(left.sqlContext.sessionState.conf).apply( + BroadcastNestedLoopJoin(left, right, BuildLeft, joinType, Some(condition))), + expectedAnswer, + sortAnswers = true) + } + } + + test(s"$testName using BroadcastNestedLoopJoin build right") { + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => + EnsureRequirements(left.sqlContext.sessionState.conf).apply( + BroadcastNestedLoopJoin(left, right, BuildRight, joinType, Some(condition))), + expectedAnswer, + sortAnswers = true) + } + } + } + + testExistenceJoin( + "basic test for left semi join", + LeftSemi, + left, + right, + condition, + Seq(Row(2, 1.0), Row(2, 1.0))) + + testExistenceJoin( + "basic test for left semi non equal join", + LeftSemi, + left, + right, + conditionNEQ, + Seq(Row(1, 2.0), Row(1, 2.0), Row(2, 1.0), Row(2, 1.0))) + + testExistenceJoin( + "basic test for anti join", + LeftAnti, + left, + right, + condition, + Seq(Row(1, 2.0), Row(1, 2.0), Row(3, 3.0), Row(6, null), Row(null, 5.0), Row(null, null))) + + testExistenceJoin( + "basic test for anti non equal join", + LeftAnti, + left, + right, + conditionNEQ, + Seq(Row(3, 3.0), Row(6, null), Row(null, 5.0), Row(null, null))) +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala deleted file mode 100644 index 985a96f684..0000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala +++ /dev/null @@ -1,129 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.joins - -import org.apache.spark.sql.{DataFrame, Row} -import org.apache.spark.sql.catalyst.expressions.{And, Expression, LessThan} -import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys -import org.apache.spark.sql.catalyst.plans.{Inner, LeftSemi} -import org.apache.spark.sql.catalyst.plans.logical.Join -import org.apache.spark.sql.execution.{SparkPlan, SparkPlanTest} -import org.apache.spark.sql.execution.exchange.EnsureRequirements -import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.test.SharedSQLContext -import org.apache.spark.sql.types.{DoubleType, IntegerType, StructType} - -class SemiJoinSuite extends SparkPlanTest with SharedSQLContext { - - private lazy val left = sqlContext.createDataFrame( - sparkContext.parallelize(Seq( - Row(1, 2.0), - Row(1, 2.0), - Row(2, 1.0), - Row(2, 1.0), - Row(3, 3.0), - Row(null, null), - Row(null, 5.0), - Row(6, null) - )), new StructType().add("a", IntegerType).add("b", DoubleType)) - - private lazy val right = sqlContext.createDataFrame( - sparkContext.parallelize(Seq( - Row(2, 3.0), - Row(2, 3.0), - Row(3, 2.0), - Row(4, 1.0), - Row(null, null), - Row(null, 5.0), - Row(6, null) - )), new StructType().add("c", IntegerType).add("d", DoubleType)) - - private lazy val condition = { - And((left.col("a") === right.col("c")).expr, - LessThan(left.col("b").expr, right.col("d").expr)) - } - - // Note: the input dataframes and expression must be evaluated lazily because - // the SQLContext should be used only within a test to keep SQL tests stable - private def testLeftSemiJoin( - testName: String, - leftRows: => DataFrame, - rightRows: => DataFrame, - condition: => Expression, - expectedAnswer: Seq[Product]): Unit = { - - def extractJoinParts(): Option[ExtractEquiJoinKeys.ReturnType] = { - val join = Join(leftRows.logicalPlan, rightRows.logicalPlan, Inner, Some(condition)) - ExtractEquiJoinKeys.unapply(join) - } - - test(s"$testName using ShuffledHashJoin") { - extractJoinParts().foreach { case (joinType, leftKeys, rightKeys, boundCondition, _, _) => - withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { - checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => - EnsureRequirements(left.sqlContext.sessionState.conf).apply( - ShuffledHashJoin( - leftKeys, rightKeys, LeftSemi, BuildRight, boundCondition, left, right)), - expectedAnswer.map(Row.fromTuple), - sortAnswers = true) - } - } - } - - test(s"$testName using BroadcastHashJoin") { - extractJoinParts().foreach { case (joinType, leftKeys, rightKeys, boundCondition, _, _) => - withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { - checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => - BroadcastHashJoin( - leftKeys, rightKeys, LeftSemi, BuildRight, boundCondition, left, right), - expectedAnswer.map(Row.fromTuple), - sortAnswers = true) - } - } - } - - test(s"$testName using BroadcastNestedLoopJoin build left") { - withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { - checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => - BroadcastNestedLoopJoin(left, right, BuildLeft, LeftSemi, Some(condition)), - expectedAnswer.map(Row.fromTuple), - sortAnswers = true) - } - } - - test(s"$testName using BroadcastNestedLoopJoin build right") { - withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { - checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => - BroadcastNestedLoopJoin(left, right, BuildRight, LeftSemi, Some(condition)), - expectedAnswer.map(Row.fromTuple), - sortAnswers = true) - } - } - } - - testLeftSemiJoin( - "basic test", - left, - right, - condition, - Seq( - (2, 1.0), - (2, 1.0) - ) - ) -} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala index cff24e28fd..b992fda18c 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala @@ -92,7 +92,7 @@ private[hive] class HiveSessionState(ctx: HiveContext) extends SessionState(ctx) DataSinks, Scripts, Aggregation, - LeftSemiJoin, + ExistenceJoin, EquiJoinSelection, BasicOperators, BroadcastNestedLoop, -- cgit v1.2.3 From 21d5ca128bf3afd5c2d4c7fcc56240e28443474f Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Wed, 6 Apr 2016 19:33:48 -0700 Subject: [SPARK-14134][CORE] Change the package name used for shading classes. The current package name uses a dash, which is a little weird but seemed to work. That is, until a new test tried to mock a class that references one of those shaded types, and then things started failing. Most changes are just noise to fix the logging configs. For reference, SPARK-8815 also raised this issue, although at the time it did not cause any issues in Spark, so it was not addressed. Author: Marcelo Vanzin Closes #11941 from vanzin/SPARK-14134. --- common/network-yarn/pom.xml | 4 ++-- conf/log4j.properties.template | 4 ++-- core/src/main/resources/org/apache/spark/log4j-defaults.properties | 4 ++-- core/src/test/resources/log4j.properties | 3 +-- external/flume-sink/src/test/resources/log4j.properties | 2 +- external/flume/src/test/resources/log4j.properties | 2 +- external/java8-tests/src/test/resources/log4j.properties | 2 +- external/kafka/src/test/resources/log4j.properties | 2 +- external/kinesis-asl/src/main/resources/log4j.properties | 4 ++-- external/kinesis-asl/src/test/resources/log4j.properties | 2 +- graphx/src/test/resources/log4j.properties | 3 +-- launcher/src/test/resources/log4j.properties | 3 +-- mllib/src/test/resources/log4j.properties | 2 +- pom.xml | 7 +++++-- repl/src/test/resources/log4j.properties | 2 +- sql/catalyst/src/test/resources/log4j.properties | 3 +-- streaming/src/test/resources/log4j.properties | 2 +- yarn/src/test/resources/log4j.properties | 2 +- .../scala/org/apache/spark/deploy/yarn/BaseYarnClusterSuite.scala | 2 +- 19 files changed, 27 insertions(+), 28 deletions(-) diff --git a/common/network-yarn/pom.xml b/common/network-yarn/pom.xml index 3cb44324f2..bc83ef24c3 100644 --- a/common/network-yarn/pom.xml +++ b/common/network-yarn/pom.xml @@ -36,7 +36,7 @@ provided ${project.build.directory}/scala-${scala.binary.version}/spark-${project.version}-yarn-shuffle.jar - org/spark-project/ + org/spark_project/ @@ -91,7 +91,7 @@ com.fasterxml.jackson - org.spark-project.com.fasterxml.jackson + ${spark.shade.packageName}.com.fasterxml.jackson com.fasterxml.jackson.** diff --git a/conf/log4j.properties.template b/conf/log4j.properties.template index 9809b0c828..ec1aa187df 100644 --- a/conf/log4j.properties.template +++ b/conf/log4j.properties.template @@ -28,8 +28,8 @@ log4j.appender.console.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss} %p %c{1}: log4j.logger.org.apache.spark.repl.Main=WARN # Settings to quiet third party logs that are too verbose -log4j.logger.org.spark-project.jetty=WARN -log4j.logger.org.spark-project.jetty.util.component.AbstractLifeCycle=ERROR +log4j.logger.org.spark_project.jetty=WARN +log4j.logger.org.spark_project.jetty.util.component.AbstractLifeCycle=ERROR log4j.logger.org.apache.spark.repl.SparkIMain$exprTyper=INFO log4j.logger.org.apache.spark.repl.SparkILoop$SparkILoopInterpreter=INFO log4j.logger.org.apache.parquet=ERROR diff --git a/core/src/main/resources/org/apache/spark/log4j-defaults.properties b/core/src/main/resources/org/apache/spark/log4j-defaults.properties index 0750488e4a..89a7963a86 100644 --- a/core/src/main/resources/org/apache/spark/log4j-defaults.properties +++ b/core/src/main/resources/org/apache/spark/log4j-defaults.properties @@ -28,8 +28,8 @@ log4j.appender.console.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss} %p %c{1}: log4j.logger.org.apache.spark.repl.Main=WARN # Settings to quiet third party logs that are too verbose -log4j.logger.org.spark-project.jetty=WARN -log4j.logger.org.spark-project.jetty.util.component.AbstractLifeCycle=ERROR +log4j.logger.org.spark_project.jetty=WARN +log4j.logger.org.spark_project.jetty.util.component.AbstractLifeCycle=ERROR log4j.logger.org.apache.spark.repl.SparkIMain$exprTyper=INFO log4j.logger.org.apache.spark.repl.SparkILoop$SparkILoopInterpreter=INFO diff --git a/core/src/test/resources/log4j.properties b/core/src/test/resources/log4j.properties index a54d27de91..fb9d9851cb 100644 --- a/core/src/test/resources/log4j.properties +++ b/core/src/test/resources/log4j.properties @@ -33,5 +33,4 @@ log4j.appender.console.layout=org.apache.log4j.PatternLayout log4j.appender.console.layout.ConversionPattern=%t: %m%n # Ignore messages below warning level from Jetty, because it's a bit verbose -log4j.logger.org.spark-project.jetty=WARN -org.spark-project.jetty.LEVEL=WARN +log4j.logger.org.spark_project.jetty=WARN diff --git a/external/flume-sink/src/test/resources/log4j.properties b/external/flume-sink/src/test/resources/log4j.properties index 42df8792f1..1e3f163f95 100644 --- a/external/flume-sink/src/test/resources/log4j.properties +++ b/external/flume-sink/src/test/resources/log4j.properties @@ -24,5 +24,5 @@ log4j.appender.file.layout=org.apache.log4j.PatternLayout log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n # Ignore messages below warning level from Jetty, because it's a bit verbose -log4j.logger.org.spark-project.jetty=WARN +log4j.logger.org.spark_project.jetty=WARN diff --git a/external/flume/src/test/resources/log4j.properties b/external/flume/src/test/resources/log4j.properties index 75e3b53a09..fd51f8faf5 100644 --- a/external/flume/src/test/resources/log4j.properties +++ b/external/flume/src/test/resources/log4j.properties @@ -24,5 +24,5 @@ log4j.appender.file.layout=org.apache.log4j.PatternLayout log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n # Ignore messages below warning level from Jetty, because it's a bit verbose -log4j.logger.org.spark-project.jetty=WARN +log4j.logger.org.spark_project.jetty=WARN diff --git a/external/java8-tests/src/test/resources/log4j.properties b/external/java8-tests/src/test/resources/log4j.properties index edbecdae92..3706a6e361 100644 --- a/external/java8-tests/src/test/resources/log4j.properties +++ b/external/java8-tests/src/test/resources/log4j.properties @@ -24,4 +24,4 @@ log4j.appender.file.layout=org.apache.log4j.PatternLayout log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n # Ignore messages below warning level from Jetty, because it's a bit verbose -log4j.logger.org.spark-project.jetty=WARN +log4j.logger.org.spark_project.jetty=WARN diff --git a/external/kafka/src/test/resources/log4j.properties b/external/kafka/src/test/resources/log4j.properties index 75e3b53a09..fd51f8faf5 100644 --- a/external/kafka/src/test/resources/log4j.properties +++ b/external/kafka/src/test/resources/log4j.properties @@ -24,5 +24,5 @@ log4j.appender.file.layout=org.apache.log4j.PatternLayout log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n # Ignore messages below warning level from Jetty, because it's a bit verbose -log4j.logger.org.spark-project.jetty=WARN +log4j.logger.org.spark_project.jetty=WARN diff --git a/external/kinesis-asl/src/main/resources/log4j.properties b/external/kinesis-asl/src/main/resources/log4j.properties index 6cdc9286c5..8118d12c5d 100644 --- a/external/kinesis-asl/src/main/resources/log4j.properties +++ b/external/kinesis-asl/src/main/resources/log4j.properties @@ -31,7 +31,7 @@ log4j.appender.console.layout=org.apache.log4j.PatternLayout log4j.appender.console.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss} %p %c{1}: %m%n # Settings to quiet third party logs that are too verbose -log4j.logger.org.spark-project.jetty=WARN -log4j.logger.org.spark-project.jetty.util.component.AbstractLifeCycle=ERROR +log4j.logger.org.spark_project.jetty=WARN +log4j.logger.org.spark_project.jetty.util.component.AbstractLifeCycle=ERROR log4j.logger.org.apache.spark.repl.SparkIMain$exprTyper=INFO log4j.logger.org.apache.spark.repl.SparkILoop$SparkILoopInterpreter=INFO \ No newline at end of file diff --git a/external/kinesis-asl/src/test/resources/log4j.properties b/external/kinesis-asl/src/test/resources/log4j.properties index edbecdae92..3706a6e361 100644 --- a/external/kinesis-asl/src/test/resources/log4j.properties +++ b/external/kinesis-asl/src/test/resources/log4j.properties @@ -24,4 +24,4 @@ log4j.appender.file.layout=org.apache.log4j.PatternLayout log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n # Ignore messages below warning level from Jetty, because it's a bit verbose -log4j.logger.org.spark-project.jetty=WARN +log4j.logger.org.spark_project.jetty=WARN diff --git a/graphx/src/test/resources/log4j.properties b/graphx/src/test/resources/log4j.properties index eb3b1999eb..3706a6e361 100644 --- a/graphx/src/test/resources/log4j.properties +++ b/graphx/src/test/resources/log4j.properties @@ -24,5 +24,4 @@ log4j.appender.file.layout=org.apache.log4j.PatternLayout log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n # Ignore messages below warning level from Jetty, because it's a bit verbose -log4j.logger.org.spark-project.jetty=WARN -org.spark-project.jetty.LEVEL=WARN +log4j.logger.org.spark_project.jetty=WARN diff --git a/launcher/src/test/resources/log4j.properties b/launcher/src/test/resources/log4j.properties index c64b1565e1..744c456cb2 100644 --- a/launcher/src/test/resources/log4j.properties +++ b/launcher/src/test/resources/log4j.properties @@ -30,5 +30,4 @@ log4j.appender.childproc.layout=org.apache.log4j.PatternLayout log4j.appender.childproc.layout.ConversionPattern=%t: %m%n # Ignore messages below warning level from Jetty, because it's a bit verbose -log4j.logger.org.spark-project.jetty=WARN -org.spark-project.jetty.LEVEL=WARN +log4j.logger.org.spark_project.jetty=WARN diff --git a/mllib/src/test/resources/log4j.properties b/mllib/src/test/resources/log4j.properties index 75e3b53a09..fd51f8faf5 100644 --- a/mllib/src/test/resources/log4j.properties +++ b/mllib/src/test/resources/log4j.properties @@ -24,5 +24,5 @@ log4j.appender.file.layout=org.apache.log4j.PatternLayout log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n # Ignore messages below warning level from Jetty, because it's a bit verbose -log4j.logger.org.spark-project.jetty=WARN +log4j.logger.org.spark_project.jetty=WARN diff --git a/pom.xml b/pom.xml index 984b2859ef..66a34e4bdf 100644 --- a/pom.xml +++ b/pom.xml @@ -182,6 +182,9 @@ ${java.home} + + org.spark_project + ${project.build.directory}/scala-${scala.binary.version}/jars @@ -2204,14 +2207,14 @@ org.eclipse.jetty - org.spark-project.jetty + ${spark.shade.packageName}.jetty org.eclipse.jetty.** com.google.common - org.spark-project.guava + ${spark.shade.packageName}.guava diff --git a/repl/src/test/resources/log4j.properties b/repl/src/test/resources/log4j.properties index e2ee9c963a..7665bd5e7c 100644 --- a/repl/src/test/resources/log4j.properties +++ b/repl/src/test/resources/log4j.properties @@ -24,4 +24,4 @@ log4j.appender.file.layout=org.apache.log4j.PatternLayout log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n # Ignore messages below warning level from Jetty, because it's a bit verbose -log4j.logger.org.spark-project.jetty=WARN +log4j.logger.org.spark_project.jetty=WARN diff --git a/sql/catalyst/src/test/resources/log4j.properties b/sql/catalyst/src/test/resources/log4j.properties index eb3b1999eb..3706a6e361 100644 --- a/sql/catalyst/src/test/resources/log4j.properties +++ b/sql/catalyst/src/test/resources/log4j.properties @@ -24,5 +24,4 @@ log4j.appender.file.layout=org.apache.log4j.PatternLayout log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n # Ignore messages below warning level from Jetty, because it's a bit verbose -log4j.logger.org.spark-project.jetty=WARN -org.spark-project.jetty.LEVEL=WARN +log4j.logger.org.spark_project.jetty=WARN diff --git a/streaming/src/test/resources/log4j.properties b/streaming/src/test/resources/log4j.properties index 75e3b53a09..fd51f8faf5 100644 --- a/streaming/src/test/resources/log4j.properties +++ b/streaming/src/test/resources/log4j.properties @@ -24,5 +24,5 @@ log4j.appender.file.layout=org.apache.log4j.PatternLayout log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n # Ignore messages below warning level from Jetty, because it's a bit verbose -log4j.logger.org.spark-project.jetty=WARN +log4j.logger.org.spark_project.jetty=WARN diff --git a/yarn/src/test/resources/log4j.properties b/yarn/src/test/resources/log4j.properties index 6b9a799954..d13454d5ae 100644 --- a/yarn/src/test/resources/log4j.properties +++ b/yarn/src/test/resources/log4j.properties @@ -28,4 +28,4 @@ log4j.logger.com.sun.jersey=WARN log4j.logger.org.apache.hadoop=WARN log4j.logger.org.eclipse.jetty=WARN log4j.logger.org.mortbay=WARN -log4j.logger.org.spark-project.jetty=WARN +log4j.logger.org.spark_project.jetty=WARN diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/BaseYarnClusterSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/BaseYarnClusterSuite.scala index 2f3a31cb04..9c3b18e4ec 100644 --- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/BaseYarnClusterSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/BaseYarnClusterSuite.scala @@ -53,7 +53,7 @@ abstract class BaseYarnClusterSuite |log4j.logger.org.apache.hadoop=WARN |log4j.logger.org.eclipse.jetty=WARN |log4j.logger.org.mortbay=WARN - |log4j.logger.org.spark-project.jetty=WARN + |log4j.logger.org.spark_project.jetty=WARN """.stripMargin private var yarnCluster: MiniYARNCluster = _ -- cgit v1.2.3 From e11aa9ec5c3cdcd8ca08d2486a7208840ad77bf8 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Thu, 7 Apr 2016 00:46:57 -0700 Subject: [SPARK-14452][SQL] Explicit APIs in Scala for specifying encoders ## What changes were proposed in this pull request? The Scala Dataset public API currently only allows users to specify encoders through SQLContext.implicits. This is OK but sometimes people want to explicitly get encoders without a SQLContext (e.g. Aggregator implementations). This patch adds public APIs to Encoders class for getting Scala encoders. ## How was this patch tested? None - I will update test cases once https://github.com/apache/spark/pull/12231 is merged. Author: Reynold Xin Closes #12232 from rxin/SPARK-14452. --- .../main/scala/org/apache/spark/sql/Encoder.scala | 231 +-------------- .../main/scala/org/apache/spark/sql/Encoders.scala | 314 +++++++++++++++++++++ .../scala/org/apache/spark/sql/SQLImplicits.scala | 18 +- 3 files changed, 327 insertions(+), 236 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/Encoders.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala index e0bfe3c32f..ffa694fcdc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala @@ -17,22 +17,20 @@ package org.apache.spark.sql -import java.lang.reflect.Modifier - import scala.annotation.implicitNotFound -import scala.reflect.{classTag, ClassTag} +import scala.reflect.ClassTag import org.apache.spark.annotation.Experimental -import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder} -import org.apache.spark.sql.catalyst.expressions.{BoundReference, DecodeUsingSerializer, EncodeUsingSerializer} import org.apache.spark.sql.types._ + /** * :: Experimental :: * Used to convert a JVM object of type `T` to and from the internal Spark SQL representation. * * == Scala == - * Encoders are generally created automatically through implicits from a `SQLContext`. + * Encoders are generally created automatically through implicits from a `SQLContext`, or can be + * explicitly created by calling static methods on [[Encoders]]. * * {{{ * import sqlContext.implicits._ @@ -81,224 +79,3 @@ trait Encoder[T] extends Serializable { /** A ClassTag that can be used to construct and Array to contain a collection of `T`. */ def clsTag: ClassTag[T] } - -/** - * :: Experimental :: - * Methods for creating an [[Encoder]]. - * - * @since 1.6.0 - */ -@Experimental -object Encoders { - - /** - * An encoder for nullable boolean type. - * @since 1.6.0 - */ - def BOOLEAN: Encoder[java.lang.Boolean] = ExpressionEncoder() - - /** - * An encoder for nullable byte type. - * @since 1.6.0 - */ - def BYTE: Encoder[java.lang.Byte] = ExpressionEncoder() - - /** - * An encoder for nullable short type. - * @since 1.6.0 - */ - def SHORT: Encoder[java.lang.Short] = ExpressionEncoder() - - /** - * An encoder for nullable int type. - * @since 1.6.0 - */ - def INT: Encoder[java.lang.Integer] = ExpressionEncoder() - - /** - * An encoder for nullable long type. - * @since 1.6.0 - */ - def LONG: Encoder[java.lang.Long] = ExpressionEncoder() - - /** - * An encoder for nullable float type. - * @since 1.6.0 - */ - def FLOAT: Encoder[java.lang.Float] = ExpressionEncoder() - - /** - * An encoder for nullable double type. - * @since 1.6.0 - */ - def DOUBLE: Encoder[java.lang.Double] = ExpressionEncoder() - - /** - * An encoder for nullable string type. - * @since 1.6.0 - */ - def STRING: Encoder[java.lang.String] = ExpressionEncoder() - - /** - * An encoder for nullable decimal type. - * @since 1.6.0 - */ - def DECIMAL: Encoder[java.math.BigDecimal] = ExpressionEncoder() - - /** - * An encoder for nullable date type. - * @since 1.6.0 - */ - def DATE: Encoder[java.sql.Date] = ExpressionEncoder() - - /** - * An encoder for nullable timestamp type. - * @since 1.6.0 - */ - def TIMESTAMP: Encoder[java.sql.Timestamp] = ExpressionEncoder() - - /** - * An encoder for arrays of bytes. - * @since 1.6.1 - */ - def BINARY: Encoder[Array[Byte]] = ExpressionEncoder() - - /** - * Creates an encoder for Java Bean of type T. - * - * T must be publicly accessible. - * - * supported types for java bean field: - * - primitive types: boolean, int, double, etc. - * - boxed types: Boolean, Integer, Double, etc. - * - String - * - java.math.BigDecimal - * - time related: java.sql.Date, java.sql.Timestamp - * - collection types: only array and java.util.List currently, map support is in progress - * - nested java bean. - * - * @since 1.6.0 - */ - def bean[T](beanClass: Class[T]): Encoder[T] = ExpressionEncoder.javaBean(beanClass) - - /** - * (Scala-specific) Creates an encoder that serializes objects of type T using Kryo. - * This encoder maps T into a single byte array (binary) field. - * - * T must be publicly accessible. - * - * @since 1.6.0 - */ - def kryo[T: ClassTag]: Encoder[T] = genericSerializer(useKryo = true) - - /** - * Creates an encoder that serializes objects of type T using Kryo. - * This encoder maps T into a single byte array (binary) field. - * - * T must be publicly accessible. - * - * @since 1.6.0 - */ - def kryo[T](clazz: Class[T]): Encoder[T] = kryo(ClassTag[T](clazz)) - - /** - * (Scala-specific) Creates an encoder that serializes objects of type T using generic Java - * serialization. This encoder maps T into a single byte array (binary) field. - * - * Note that this is extremely inefficient and should only be used as the last resort. - * - * T must be publicly accessible. - * - * @since 1.6.0 - */ - def javaSerialization[T: ClassTag]: Encoder[T] = genericSerializer(useKryo = false) - - /** - * Creates an encoder that serializes objects of type T using generic Java serialization. - * This encoder maps T into a single byte array (binary) field. - * - * Note that this is extremely inefficient and should only be used as the last resort. - * - * T must be publicly accessible. - * - * @since 1.6.0 - */ - def javaSerialization[T](clazz: Class[T]): Encoder[T] = javaSerialization(ClassTag[T](clazz)) - - /** Throws an exception if T is not a public class. */ - private def validatePublicClass[T: ClassTag](): Unit = { - if (!Modifier.isPublic(classTag[T].runtimeClass.getModifiers)) { - throw new UnsupportedOperationException( - s"${classTag[T].runtimeClass.getName} is not a public class. " + - "Only public classes are supported.") - } - } - - /** A way to construct encoders using generic serializers. */ - private def genericSerializer[T: ClassTag](useKryo: Boolean): Encoder[T] = { - if (classTag[T].runtimeClass.isPrimitive) { - throw new UnsupportedOperationException("Primitive types are not supported.") - } - - validatePublicClass[T]() - - ExpressionEncoder[T]( - schema = new StructType().add("value", BinaryType), - flat = true, - serializer = Seq( - EncodeUsingSerializer( - BoundReference(0, ObjectType(classOf[AnyRef]), nullable = true), kryo = useKryo)), - deserializer = - DecodeUsingSerializer[T]( - BoundReference(0, BinaryType, nullable = true), classTag[T], kryo = useKryo), - clsTag = classTag[T] - ) - } - - /** - * An encoder for 2-ary tuples. - * @since 1.6.0 - */ - def tuple[T1, T2]( - e1: Encoder[T1], - e2: Encoder[T2]): Encoder[(T1, T2)] = { - ExpressionEncoder.tuple(encoderFor(e1), encoderFor(e2)) - } - - /** - * An encoder for 3-ary tuples. - * @since 1.6.0 - */ - def tuple[T1, T2, T3]( - e1: Encoder[T1], - e2: Encoder[T2], - e3: Encoder[T3]): Encoder[(T1, T2, T3)] = { - ExpressionEncoder.tuple(encoderFor(e1), encoderFor(e2), encoderFor(e3)) - } - - /** - * An encoder for 4-ary tuples. - * @since 1.6.0 - */ - def tuple[T1, T2, T3, T4]( - e1: Encoder[T1], - e2: Encoder[T2], - e3: Encoder[T3], - e4: Encoder[T4]): Encoder[(T1, T2, T3, T4)] = { - ExpressionEncoder.tuple(encoderFor(e1), encoderFor(e2), encoderFor(e3), encoderFor(e4)) - } - - /** - * An encoder for 5-ary tuples. - * @since 1.6.0 - */ - def tuple[T1, T2, T3, T4, T5]( - e1: Encoder[T1], - e2: Encoder[T2], - e3: Encoder[T3], - e4: Encoder[T4], - e5: Encoder[T5]): Encoder[(T1, T2, T3, T4, T5)] = { - ExpressionEncoder.tuple( - encoderFor(e1), encoderFor(e2), encoderFor(e3), encoderFor(e4), encoderFor(e5)) - } -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoders.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoders.scala new file mode 100644 index 0000000000..3f4df704db --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoders.scala @@ -0,0 +1,314 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import java.lang.reflect.Modifier + +import scala.reflect.{classTag, ClassTag} +import scala.reflect.runtime.universe.TypeTag + +import org.apache.spark.annotation.Experimental +import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder} +import org.apache.spark.sql.catalyst.expressions.{BoundReference, DecodeUsingSerializer, EncodeUsingSerializer} +import org.apache.spark.sql.types._ + +/** + * :: Experimental :: + * Methods for creating an [[Encoder]]. + * + * @since 1.6.0 + */ +@Experimental +object Encoders { + + /** + * An encoder for nullable boolean type. + * The Scala primitive encoder is available as [[scalaBoolean]]. + * @since 1.6.0 + */ + def BOOLEAN: Encoder[java.lang.Boolean] = ExpressionEncoder() + + /** + * An encoder for nullable byte type. + * The Scala primitive encoder is available as [[scalaByte]]. + * @since 1.6.0 + */ + def BYTE: Encoder[java.lang.Byte] = ExpressionEncoder() + + /** + * An encoder for nullable short type. + * The Scala primitive encoder is available as [[scalaShort]]. + * @since 1.6.0 + */ + def SHORT: Encoder[java.lang.Short] = ExpressionEncoder() + + /** + * An encoder for nullable int type. + * The Scala primitive encoder is available as [[scalaInt]]. + * @since 1.6.0 + */ + def INT: Encoder[java.lang.Integer] = ExpressionEncoder() + + /** + * An encoder for nullable long type. + * The Scala primitive encoder is available as [[scalaLong]]. + * @since 1.6.0 + */ + def LONG: Encoder[java.lang.Long] = ExpressionEncoder() + + /** + * An encoder for nullable float type. + * The Scala primitive encoder is available as [[scalaFloat]]. + * @since 1.6.0 + */ + def FLOAT: Encoder[java.lang.Float] = ExpressionEncoder() + + /** + * An encoder for nullable double type. + * The Scala primitive encoder is available as [[scalaDouble]]. + * @since 1.6.0 + */ + def DOUBLE: Encoder[java.lang.Double] = ExpressionEncoder() + + /** + * An encoder for nullable string type. + * + * @since 1.6.0 + */ + def STRING: Encoder[java.lang.String] = ExpressionEncoder() + + /** + * An encoder for nullable decimal type. + * + * @since 1.6.0 + */ + def DECIMAL: Encoder[java.math.BigDecimal] = ExpressionEncoder() + + /** + * An encoder for nullable date type. + * + * @since 1.6.0 + */ + def DATE: Encoder[java.sql.Date] = ExpressionEncoder() + + /** + * An encoder for nullable timestamp type. + * + * @since 1.6.0 + */ + def TIMESTAMP: Encoder[java.sql.Timestamp] = ExpressionEncoder() + + /** + * An encoder for arrays of bytes. + * + * @since 1.6.1 + */ + def BINARY: Encoder[Array[Byte]] = ExpressionEncoder() + + /** + * Creates an encoder for Java Bean of type T. + * + * T must be publicly accessible. + * + * supported types for java bean field: + * - primitive types: boolean, int, double, etc. + * - boxed types: Boolean, Integer, Double, etc. + * - String + * - java.math.BigDecimal + * - time related: java.sql.Date, java.sql.Timestamp + * - collection types: only array and java.util.List currently, map support is in progress + * - nested java bean. + * + * @since 1.6.0 + */ + def bean[T](beanClass: Class[T]): Encoder[T] = ExpressionEncoder.javaBean(beanClass) + + /** + * (Scala-specific) Creates an encoder that serializes objects of type T using Kryo. + * This encoder maps T into a single byte array (binary) field. + * + * T must be publicly accessible. + * + * @since 1.6.0 + */ + def kryo[T: ClassTag]: Encoder[T] = genericSerializer(useKryo = true) + + /** + * Creates an encoder that serializes objects of type T using Kryo. + * This encoder maps T into a single byte array (binary) field. + * + * T must be publicly accessible. + * + * @since 1.6.0 + */ + def kryo[T](clazz: Class[T]): Encoder[T] = kryo(ClassTag[T](clazz)) + + /** + * (Scala-specific) Creates an encoder that serializes objects of type T using generic Java + * serialization. This encoder maps T into a single byte array (binary) field. + * + * Note that this is extremely inefficient and should only be used as the last resort. + * + * T must be publicly accessible. + * + * @since 1.6.0 + */ + def javaSerialization[T: ClassTag]: Encoder[T] = genericSerializer(useKryo = false) + + /** + * Creates an encoder that serializes objects of type T using generic Java serialization. + * This encoder maps T into a single byte array (binary) field. + * + * Note that this is extremely inefficient and should only be used as the last resort. + * + * T must be publicly accessible. + * + * @since 1.6.0 + */ + def javaSerialization[T](clazz: Class[T]): Encoder[T] = javaSerialization(ClassTag[T](clazz)) + + /** Throws an exception if T is not a public class. */ + private def validatePublicClass[T: ClassTag](): Unit = { + if (!Modifier.isPublic(classTag[T].runtimeClass.getModifiers)) { + throw new UnsupportedOperationException( + s"${classTag[T].runtimeClass.getName} is not a public class. " + + "Only public classes are supported.") + } + } + + /** A way to construct encoders using generic serializers. */ + private def genericSerializer[T: ClassTag](useKryo: Boolean): Encoder[T] = { + if (classTag[T].runtimeClass.isPrimitive) { + throw new UnsupportedOperationException("Primitive types are not supported.") + } + + validatePublicClass[T]() + + ExpressionEncoder[T]( + schema = new StructType().add("value", BinaryType), + flat = true, + serializer = Seq( + EncodeUsingSerializer( + BoundReference(0, ObjectType(classOf[AnyRef]), nullable = true), kryo = useKryo)), + deserializer = + DecodeUsingSerializer[T]( + BoundReference(0, BinaryType, nullable = true), classTag[T], kryo = useKryo), + clsTag = classTag[T] + ) + } + + /** + * An encoder for 2-ary tuples. + * + * @since 1.6.0 + */ + def tuple[T1, T2]( + e1: Encoder[T1], + e2: Encoder[T2]): Encoder[(T1, T2)] = { + ExpressionEncoder.tuple(encoderFor(e1), encoderFor(e2)) + } + + /** + * An encoder for 3-ary tuples. + * + * @since 1.6.0 + */ + def tuple[T1, T2, T3]( + e1: Encoder[T1], + e2: Encoder[T2], + e3: Encoder[T3]): Encoder[(T1, T2, T3)] = { + ExpressionEncoder.tuple(encoderFor(e1), encoderFor(e2), encoderFor(e3)) + } + + /** + * An encoder for 4-ary tuples. + * + * @since 1.6.0 + */ + def tuple[T1, T2, T3, T4]( + e1: Encoder[T1], + e2: Encoder[T2], + e3: Encoder[T3], + e4: Encoder[T4]): Encoder[(T1, T2, T3, T4)] = { + ExpressionEncoder.tuple(encoderFor(e1), encoderFor(e2), encoderFor(e3), encoderFor(e4)) + } + + /** + * An encoder for 5-ary tuples. + * + * @since 1.6.0 + */ + def tuple[T1, T2, T3, T4, T5]( + e1: Encoder[T1], + e2: Encoder[T2], + e3: Encoder[T3], + e4: Encoder[T4], + e5: Encoder[T5]): Encoder[(T1, T2, T3, T4, T5)] = { + ExpressionEncoder.tuple( + encoderFor(e1), encoderFor(e2), encoderFor(e3), encoderFor(e4), encoderFor(e5)) + } + + /** + * An encoder for Scala's product type (tuples, case classes, etc). + * @since 2.0.0 + */ + def product[T <: Product : TypeTag]: Encoder[T] = ExpressionEncoder() + + /** + * An encoder for Scala's primitive int type. + * @since 2.0.0 + */ + def scalaInt: Encoder[Int] = ExpressionEncoder() + + /** + * An encoder for Scala's primitive long type. + * @since 2.0.0 + */ + def scalaLong: Encoder[Long] = ExpressionEncoder() + + /** + * An encoder for Scala's primitive double type. + * @since 2.0.0 + */ + def scalaDouble: Encoder[Double] = ExpressionEncoder() + + /** + * An encoder for Scala's primitive float type. + * @since 2.0.0 + */ + def scalaFloat: Encoder[Float] = ExpressionEncoder() + + /** + * An encoder for Scala's primitive byte type. + * @since 2.0.0 + */ + def scalaByte: Encoder[Byte] = ExpressionEncoder() + + /** + * An encoder for Scala's primitive short type. + * @since 2.0.0 + */ + def scalaShort: Encoder[Short] = ExpressionEncoder() + + /** + * An encoder for Scala's primitive boolean type. + * @since 2.0.0 + */ + def scalaBoolean: Encoder[Boolean] = ExpressionEncoder() + +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala index c35a969bf0..ad69e23540 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala @@ -44,33 +44,33 @@ abstract class SQLImplicits { } /** @since 1.6.0 */ - implicit def newProductEncoder[T <: Product : TypeTag]: Encoder[T] = ExpressionEncoder() + implicit def newProductEncoder[T <: Product : TypeTag]: Encoder[T] = Encoders.product[T] // Primitives /** @since 1.6.0 */ - implicit def newIntEncoder: Encoder[Int] = ExpressionEncoder() + implicit def newIntEncoder: Encoder[Int] = Encoders.scalaInt /** @since 1.6.0 */ - implicit def newLongEncoder: Encoder[Long] = ExpressionEncoder() + implicit def newLongEncoder: Encoder[Long] = Encoders.scalaLong /** @since 1.6.0 */ - implicit def newDoubleEncoder: Encoder[Double] = ExpressionEncoder() + implicit def newDoubleEncoder: Encoder[Double] = Encoders.scalaDouble /** @since 1.6.0 */ - implicit def newFloatEncoder: Encoder[Float] = ExpressionEncoder() + implicit def newFloatEncoder: Encoder[Float] = Encoders.scalaFloat /** @since 1.6.0 */ - implicit def newByteEncoder: Encoder[Byte] = ExpressionEncoder() + implicit def newByteEncoder: Encoder[Byte] = Encoders.scalaByte /** @since 1.6.0 */ - implicit def newShortEncoder: Encoder[Short] = ExpressionEncoder() + implicit def newShortEncoder: Encoder[Short] = Encoders.scalaShort /** @since 1.6.0 */ - implicit def newBooleanEncoder: Encoder[Boolean] = ExpressionEncoder() + implicit def newBooleanEncoder: Encoder[Boolean] = Encoders.scalaBoolean /** @since 1.6.0 */ - implicit def newStringEncoder: Encoder[String] = ExpressionEncoder() + implicit def newStringEncoder: Encoder[String] = Encoders.STRING // Seqs -- cgit v1.2.3 From 9ca0760d6769199f164a661655912f028234eb1c Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Thu, 7 Apr 2016 00:51:45 -0700 Subject: [SPARK-10063][SQL] Remove DirectParquetOutputCommitter ## What changes were proposed in this pull request? This patch removes DirectParquetOutputCommitter. This was initially created by Databricks as a faster way to write Parquet data to S3. However, given how the underlying S3 Hadoop implementation works, this committer only works when there are no failures. If there are multiple attempts of the same task (e.g. speculation or task failures or node failures), the output data can be corrupted. I don't think this performance optimization outweighs the correctness issue. ## How was this patch tested? Removed the related tests also. Author: Reynold Xin Closes #12229 from rxin/SPARK-10063. --- docs/sql-programming-guide.md | 33 -------- .../execution/datasources/WriterContainer.scala | 18 ++--- .../parquet/DirectParquetOutputCommitter.scala | 88 ---------------------- .../datasources/parquet/ParquetRelation.scala | 7 -- .../datasources/parquet/ParquetIOSuite.scala | 49 ------------ .../spark/sql/sources/hadoopFsRelationSuites.scala | 34 --------- 6 files changed, 5 insertions(+), 224 deletions(-) delete mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/DirectParquetOutputCommitter.scala diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 274a8edb0c..63310be22c 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1466,37 +1466,6 @@ Configuration of Parquet can be done using the `setConf` method on `SQLContext` support. - - - - - @@ -2165,8 +2134,6 @@ options. - In the `sql` dialect, floating point numbers are now parsed as decimal. HiveQL parsing remains unchanged. - The canonical name of SQL/DataFrame functions are now lower case (e.g. sum vs SUM). - - It has been determined that using the DirectOutputCommitter when speculation is enabled is unsafe - and thus this output committer will not be used when speculation is on, independent of configuration. - JSON data source will not automatically load new files that are created by other applications (i.e. files that are not inserted to the dataset through Spark SQL). For a JSON persistent table (i.e. the metadata of the table is stored in Hive Metastore), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala index 233ac263aa..f6b7f0854b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala @@ -129,16 +129,17 @@ private[sql] abstract class BaseWriterContainer( outputWriterFactory.newInstance(path, bucketId, dataSchema, taskAttemptContext) } catch { case e: org.apache.hadoop.fs.FileAlreadyExistsException => - if (outputCommitter.isInstanceOf[parquet.DirectParquetOutputCommitter]) { - // Spark-11382: DirectParquetOutputCommitter is not idempotent, meaning on retry + if (outputCommitter.getClass.getName.contains("Direct")) { + // SPARK-11382: DirectParquetOutputCommitter is not idempotent, meaning on retry // attempts, the task will fail because the output file is created from a prior attempt. // This often means the most visible error to the user is misleading. Augment the error // to tell the user to look for the actual error. throw new SparkException("The output file already exists but this could be due to a " + "failure from an earlier attempt. Look through the earlier logs or stage page for " + - "the first error.\n File exists error: " + e) + "the first error.\n File exists error: " + e.getLocalizedMessage, e) + } else { + throw e } - throw e } } @@ -156,15 +157,6 @@ private[sql] abstract class BaseWriterContainer( s"Using default output committer ${defaultOutputCommitter.getClass.getCanonicalName} " + "for appending.") defaultOutputCommitter - } else if (speculationEnabled) { - // When speculation is enabled, it's not safe to use customized output committer classes, - // especially direct output committers (e.g. `DirectParquetOutputCommitter`). - // - // See SPARK-9899 for more details. - logInfo( - s"Using default output committer ${defaultOutputCommitter.getClass.getCanonicalName} " + - "because spark.speculation is configured to be true.") - defaultOutputCommitter } else { val configuration = context.getConfiguration val committerClass = configuration.getClass( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/DirectParquetOutputCommitter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/DirectParquetOutputCommitter.scala deleted file mode 100644 index ecadb9e7c6..0000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/DirectParquetOutputCommitter.scala +++ /dev/null @@ -1,88 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.datasources.parquet - -import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.Path -import org.apache.hadoop.mapreduce.{JobContext, TaskAttemptContext} -import org.apache.hadoop.mapreduce.lib.output.FileOutputCommitter -import org.apache.parquet.Log -import org.apache.parquet.hadoop.{ParquetFileReader, ParquetFileWriter, ParquetOutputCommitter, ParquetOutputFormat} -import org.apache.parquet.hadoop.util.ContextUtil - -/** - * An output committer for writing Parquet files. In stead of writing to the `_temporary` folder - * like what [[ParquetOutputCommitter]] does, this output committer writes data directly to the - * destination folder. This can be useful for data stored in S3, where directory operations are - * relatively expensive. - * - * To enable this output committer, users may set the "spark.sql.parquet.output.committer.class" - * property via Hadoop [[Configuration]]. Not that this property overrides - * "spark.sql.sources.outputCommitterClass". - * - * *NOTE* - * - * NEVER use [[DirectParquetOutputCommitter]] when appending data, because currently there's - * no safe way undo a failed appending job (that's why both `abortTask()` and `abortJob()` are - * left empty). - */ -private[datasources] class DirectParquetOutputCommitter( - outputPath: Path, context: TaskAttemptContext) - extends ParquetOutputCommitter(outputPath, context) { - val LOG = Log.getLog(classOf[ParquetOutputCommitter]) - - override def getWorkPath: Path = outputPath - override def abortTask(taskContext: TaskAttemptContext): Unit = {} - override def commitTask(taskContext: TaskAttemptContext): Unit = {} - override def needsTaskCommit(taskContext: TaskAttemptContext): Boolean = true - override def setupJob(jobContext: JobContext): Unit = {} - override def setupTask(taskContext: TaskAttemptContext): Unit = {} - - override def commitJob(jobContext: JobContext) { - val configuration = ContextUtil.getConfiguration(jobContext) - val fileSystem = outputPath.getFileSystem(configuration) - - if (configuration.getBoolean(ParquetOutputFormat.ENABLE_JOB_SUMMARY, true)) { - try { - val outputStatus = fileSystem.getFileStatus(outputPath) - val footers = ParquetFileReader.readAllFootersInParallel(configuration, outputStatus) - try { - ParquetFileWriter.writeMetadataFile(configuration, outputPath, footers) - } catch { case e: Exception => - LOG.warn("could not write summary file for " + outputPath, e) - val metadataPath = new Path(outputPath, ParquetFileWriter.PARQUET_METADATA_FILE) - if (fileSystem.exists(metadataPath)) { - fileSystem.delete(metadataPath, true) - } - } - } catch { - case e: Exception => LOG.warn("could not write summary file for " + outputPath, e) - } - } - - if (configuration.getBoolean("mapreduce.fileoutputcommitter.marksuccessfuljobs", true)) { - try { - val successPath = new Path(outputPath, FileOutputCommitter.SUCCEEDED_FILE_NAME) - fileSystem.create(successPath).close() - } catch { - case e: Exception => LOG.warn("could not write success file for " + outputPath, e) - } - } - } -} - diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala index a2fd8da782..5ad95e4b9e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala @@ -76,13 +76,6 @@ private[sql] class DefaultSource val conf = ContextUtil.getConfiguration(job) - // SPARK-9849 DirectParquetOutputCommitter qualified name should be backward compatible - val committerClassName = conf.get(SQLConf.PARQUET_OUTPUT_COMMITTER_CLASS.key) - if (committerClassName == "org.apache.spark.sql.parquet.DirectParquetOutputCommitter") { - conf.set(SQLConf.PARQUET_OUTPUT_COMMITTER_CLASS.key, - classOf[DirectParquetOutputCommitter].getCanonicalName) - } - val committerClass = conf.getClass( SQLConf.PARQUET_OUTPUT_COMMITTER_CLASS.key, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala index a3017258d6..581095d3dc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala @@ -445,55 +445,6 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { } } - testQuietly("SPARK-6352 DirectParquetOutputCommitter") { - val clonedConf = new Configuration(hadoopConfiguration) - - // Write to a parquet file and let it fail. - // _temporary should be missing if direct output committer works. - try { - hadoopConfiguration.set("spark.sql.parquet.output.committer.class", - classOf[DirectParquetOutputCommitter].getCanonicalName) - sqlContext.udf.register("div0", (x: Int) => x / 0) - withTempPath { dir => - intercept[org.apache.spark.SparkException] { - sqlContext.sql("select div0(1) as div0").write.parquet(dir.getCanonicalPath) - } - val path = new Path(dir.getCanonicalPath, "_temporary") - val fs = path.getFileSystem(hadoopConfiguration) - assert(!fs.exists(path)) - } - } finally { - // Hadoop 1 doesn't have `Configuration.unset` - hadoopConfiguration.clear() - clonedConf.asScala.foreach(entry => hadoopConfiguration.set(entry.getKey, entry.getValue)) - } - } - - testQuietly("SPARK-9849 DirectParquetOutputCommitter qualified name backwards compatibility") { - val clonedConf = new Configuration(hadoopConfiguration) - - // Write to a parquet file and let it fail. - // _temporary should be missing if direct output committer works. - try { - hadoopConfiguration.set("spark.sql.parquet.output.committer.class", - "org.apache.spark.sql.parquet.DirectParquetOutputCommitter") - sqlContext.udf.register("div0", (x: Int) => x / 0) - withTempPath { dir => - intercept[org.apache.spark.SparkException] { - sqlContext.sql("select div0(1) as div0").write.parquet(dir.getCanonicalPath) - } - val path = new Path(dir.getCanonicalPath, "_temporary") - val fs = path.getFileSystem(hadoopConfiguration) - assert(!fs.exists(path)) - } - } finally { - // Hadoop 1 doesn't have `Configuration.unset` - hadoopConfiguration.clear() - clonedConf.asScala.foreach(entry => hadoopConfiguration.set(entry.getKey, entry.getValue)) - } - } - - test("SPARK-8121: spark.sql.parquet.output.committer.class shouldn't be overridden") { withTempPath { dir => val clonedConf = new Configuration(hadoopConfiguration) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala index ea7e905742..10eeb30242 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala @@ -668,40 +668,6 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with Tes df.write.format(dataSourceName).partitionBy("c", "d", "e").saveAsTable("t") } } - - test("SPARK-9899 Disable customized output committer when speculation is on") { - val clonedConf = new Configuration(hadoopConfiguration) - val speculationEnabled = - sqlContext.sparkContext.conf.getBoolean("spark.speculation", defaultValue = false) - - try { - withTempPath { dir => - // Enables task speculation - sqlContext.sparkContext.conf.set("spark.speculation", "true") - - // Uses a customized output committer which always fails - hadoopConfiguration.set( - SQLConf.OUTPUT_COMMITTER_CLASS.key, - classOf[AlwaysFailOutputCommitter].getName) - - // Code below shouldn't throw since customized output committer should be disabled. - val df = sqlContext.range(10).toDF().coalesce(1) - df.write.format(dataSourceName).save(dir.getCanonicalPath) - checkAnswer( - sqlContext - .read - .format(dataSourceName) - .option("dataSchema", df.schema.json) - .load(dir.getCanonicalPath), - df) - } - } finally { - // Hadoop 1 doesn't have `Configuration.unset` - hadoopConfiguration.clear() - clonedConf.asScala.foreach(entry => hadoopConfiguration.set(entry.getKey, entry.getValue)) - sqlContext.sparkContext.conf.set("spark.speculation", speculationEnabled.toString) - } - } } // This class is used to test SPARK-8578. We should not use any custom output committer when -- cgit v1.2.3 From db75ccb5522ffdb8cf8fa2531297a2c1d883c283 Mon Sep 17 00:00:00 2001 From: Malte Date: Thu, 7 Apr 2016 09:16:07 +0100 Subject: Better host description for multi-master mesos ## What changes were proposed in this pull request? Since not having the correct zk url causes job failure, the documentation should include all parameters ## How was this patch tested? no tests necessary Author: Malte Closes #12218 from elmalto/patch-1. --- docs/running-on-mesos.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/running-on-mesos.md b/docs/running-on-mesos.md index 8e47301a75..4a0ab623c1 100644 --- a/docs/running-on-mesos.md +++ b/docs/running-on-mesos.md @@ -108,7 +108,7 @@ the `dev/make-distribution.sh` script included in a Spark source tarball/checkou ## Using a Mesos Master URL The Master URLs for Mesos are in the form `mesos://host:5050` for a single-master Mesos -cluster, or `mesos://zk://host:2181` for a multi-master Mesos cluster using ZooKeeper. +cluster, or `mesos://zk://host1:2181,host2:2181,host3:2181/mesos` for a multi-master Mesos cluster using ZooKeeper. ## Client Mode -- cgit v1.2.3 From 35e0db2d45e2f98d8b4d2c0d442ab19cd615830e Mon Sep 17 00:00:00 2001 From: Alex Bozarth Date: Thu, 7 Apr 2016 09:15:00 -0500 Subject: [SPARK-14245][WEB UI] Display the user in the application view ## What changes were proposed in this pull request? The Spark UI (both active and history) should show the user who ran the application somewhere when you are in the application view. This was added under the Jobs view by total uptime and scheduler mode. ## How was this patch tested? Manual testing username Author: Alex Bozarth Closes #12123 from ajbozarth/spark14245. --- core/src/main/scala/org/apache/spark/ui/SparkUI.scala | 4 ++++ core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala | 4 ++++ core/src/main/scala/org/apache/spark/ui/jobs/JobsTab.scala | 2 ++ 3 files changed, 10 insertions(+) diff --git a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala index 6057522509..39155ff264 100644 --- a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala @@ -80,6 +80,10 @@ private[spark] class SparkUI private ( } initialize() + def getSparkUser: String = { + environmentListener.systemProperties.toMap.get("user.name").getOrElse("") + } + def getAppName: String = appName def setAppId(id: String): Unit = { diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala index d5f15f160b..07484c9550 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala @@ -296,6 +296,10 @@ private[ui] class AllJobsPage(parent: JobsTab) extends WebUIPage("") { val summary: NodeSeq =
      +
    • + User: + {parent.getSparkUser} +
    • Total Uptime: { diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobsTab.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobsTab.scala index 0d0e9b00d3..7b00b558d5 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/JobsTab.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobsTab.scala @@ -31,6 +31,8 @@ private[ui] class JobsTab(parent: SparkUI) extends SparkUITab(parent, "jobs") { def isFairScheduler: Boolean = jobProgresslistener.schedulingMode == Some(SchedulingMode.FAIR) + def getSparkUser: String = parent.getSparkUser + attachPage(new AllJobsPage(this)) attachPage(new JobPage(this)) } -- cgit v1.2.3 From 033d8081525a7137085ec898e2426a58056ee2b8 Mon Sep 17 00:00:00 2001 From: Dhruve Ashar Date: Thu, 7 Apr 2016 10:39:21 -0500 Subject: [SPARK-12384] Enables spark-clients to set the min(-Xms) and max(*.memory config) j… MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? Currently Spark clients are started with the same memory setting for Xms and Xms leading to reserving unnecessary higher amounts of memory. This behavior is changed and the clients can now specify an initial heap size using the extraJavaOptions in the config for driver,executor and am individually. Note, that only -Xms can be provided through this config option, if the client wants to set the max size(-Xmx), this has to be done via the *.memory configuration knobs which are currently supported. ## How was this patch tested? Monitored executor and yarn logs in debug mode to verify the commands through which they are being launched in client and cluster mode. The driver memory was verified locally using jps -v. Setting up -Xmx parameter in the javaExtraOptions raises exception with the info provided. Author: Dhruve Ashar Closes #12115 from dhruve/impr/SPARK-12384. --- core/src/main/scala/org/apache/spark/SparkConf.scala | 6 +++--- .../apache/spark/launcher/WorkerCommandBuilder.scala | 1 - .../apache/spark/launcher/AbstractCommandBuilder.java | 3 ++- .../spark/launcher/SparkClassCommandBuilder.java | 13 ++++++++++--- .../spark/launcher/SparkSubmitCommandBuilder.java | 19 +++++++++++++++---- .../launcher/SparkSubmitCommandBuilderSuite.java | 4 +--- .../scala/org/apache/spark/deploy/yarn/Client.scala | 8 ++++---- .../apache/spark/deploy/yarn/ExecutorRunnable.scala | 1 - 8 files changed, 35 insertions(+), 20 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala index e0fd248c43..acce6bc24f 100644 --- a/core/src/main/scala/org/apache/spark/SparkConf.scala +++ b/core/src/main/scala/org/apache/spark/SparkConf.scala @@ -456,9 +456,9 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging { "Set them directly on a SparkConf or in a properties file when using ./bin/spark-submit." throw new Exception(msg) } - if (javaOpts.contains("-Xmx") || javaOpts.contains("-Xms")) { - val msg = s"$executorOptsKey is not allowed to alter memory settings (was '$javaOpts'). " + - "Use spark.executor.memory instead." + if (javaOpts.contains("-Xmx")) { + val msg = s"$executorOptsKey is not allowed to specify max heap memory settings " + + s"(was '$javaOpts'). Use spark.executor.memory instead." throw new Exception(msg) } } diff --git a/core/src/main/scala/org/apache/spark/launcher/WorkerCommandBuilder.scala b/core/src/main/scala/org/apache/spark/launcher/WorkerCommandBuilder.scala index a2add61617..31b9c5edf0 100644 --- a/core/src/main/scala/org/apache/spark/launcher/WorkerCommandBuilder.scala +++ b/core/src/main/scala/org/apache/spark/launcher/WorkerCommandBuilder.scala @@ -37,7 +37,6 @@ private[spark] class WorkerCommandBuilder(sparkHome: String, memoryMb: Int, comm override def buildCommand(env: JMap[String, String]): JList[String] = { val cmd = buildJavaCommand(command.classPathEntries.mkString(File.pathSeparator)) - cmd.add(s"-Xms${memoryMb}M") cmd.add(s"-Xmx${memoryMb}M") command.javaOpts.foreach(cmd.add) CommandBuilderUtils.addPermGenSizeOpt(cmd) diff --git a/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java b/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java index 7a5e37c501..c7488082ca 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java +++ b/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java @@ -74,7 +74,8 @@ abstract class AbstractCommandBuilder { * SparkLauncher constructor that takes an environment), and may be modified to * include other variables needed by the process to be executed. */ - abstract List buildCommand(Map env) throws IOException; + abstract List buildCommand(Map env) + throws IOException, IllegalArgumentException; /** * Builds a list of arguments to run java. diff --git a/launcher/src/main/java/org/apache/spark/launcher/SparkClassCommandBuilder.java b/launcher/src/main/java/org/apache/spark/launcher/SparkClassCommandBuilder.java index 6b9d36cc0b..82b593a3f7 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/SparkClassCommandBuilder.java +++ b/launcher/src/main/java/org/apache/spark/launcher/SparkClassCommandBuilder.java @@ -41,7 +41,8 @@ class SparkClassCommandBuilder extends AbstractCommandBuilder { } @Override - public List buildCommand(Map env) throws IOException { + public List buildCommand(Map env) + throws IOException, IllegalArgumentException { List javaOptsKeys = new ArrayList<>(); String memKey = null; String extraClassPath = null; @@ -80,12 +81,18 @@ class SparkClassCommandBuilder extends AbstractCommandBuilder { } List cmd = buildJavaCommand(extraClassPath); + for (String key : javaOptsKeys) { - addOptionString(cmd, System.getenv(key)); + String envValue = System.getenv(key); + if (!isEmpty(envValue) && envValue.contains("Xmx")) { + String msg = String.format("%s is not allowed to specify max heap(Xmx) memory settings " + + "(was %s). Use the corresponding configuration instead.", key, envValue); + throw new IllegalArgumentException(msg); + } + addOptionString(cmd, envValue); } String mem = firstNonEmpty(memKey != null ? System.getenv(memKey) : null, DEFAULT_MEM); - cmd.add("-Xms" + mem); cmd.add("-Xmx" + mem); addPermGenSizeOpt(cmd); cmd.add(className); diff --git a/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java b/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java index c31c42cd3a..6941ca903c 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java +++ b/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java @@ -132,7 +132,8 @@ class SparkSubmitCommandBuilder extends AbstractCommandBuilder { } @Override - public List buildCommand(Map env) throws IOException { + public List buildCommand(Map env) + throws IOException, IllegalArgumentException { if (PYSPARK_SHELL_RESOURCE.equals(appResource) && !printInfo) { return buildPySparkShellCommand(env); } else if (SPARKR_SHELL_RESOURCE.equals(appResource) && !printInfo) { @@ -211,7 +212,8 @@ class SparkSubmitCommandBuilder extends AbstractCommandBuilder { return args; } - private List buildSparkSubmitCommand(Map env) throws IOException { + private List buildSparkSubmitCommand(Map env) + throws IOException, IllegalArgumentException { // Load the properties file and check whether spark-submit will be running the app's driver // or just launching a cluster app. When running the driver, the JVM's argument will be // modified to cover the driver's configuration. @@ -227,6 +229,16 @@ class SparkSubmitCommandBuilder extends AbstractCommandBuilder { addOptionString(cmd, System.getenv("SPARK_SUBMIT_OPTS")); addOptionString(cmd, System.getenv("SPARK_JAVA_OPTS")); + // We don't want the client to specify Xmx. These have to be set by their corresponding + // memory flag --driver-memory or configuration entry spark.driver.memory + String driverExtraJavaOptions = config.get(SparkLauncher.DRIVER_EXTRA_JAVA_OPTIONS); + if (!isEmpty(driverExtraJavaOptions) && driverExtraJavaOptions.contains("Xmx")) { + String msg = String.format("Not allowed to specify max heap(Xmx) memory settings through " + + "java options (was %s). Use the corresponding --driver-memory or " + + "spark.driver.memory configuration instead.", driverExtraJavaOptions); + throw new IllegalArgumentException(msg); + } + if (isClientMode) { // Figuring out where the memory value come from is a little tricky due to precedence. // Precedence is observed in the following order: @@ -240,9 +252,8 @@ class SparkSubmitCommandBuilder extends AbstractCommandBuilder { isThriftServer(mainClass) ? System.getenv("SPARK_DAEMON_MEMORY") : null; String memory = firstNonEmpty(tsMemory, config.get(SparkLauncher.DRIVER_MEMORY), System.getenv("SPARK_DRIVER_MEMORY"), System.getenv("SPARK_MEM"), DEFAULT_MEM); - cmd.add("-Xms" + memory); cmd.add("-Xmx" + memory); - addOptionString(cmd, config.get(SparkLauncher.DRIVER_EXTRA_JAVA_OPTIONS)); + addOptionString(cmd, driverExtraJavaOptions); mergeEnvPathList(env, getLibPathEnvName(), config.get(SparkLauncher.DRIVER_EXTRA_LIBRARY_PATH)); } diff --git a/launcher/src/test/java/org/apache/spark/launcher/SparkSubmitCommandBuilderSuite.java b/launcher/src/test/java/org/apache/spark/launcher/SparkSubmitCommandBuilderSuite.java index 29cbbe825b..c7e8b2e03a 100644 --- a/launcher/src/test/java/org/apache/spark/launcher/SparkSubmitCommandBuilderSuite.java +++ b/launcher/src/test/java/org/apache/spark/launcher/SparkSubmitCommandBuilderSuite.java @@ -79,7 +79,6 @@ public class SparkSubmitCommandBuilderSuite extends BaseSuite { assertTrue(findInStringList(env.get(CommandBuilderUtils.getLibPathEnvName()), File.pathSeparator, "/driverLibPath")); assertTrue(findInStringList(findArgValue(cmd, "-cp"), File.pathSeparator, "/driverCp")); - assertTrue("Driver -Xms should be configured.", cmd.contains("-Xms42g")); assertTrue("Driver -Xmx should be configured.", cmd.contains("-Xmx42g")); assertTrue("Command should contain user-defined conf.", Collections.indexOfSubList(cmd, Arrays.asList(parser.CONF, "spark.randomOption=foo")) > 0); @@ -202,12 +201,11 @@ public class SparkSubmitCommandBuilderSuite extends BaseSuite { // Checks below are different for driver and non-driver mode. if (isDriver) { - assertTrue("Driver -Xms should be configured.", cmd.contains("-Xms1g")); assertTrue("Driver -Xmx should be configured.", cmd.contains("-Xmx1g")); } else { boolean found = false; for (String arg : cmd) { - if (arg.startsWith("-Xms") || arg.startsWith("-Xmx")) { + if (arg.startsWith("-Xmx")) { found = true; break; } diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index 5e7e3be08d..04e91f8553 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -839,16 +839,16 @@ private[spark] class Client( // Validate and include yarn am specific java options in yarn-client mode. sparkConf.get(AM_JAVA_OPTIONS).foreach { opts => if (opts.contains("-Dspark")) { - val msg = s"$${amJavaOptions.key} is not allowed to set Spark options (was '$opts'). " + val msg = s"${AM_JAVA_OPTIONS.key} is not allowed to set Spark options (was '$opts')." throw new SparkException(msg) } - if (opts.contains("-Xmx") || opts.contains("-Xms")) { - val msg = s"$${amJavaOptions.key} is not allowed to alter memory settings (was '$opts')." + if (opts.contains("-Xmx")) { + val msg = s"${AM_JAVA_OPTIONS.key} is not allowed to specify max heap memory settings " + + s"(was '$opts'). Use spark.yarn.am.memory instead." throw new SparkException(msg) } javaOpts ++= Utils.splitCommandString(opts).map(YarnSparkHadoopUtil.escapeForShell) } - sparkConf.get(AM_LIBRARY_PATH).foreach { paths => prefixEnv = Some(getClusterPath(sparkConf, Utils.libraryPathEnvPrefix(Seq(paths)))) } diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala index 7b55d781f8..ef7908a3ef 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala @@ -147,7 +147,6 @@ private[yarn] class ExecutorRunnable( // Set the JVM memory val executorMemoryString = executorMemory + "m" - javaOpts += "-Xms" + executorMemoryString javaOpts += "-Xmx" + executorMemoryString // Set extra Java options for the executor, if defined -- cgit v1.2.3 From 3aa7d76395a76fb804fc2f51a39c3179208c33a5 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Thu, 7 Apr 2016 10:51:49 -0700 Subject: [SQL][TESTS] Fix for flaky test in ContinuousQueryManagerSuite ## What changes were proposed in this pull request? The timeouts were lower the other timeouts in the test. Other tests were stable over the last month. ## How was this patch tested? Jenkins tests. Author: Tathagata Das Closes #12219 from tdas/flaky-test-fix. --- .../org/apache/spark/sql/streaming/ContinuousQueryManagerSuite.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/ContinuousQueryManagerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/ContinuousQueryManagerSuite.scala index 33787de9da..3d69c8a187 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/ContinuousQueryManagerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/ContinuousQueryManagerSuite.scala @@ -185,8 +185,8 @@ class ContinuousQueryManagerSuite extends StreamTest with SharedSQLContext with val q2 = stopRandomQueryAsync(100 milliseconds, withError = true) testAwaitAnyTermination( ExpectException[SparkException], - awaitTimeout = 1 seconds, - testBehaviorFor = 2 seconds) + awaitTimeout = 4 seconds, + testBehaviorFor = 6 seconds) require(!q2.isActive) // should be inactive by the time the prev awaitAnyTerm returned // All subsequent calls to awaitAnyTermination should throw the exception -- cgit v1.2.3 From 8dcb0c7c974e9707933ac2ae6ce837e765a5e81a Mon Sep 17 00:00:00 2001 From: Kousuke Saruta Date: Thu, 7 Apr 2016 11:03:39 -0700 Subject: [SPARK-14456][SQL][MINOR] Remove unused variables and logics in DataSource ## What changes were proposed in this pull request? In DataSource#write method, the variables `dataSchema` and `equality`, and related logics are no longer used. Let's remove them. ## How was this patch tested? Existing tests. Author: Kousuke Saruta Closes #12237 from sarutak/SPARK-14456. --- .../apache/spark/sql/execution/datasources/DataSource.scala | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala index 1850810270..f55cedb1b6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala @@ -348,16 +348,6 @@ case class DataSource( PartitioningUtils.validatePartitionColumnDataTypes( data.schema, partitionColumns, caseSensitive) - val equality = - if (sqlContext.conf.caseSensitiveAnalysis) { - org.apache.spark.sql.catalyst.analysis.caseSensitiveResolution - } else { - org.apache.spark.sql.catalyst.analysis.caseInsensitiveResolution - } - - val dataSchema = StructType( - data.schema.filterNot(f => partitionColumns.exists(equality(_, f.name)))) - // If we are appending to a table that already exists, make sure the partitioning matches // up. If we fail to load the table for whatever reason, ignore the check. if (mode == SaveMode.Append) { -- cgit v1.2.3 From aa852215f82876977d164f371627e894e86baacc Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Thu, 7 Apr 2016 11:51:34 -0700 Subject: [SPARK-12740] [SPARK-13932] support grouping()/grouping_id() in having/order clause ## What changes were proposed in this pull request? This PR brings the support of using grouping()/grouping_id() in HAVING/ORDER BY clause. The resolved grouping()/grouping_id() will be replaced by unresolved "spark_gropuing_id" virtual attribute, then resolved by ResolveMissingAttribute. This PR also fix the HAVING clause that access a grouping column that is not presented in SELECT clause, for example: ```sql select count(1) from (select 1 as a) t group by a having a > 0 ``` ## How was this patch tested? Add new tests. Author: Davies Liu Closes #12235 from davies/grouping_having. --- .../spark/sql/catalyst/analysis/Analyzer.scala | 181 ++++++++++++++------- .../catalyst/expressions/namedExpressions.scala | 4 +- .../scala/org/apache/spark/sql/SQLQuerySuite.scala | 82 ++++++++++ 3 files changed, 211 insertions(+), 56 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index bc8cf4e78a..7bcba421fd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -87,7 +87,7 @@ class Analyzer( ResolveGroupingAnalytics :: ResolvePivot :: ResolveOrdinalInOrderByAndGroupBy :: - ResolveSortReferences :: + ResolveMissingReferences :: ResolveGenerate :: ResolveFunctions :: ResolveAliases :: @@ -228,21 +228,56 @@ class Analyzer( Seq.tabulate(1 << c.groupByExprs.length)(i => i) } - private def hasGroupingId(expr: Seq[Expression]): Boolean = { - expr.exists(_.collectFirst { - case u: UnresolvedAttribute if resolver(u.name, VirtualColumn.groupingIdName) => u - }.isDefined) + private def hasGroupingAttribute(expr: Expression): Boolean = { + expr.collectFirst { + case u: UnresolvedAttribute if resolver(u.name, VirtualColumn.hiveGroupingIdName) => u + }.isDefined } - def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + private def hasGroupingFunction(e: Expression): Boolean = { + e.collectFirst { + case g: Grouping => g + case g: GroupingID => g + }.isDefined + } + + private def replaceGroupingFunc( + expr: Expression, + groupByExprs: Seq[Expression], + gid: Expression): Expression = { + expr transform { + case e: GroupingID => + if (e.groupByExprs.isEmpty || e.groupByExprs == groupByExprs) { + gid + } else { + throw new AnalysisException( + s"Columns of grouping_id (${e.groupByExprs.mkString(",")}) does not match " + + s"grouping columns (${groupByExprs.mkString(",")})") + } + case Grouping(col: Expression) => + val idx = groupByExprs.indexOf(col) + if (idx >= 0) { + Cast(BitwiseAnd(ShiftRight(gid, Literal(groupByExprs.length - 1 - idx)), + Literal(1)), ByteType) + } else { + throw new AnalysisException(s"Column of grouping ($col) can't be found " + + s"in grouping columns ${groupByExprs.mkString(",")}") + } + } + } + + // This require transformUp to replace grouping()/grouping_id() in resolved Filter/Sort + def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { case a if !a.childrenResolved => a // be sure all of the children are resolved. + case p if p.expressions.exists(hasGroupingAttribute) => + failAnalysis( + s"${VirtualColumn.hiveGroupingIdName} is deprecated; use grouping_id() instead") + case Aggregate(Seq(c @ Cube(groupByExprs)), aggregateExpressions, child) => GroupingSets(bitmasks(c), groupByExprs, child, aggregateExpressions) case Aggregate(Seq(r @ Rollup(groupByExprs)), aggregateExpressions, child) => GroupingSets(bitmasks(r), groupByExprs, child, aggregateExpressions) - case g: GroupingSets if g.expressions.exists(!_.resolved) && hasGroupingId(g.expressions) => - failAnalysis( - s"${VirtualColumn.groupingIdName} is deprecated; use grouping_id() instead") + // Ensure all the expressions have been resolved. case x: GroupingSets if x.expressions.forall(_.resolved) => val gid = AttributeReference(VirtualColumn.groupingIdName, IntegerType, false)() @@ -270,7 +305,7 @@ class Analyzer( def isPartOfAggregation(e: Expression): Boolean = { aggsBuffer.exists(a => a.find(_ eq e).isDefined) } - expr.transformDown { + replaceGroupingFunc(expr, x.groupByExprs, gid).transformDown { // AggregateExpression should be computed on the unmodified value of its argument // expressions, so we should not replace any references to grouping expression // inside it. @@ -278,23 +313,6 @@ class Analyzer( aggsBuffer += e e case e if isPartOfAggregation(e) => e - case e: GroupingID => - if (e.groupByExprs.isEmpty || e.groupByExprs == x.groupByExprs) { - gid - } else { - throw new AnalysisException( - s"Columns of grouping_id (${e.groupByExprs.mkString(",")}) does not match " + - s"grouping columns (${x.groupByExprs.mkString(",")})") - } - case Grouping(col: Expression) => - val idx = x.groupByExprs.indexOf(col) - if (idx >= 0) { - Cast(BitwiseAnd(ShiftRight(gid, Literal(x.groupByExprs.length - 1 - idx)), - Literal(1)), ByteType) - } else { - throw new AnalysisException(s"Column of grouping ($col) can't be found " + - s"in grouping columns ${x.groupByExprs.mkString(",")}") - } case e => val index = groupByAliases.indexWhere(_.child.semanticEquals(e)) if (index == -1) { @@ -306,9 +324,37 @@ class Analyzer( } Aggregate( - groupByAttributes :+ VirtualColumn.groupingIdAttribute, + groupByAttributes :+ gid, aggregations, Expand(x.bitmasks, groupByAliases, groupByAttributes, gid, x.child)) + + case f @ Filter(cond, child) if hasGroupingFunction(cond) => + val groupingExprs = findGroupingExprs(child) + // The unresolved grouping id will be resolved by ResolveMissingReferences + val newCond = replaceGroupingFunc(cond, groupingExprs, VirtualColumn.groupingIdAttribute) + f.copy(condition = newCond) + + case s @ Sort(order, _, child) if order.exists(hasGroupingFunction) => + val groupingExprs = findGroupingExprs(child) + val gid = VirtualColumn.groupingIdAttribute + // The unresolved grouping id will be resolved by ResolveMissingReferences + val newOrder = order.map(replaceGroupingFunc(_, groupingExprs, gid).asInstanceOf[SortOrder]) + s.copy(order = newOrder) + } + + private def findGroupingExprs(plan: LogicalPlan): Seq[Expression] = { + plan.collectFirst { + case a: Aggregate => + // this Aggregate should have grouping id as the last grouping key. + val gid = a.groupingExpressions.last + if (!gid.isInstanceOf[AttributeReference] + || gid.asInstanceOf[AttributeReference].name != VirtualColumn.groupingIdName) { + failAnalysis(s"grouping()/grouping_id() can only be used with GroupingSets/Cube/Rollup") + } + a.groupingExpressions.take(a.groupingExpressions.length - 1) + }.getOrElse { + failAnalysis(s"grouping()/grouping_id() can only be used with GroupingSets/Cube/Rollup") + } } } @@ -663,13 +709,15 @@ class Analyzer( * clause. This rule detects such queries and adds the required attributes to the original * projection, so that they will be available during sorting. Another projection is added to * remove these attributes after sorting. + * + * The HAVING clause could also used a grouping columns that is not presented in the SELECT. */ - object ResolveSortReferences extends Rule[LogicalPlan] { + object ResolveMissingReferences extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { // Skip sort with aggregate. This will be handled in ResolveAggregateFunctions case sa @ Sort(_, _, child: Aggregate) => sa - case s @ Sort(order, _, child) if !s.resolved && child.resolved => + case s @ Sort(order, _, child) if child.resolved => try { val newOrder = order.map(resolveExpressionRecursively(_, child).asInstanceOf[SortOrder]) val requiredAttrs = AttributeSet(newOrder).filter(_.resolved) @@ -689,6 +737,26 @@ class Analyzer( // in Sort case ae: AnalysisException => s } + + case f @ Filter(cond, child) if child.resolved => + try { + val newCond = resolveExpressionRecursively(cond, child) + val requiredAttrs = newCond.references.filter(_.resolved) + val missingAttrs = requiredAttrs -- child.outputSet + if (missingAttrs.nonEmpty) { + // Add missing attributes and then project them away. + Project(child.output, + Filter(newCond, addMissingAttr(child, missingAttrs))) + } else if (newCond != cond) { + f.copy(condition = newCond) + } else { + f + } + } catch { + // Attempting to resolve it might fail. When this happens, return the original plan. + // Users will see an AnalysisException for resolution failure of missing attributes + case ae: AnalysisException => f + } } /** @@ -843,27 +911,33 @@ class Analyzer( if aggregate.resolved => // Try resolving the condition of the filter as though it is in the aggregate clause - val aggregatedCondition = - Aggregate( - grouping, - Alias(havingCondition, "havingCondition")(isGenerated = true) :: Nil, - child) - val resolvedOperator = execute(aggregatedCondition) - def resolvedAggregateFilter = - resolvedOperator - .asInstanceOf[Aggregate] - .aggregateExpressions.head - - // If resolution was successful and we see the filter has an aggregate in it, add it to - // the original aggregate operator. - if (resolvedOperator.resolved && containsAggregate(resolvedAggregateFilter)) { - val aggExprsWithHaving = resolvedAggregateFilter +: originalAggExprs - - Project(aggregate.output, - Filter(resolvedAggregateFilter.toAttribute, - aggregate.copy(aggregateExpressions = aggExprsWithHaving))) - } else { - filter + try { + val aggregatedCondition = + Aggregate( + grouping, + Alias(havingCondition, "havingCondition")(isGenerated = true) :: Nil, + child) + val resolvedOperator = execute(aggregatedCondition) + def resolvedAggregateFilter = + resolvedOperator + .asInstanceOf[Aggregate] + .aggregateExpressions.head + + // If resolution was successful and we see the filter has an aggregate in it, add it to + // the original aggregate operator. + if (resolvedOperator.resolved && containsAggregate(resolvedAggregateFilter)) { + val aggExprsWithHaving = resolvedAggregateFilter +: originalAggExprs + + Project(aggregate.output, + Filter(resolvedAggregateFilter.toAttribute, + aggregate.copy(aggregateExpressions = aggExprsWithHaving))) + } else { + filter + } + } catch { + // Attempting to resolve in the aggregate can result in ambiguity. When this happens, + // just return the original plan. + case ae: AnalysisException => filter } case sort @ Sort(sortOrder, global, aggregate: Aggregate) if aggregate.resolved => @@ -927,11 +1001,8 @@ class Analyzer( } } - private def isAggregateExpression(e: Expression): Boolean = { - e.isInstanceOf[AggregateExpression] || e.isInstanceOf[Grouping] || e.isInstanceOf[GroupingID] - } def containsAggregate(condition: Expression): Boolean = { - condition.find(isAggregateExpression).isDefined + condition.find(_.isInstanceOf[AggregateExpression]).isDefined } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala index 2307122ea1..78310fb2f1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala @@ -333,6 +333,8 @@ case class PrettyAttribute( } object VirtualColumn { - val groupingIdName: String = "grouping__id" + // The attribute name used by Hive, which has different result than Spark, deprecated. + val hiveGroupingIdName: String = "grouping__id" + val groupingIdName: String = "spark_grouping_id" val groupingIdAttribute: UnresolvedAttribute = UnresolvedAttribute(groupingIdName) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 2ab7c1581c..dd648cdb81 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -2230,6 +2230,88 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { assert(error.getMessage contains "grouping__id is deprecated; use grouping_id() instead") } + test("grouping and grouping_id in having") { + checkAnswer( + sql("select course, year from courseSales group by cube(course, year)" + + " having grouping(year) = 1 and grouping_id(course, year) > 0"), + Row("Java", null) :: + Row("dotNET", null) :: + Row(null, null) :: Nil + ) + + var error = intercept[AnalysisException] { + sql("select course, year from courseSales group by course, year" + + " having grouping(course) > 0") + } + assert(error.getMessage contains + "grouping()/grouping_id() can only be used with GroupingSets/Cube/Rollup") + error = intercept[AnalysisException] { + sql("select course, year from courseSales group by course, year" + + " having grouping_id(course, year) > 0") + } + assert(error.getMessage contains + "grouping()/grouping_id() can only be used with GroupingSets/Cube/Rollup") + error = intercept[AnalysisException] { + sql("select course, year from courseSales group by cube(course, year)" + + " having grouping__id > 0") + } + assert(error.getMessage contains "grouping__id is deprecated; use grouping_id() instead") + } + + test("grouping and grouping_id in sort") { + checkAnswer( + sql("select course, year, grouping(course), grouping(year) from courseSales" + + " group by cube(course, year) order by grouping_id(course, year), course, year"), + Row("Java", 2012, 0, 0) :: + Row("Java", 2013, 0, 0) :: + Row("dotNET", 2012, 0, 0) :: + Row("dotNET", 2013, 0, 0) :: + Row("Java", null, 0, 1) :: + Row("dotNET", null, 0, 1) :: + Row(null, 2012, 1, 0) :: + Row(null, 2013, 1, 0) :: + Row(null, null, 1, 1) :: Nil + ) + + checkAnswer( + sql("select course, year, grouping_id(course, year) from courseSales" + + " group by cube(course, year) order by grouping(course), grouping(year), course, year"), + Row("Java", 2012, 0) :: + Row("Java", 2013, 0) :: + Row("dotNET", 2012, 0) :: + Row("dotNET", 2013, 0) :: + Row("Java", null, 1) :: + Row("dotNET", null, 1) :: + Row(null, 2012, 2) :: + Row(null, 2013, 2) :: + Row(null, null, 3) :: Nil + ) + + var error = intercept[AnalysisException] { + sql("select course, year from courseSales group by course, year" + + " order by grouping(course)") + } + assert(error.getMessage contains + "grouping()/grouping_id() can only be used with GroupingSets/Cube/Rollup") + error = intercept[AnalysisException] { + sql("select course, year from courseSales group by course, year" + + " order by grouping_id(course, year)") + } + assert(error.getMessage contains + "grouping()/grouping_id() can only be used with GroupingSets/Cube/Rollup") + error = intercept[AnalysisException] { + sql("select course, year from courseSales group by cube(course, year)" + + " order by grouping__id") + } + assert(error.getMessage contains "grouping__id is deprecated; use grouping_id() instead") + } + + test("filter on a grouping column that is not presented in SELECT") { + checkAnswer( + sql("select count(1) from (select 1 as a) t group by a having a > 0"), + Row(1) :: Nil) + } + test("SPARK-13056: Null in map value causes NPE") { val df = Seq(1 -> Map("abc" -> "somestring", "cba" -> null)).toDF("key", "value") withTempTable("maptest") { -- cgit v1.2.3 From ae1db91d158d1ae62a0ab7ea74467679ca050101 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Thu, 7 Apr 2016 16:23:17 -0700 Subject: [SPARK-14410][SQL] Push functions existence check into catalog ## What changes were proposed in this pull request? This is a followup to #12117 and addresses some of the TODOs introduced there. In particular, the resolution of database is now pushed into session catalog, which knows about the current database. Further, the logic for checking whether a function exists is pushed into the external catalog. No change in functionality is expected. ## How was this patch tested? `SessionCatalogSuite`, `DDLSuite` Author: Andrew Or Closes #12198 from andrewor14/function-exists. --- .../sql/catalyst/catalog/InMemoryCatalog.scala | 10 ++-- .../sql/catalyst/catalog/SessionCatalog.scala | 59 ++++++++++------------ .../spark/sql/catalyst/catalog/interface.scala | 2 + .../sql/catalyst/catalog/SessionCatalogSuite.scala | 53 ++++++++++--------- .../spark/sql/execution/command/commands.scala | 2 +- .../apache/spark/sql/execution/command/ddl.scala | 13 ++--- .../spark/sql/execution/command/functions.scala | 31 ++++-------- .../spark/sql/execution/command/DDLSuite.scala | 41 +++++++-------- .../spark/sql/hive/HiveExternalCatalog.scala | 4 ++ .../apache/spark/sql/hive/client/HiveClient.scala | 5 ++ .../spark/sql/hive/client/HiveClientImpl.scala | 10 ++-- .../spark/sql/hive/HiveMetastoreCatalogSuite.scala | 8 +-- .../spark/sql/hive/HiveSparkSubmitSuite.scala | 2 +- 13 files changed, 126 insertions(+), 114 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala index 5d136b663f..186bbccef1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala @@ -47,11 +47,6 @@ class InMemoryCatalog extends ExternalCatalog { // Database name -> description private val catalog = new scala.collection.mutable.HashMap[String, DatabaseDesc] - private def functionExists(db: String, funcName: String): Boolean = { - requireDbExists(db) - catalog(db).functions.contains(funcName) - } - private def partitionExists(db: String, table: String, spec: TablePartitionSpec): Boolean = { requireTableExists(db, table) catalog(db).tables(table).partitions.contains(spec) @@ -315,6 +310,11 @@ class InMemoryCatalog extends ExternalCatalog { catalog(db).functions(funcName) } + override def functionExists(db: String, funcName: String): Boolean = { + requireDbExists(db) + catalog(db).functions.contains(funcName) + } + override def listFunctions(db: String, pattern: String): Seq[String] = synchronized { requireDbExists(db) StringUtils.filterPattern(catalog(db).functions.keysIterator.toSeq, pattern) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala index 2acf584e8f..7db9fd0527 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala @@ -95,7 +95,7 @@ class SessionCatalog( externalCatalog.alterDatabase(dbDefinition) } - def getDatabase(db: String): CatalogDatabase = { + def getDatabaseMetadata(db: String): CatalogDatabase = { externalCatalog.getDatabase(db) } @@ -169,7 +169,7 @@ class SessionCatalog( * If no database is specified, assume the table is in the current database. * If the specified table is not found in the database then an [[AnalysisException]] is thrown. */ - def getTable(name: TableIdentifier): CatalogTable = { + def getTableMetadata(name: TableIdentifier): CatalogTable = { val db = name.database.getOrElse(currentDb) val table = formatTableName(name.table) externalCatalog.getTable(db, table) @@ -435,28 +435,37 @@ class SessionCatalog( * Create a metastore function in the database specified in `funcDefinition`. * If no such database is specified, create it in the current database. */ - def createFunction(funcDefinition: CatalogFunction): Unit = { + def createFunction(funcDefinition: CatalogFunction, ignoreIfExists: Boolean): Unit = { val db = funcDefinition.identifier.database.getOrElse(currentDb) - val newFuncDefinition = funcDefinition.copy( - identifier = FunctionIdentifier(funcDefinition.identifier.funcName, Some(db))) - externalCatalog.createFunction(db, newFuncDefinition) + val identifier = FunctionIdentifier(funcDefinition.identifier.funcName, Some(db)) + val newFuncDefinition = funcDefinition.copy(identifier = identifier) + if (!functionExists(identifier)) { + externalCatalog.createFunction(db, newFuncDefinition) + } else if (!ignoreIfExists) { + throw new AnalysisException(s"function '$identifier' already exists in database '$db'") + } } /** * Drop a metastore function. * If no database is specified, assume the function is in the current database. */ - def dropFunction(name: FunctionIdentifier): Unit = { + def dropFunction(name: FunctionIdentifier, ignoreIfNotExists: Boolean): Unit = { val db = name.database.getOrElse(currentDb) - val qualified = name.copy(database = Some(db)).unquotedString - if (functionRegistry.functionExists(qualified)) { - // If we have loaded this function into the FunctionRegistry, - // also drop it from there. - // For a permanent function, because we loaded it to the FunctionRegistry - // when it's first used, we also need to drop it from the FunctionRegistry. - functionRegistry.dropFunction(qualified) + val identifier = name.copy(database = Some(db)) + if (functionExists(identifier)) { + // TODO: registry should just take in FunctionIdentifier for type safety + if (functionRegistry.functionExists(identifier.unquotedString)) { + // If we have loaded this function into the FunctionRegistry, + // also drop it from there. + // For a permanent function, because we loaded it to the FunctionRegistry + // when it's first used, we also need to drop it from the FunctionRegistry. + functionRegistry.dropFunction(identifier.unquotedString) + } + externalCatalog.dropFunction(db, name.funcName) + } else if (!ignoreIfNotExists) { + throw new AnalysisException(s"function '$identifier' does not exist in database '$db'") } - externalCatalog.dropFunction(db, name.funcName) } /** @@ -465,8 +474,7 @@ class SessionCatalog( * If a database is specified in `name`, this will return the function in that database. * If no database is specified, this will return the function in the current database. */ - // TODO: have a better name. This method is actually for fetching the metadata of a function. - def getFunction(name: FunctionIdentifier): CatalogFunction = { + def getFunctionMetadata(name: FunctionIdentifier): CatalogFunction = { val db = name.database.getOrElse(currentDb) externalCatalog.getFunction(db, name.funcName) } @@ -475,20 +483,9 @@ class SessionCatalog( * Check if the specified function exists. */ def functionExists(name: FunctionIdentifier): Boolean = { - if (functionRegistry.functionExists(name.unquotedString)) { - // This function exists in the FunctionRegistry. - true - } else { - // Need to check if this function exists in the metastore. - try { - // TODO: It's better to ask external catalog if this function exists. - // So, we can avoid of having this hacky try/catch block. - getFunction(name) != null - } catch { - case _: NoSuchFunctionException => false - case _: AnalysisException => false // HiveExternalCatalog wraps all exceptions with it. - } - } + val db = name.database.getOrElse(currentDb) + functionRegistry.functionExists(name.unquotedString) || + externalCatalog.functionExists(db, name.funcName) } // ---------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala index 97b9946140..e29d6bd8b0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala @@ -152,6 +152,8 @@ abstract class ExternalCatalog { def getFunction(db: String, funcName: String): CatalogFunction + def functionExists(db: String, funcName: String): Boolean + def listFunctions(db: String, pattern: String): Seq[String] } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala index 4d56d001b3..1850dc8156 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala @@ -62,7 +62,7 @@ class SessionCatalogSuite extends SparkFunSuite { test("get database when a database exists") { val catalog = new SessionCatalog(newBasicCatalog()) - val db1 = catalog.getDatabase("db1") + val db1 = catalog.getDatabaseMetadata("db1") assert(db1.name == "db1") assert(db1.description.contains("db1")) } @@ -70,7 +70,7 @@ class SessionCatalogSuite extends SparkFunSuite { test("get database should throw exception when the database does not exist") { val catalog = new SessionCatalog(newBasicCatalog()) intercept[AnalysisException] { - catalog.getDatabase("db_that_does_not_exist") + catalog.getDatabaseMetadata("db_that_does_not_exist") } } @@ -128,10 +128,10 @@ class SessionCatalogSuite extends SparkFunSuite { test("alter database") { val catalog = new SessionCatalog(newBasicCatalog()) - val db1 = catalog.getDatabase("db1") + val db1 = catalog.getDatabaseMetadata("db1") // Note: alter properties here because Hive does not support altering other fields catalog.alterDatabase(db1.copy(properties = Map("k" -> "v3", "good" -> "true"))) - val newDb1 = catalog.getDatabase("db1") + val newDb1 = catalog.getDatabaseMetadata("db1") assert(db1.properties.isEmpty) assert(newDb1.properties.size == 2) assert(newDb1.properties.get("k") == Some("v3")) @@ -346,21 +346,21 @@ class SessionCatalogSuite extends SparkFunSuite { test("get table") { val externalCatalog = newBasicCatalog() val sessionCatalog = new SessionCatalog(externalCatalog) - assert(sessionCatalog.getTable(TableIdentifier("tbl1", Some("db2"))) + assert(sessionCatalog.getTableMetadata(TableIdentifier("tbl1", Some("db2"))) == externalCatalog.getTable("db2", "tbl1")) // Get table without explicitly specifying database sessionCatalog.setCurrentDatabase("db2") - assert(sessionCatalog.getTable(TableIdentifier("tbl1")) + assert(sessionCatalog.getTableMetadata(TableIdentifier("tbl1")) == externalCatalog.getTable("db2", "tbl1")) } test("get table when database/table does not exist") { val catalog = new SessionCatalog(newBasicCatalog()) intercept[AnalysisException] { - catalog.getTable(TableIdentifier("tbl1", Some("unknown_db"))) + catalog.getTableMetadata(TableIdentifier("tbl1", Some("unknown_db"))) } intercept[AnalysisException] { - catalog.getTable(TableIdentifier("unknown_table", Some("db2"))) + catalog.getTableMetadata(TableIdentifier("unknown_table", Some("db2"))) } } @@ -386,7 +386,7 @@ class SessionCatalogSuite extends SparkFunSuite { test("lookup table relation with alias") { val catalog = new SessionCatalog(newBasicCatalog()) val alias = "monster" - val tableMetadata = catalog.getTable(TableIdentifier("tbl1", Some("db2"))) + val tableMetadata = catalog.getTableMetadata(TableIdentifier("tbl1", Some("db2"))) val relation = SubqueryAlias("tbl1", CatalogRelation("db2", tableMetadata)) val relationWithAlias = SubqueryAlias(alias, @@ -659,26 +659,28 @@ class SessionCatalogSuite extends SparkFunSuite { val externalCatalog = newEmptyCatalog() val sessionCatalog = new SessionCatalog(externalCatalog) sessionCatalog.createDatabase(newDb("mydb"), ignoreIfExists = false) - sessionCatalog.createFunction(newFunc("myfunc", Some("mydb"))) + sessionCatalog.createFunction(newFunc("myfunc", Some("mydb")), ignoreIfExists = false) assert(externalCatalog.listFunctions("mydb", "*").toSet == Set("myfunc")) // Create function without explicitly specifying database sessionCatalog.setCurrentDatabase("mydb") - sessionCatalog.createFunction(newFunc("myfunc2")) + sessionCatalog.createFunction(newFunc("myfunc2"), ignoreIfExists = false) assert(externalCatalog.listFunctions("mydb", "*").toSet == Set("myfunc", "myfunc2")) } test("create function when database does not exist") { val catalog = new SessionCatalog(newBasicCatalog()) intercept[AnalysisException] { - catalog.createFunction(newFunc("func5", Some("does_not_exist"))) + catalog.createFunction( + newFunc("func5", Some("does_not_exist")), ignoreIfExists = false) } } test("create function that already exists") { val catalog = new SessionCatalog(newBasicCatalog()) intercept[AnalysisException] { - catalog.createFunction(newFunc("func1", Some("db2"))) + catalog.createFunction(newFunc("func1", Some("db2")), ignoreIfExists = false) } + catalog.createFunction(newFunc("func1", Some("db2")), ignoreIfExists = true) } test("create temp function") { @@ -711,24 +713,27 @@ class SessionCatalogSuite extends SparkFunSuite { val externalCatalog = newBasicCatalog() val sessionCatalog = new SessionCatalog(externalCatalog) assert(externalCatalog.listFunctions("db2", "*").toSet == Set("func1")) - sessionCatalog.dropFunction(FunctionIdentifier("func1", Some("db2"))) + sessionCatalog.dropFunction( + FunctionIdentifier("func1", Some("db2")), ignoreIfNotExists = false) assert(externalCatalog.listFunctions("db2", "*").isEmpty) // Drop function without explicitly specifying database sessionCatalog.setCurrentDatabase("db2") - sessionCatalog.createFunction(newFunc("func2", Some("db2"))) + sessionCatalog.createFunction(newFunc("func2", Some("db2")), ignoreIfExists = false) assert(externalCatalog.listFunctions("db2", "*").toSet == Set("func2")) - sessionCatalog.dropFunction(FunctionIdentifier("func2")) + sessionCatalog.dropFunction(FunctionIdentifier("func2"), ignoreIfNotExists = false) assert(externalCatalog.listFunctions("db2", "*").isEmpty) } test("drop function when database/function does not exist") { val catalog = new SessionCatalog(newBasicCatalog()) intercept[AnalysisException] { - catalog.dropFunction(FunctionIdentifier("something", Some("does_not_exist"))) + catalog.dropFunction( + FunctionIdentifier("something", Some("does_not_exist")), ignoreIfNotExists = false) } intercept[AnalysisException] { - catalog.dropFunction(FunctionIdentifier("does_not_exist")) + catalog.dropFunction(FunctionIdentifier("does_not_exist"), ignoreIfNotExists = false) } + catalog.dropFunction(FunctionIdentifier("does_not_exist"), ignoreIfNotExists = true) } test("drop temp function") { @@ -753,19 +758,19 @@ class SessionCatalogSuite extends SparkFunSuite { val expected = CatalogFunction(FunctionIdentifier("func1", Some("db2")), funcClass, Seq.empty[(String, String)]) - assert(catalog.getFunction(FunctionIdentifier("func1", Some("db2"))) == expected) + assert(catalog.getFunctionMetadata(FunctionIdentifier("func1", Some("db2"))) == expected) // Get function without explicitly specifying database catalog.setCurrentDatabase("db2") - assert(catalog.getFunction(FunctionIdentifier("func1")) == expected) + assert(catalog.getFunctionMetadata(FunctionIdentifier("func1")) == expected) } test("get function when database/function does not exist") { val catalog = new SessionCatalog(newBasicCatalog()) intercept[AnalysisException] { - catalog.getFunction(FunctionIdentifier("func1", Some("does_not_exist"))) + catalog.getFunctionMetadata(FunctionIdentifier("func1", Some("does_not_exist"))) } intercept[AnalysisException] { - catalog.getFunction(FunctionIdentifier("does_not_exist", Some("db2"))) + catalog.getFunctionMetadata(FunctionIdentifier("does_not_exist", Some("db2"))) } } @@ -787,8 +792,8 @@ class SessionCatalogSuite extends SparkFunSuite { val info2 = new ExpressionInfo("tempFunc2", "yes_me") val tempFunc1 = (e: Seq[Expression]) => e.head val tempFunc2 = (e: Seq[Expression]) => e.last - catalog.createFunction(newFunc("func2", Some("db2"))) - catalog.createFunction(newFunc("not_me", Some("db2"))) + catalog.createFunction(newFunc("func2", Some("db2")), ignoreIfExists = false) + catalog.createFunction(newFunc("not_me", Some("db2")), ignoreIfExists = false) catalog.createTempFunction("func1", info1, tempFunc1, ignoreIfExists = false) catalog.createTempFunction("yes_me", info2, tempFunc2, ignoreIfExists = false) assert(catalog.listFunctions("db1", "*").toSet == diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala index faa7a2cdb4..3fd2a93d29 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala @@ -407,7 +407,7 @@ case class ShowTablePropertiesCommand( if (catalog.isTemporaryTable(table)) { Seq.empty[Row] } else { - val catalogTable = sqlContext.sessionState.catalog.getTable(table) + val catalogTable = sqlContext.sessionState.catalog.getTableMetadata(table) propertyKey match { case Some(p) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala index 6d56a6fec8..20779d68e0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala @@ -124,7 +124,7 @@ case class AlterDatabaseProperties( override def run(sqlContext: SQLContext): Seq[Row] = { val catalog = sqlContext.sessionState.catalog - val db: CatalogDatabase = catalog.getDatabase(databaseName) + val db: CatalogDatabase = catalog.getDatabaseMetadata(databaseName) catalog.alterDatabase(db.copy(properties = db.properties ++ props)) Seq.empty[Row] @@ -149,7 +149,8 @@ case class DescribeDatabase( extends RunnableCommand { override def run(sqlContext: SQLContext): Seq[Row] = { - val dbMetadata: CatalogDatabase = sqlContext.sessionState.catalog.getDatabase(databaseName) + val dbMetadata: CatalogDatabase = + sqlContext.sessionState.catalog.getDatabaseMetadata(databaseName) val result = Row("Database Name", dbMetadata.name) :: Row("Description", dbMetadata.description) :: @@ -213,7 +214,7 @@ case class AlterTableSetProperties( override def run(sqlContext: SQLContext): Seq[Row] = { val catalog = sqlContext.sessionState.catalog - val table = catalog.getTable(tableName) + val table = catalog.getTableMetadata(tableName) val newProperties = table.properties ++ properties if (DDLUtils.isDatasourceTable(newProperties)) { throw new AnalysisException( @@ -243,7 +244,7 @@ case class AlterTableUnsetProperties( override def run(sqlContext: SQLContext): Seq[Row] = { val catalog = sqlContext.sessionState.catalog - val table = catalog.getTable(tableName) + val table = catalog.getTableMetadata(tableName) if (DDLUtils.isDatasourceTable(table)) { throw new AnalysisException( "alter table properties is not supported for datasource tables") @@ -286,7 +287,7 @@ case class AlterTableSerDeProperties( override def run(sqlContext: SQLContext): Seq[Row] = { val catalog = sqlContext.sessionState.catalog - val table = catalog.getTable(tableName) + val table = catalog.getTableMetadata(tableName) // Do not support setting serde for datasource tables if (serdeClassName.isDefined && DDLUtils.isDatasourceTable(table)) { throw new AnalysisException( @@ -376,7 +377,7 @@ case class AlterTableSetLocation( override def run(sqlContext: SQLContext): Seq[Row] = { val catalog = sqlContext.sessionState.catalog - val table = catalog.getTable(tableName) + val table = catalog.getTableMetadata(tableName) partitionSpec match { case Some(spec) => // Partition spec is specified, so we set the location only for this partition diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/functions.scala index 66d17e322e..c6e601799f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/functions.scala @@ -47,6 +47,7 @@ case class CreateFunction( extends RunnableCommand { override def run(sqlContext: SQLContext): Seq[Row] = { + val catalog = sqlContext.sessionState.catalog if (isTemp) { if (databaseName.isDefined) { throw new AnalysisException( @@ -55,24 +56,18 @@ case class CreateFunction( } // We first load resources and then put the builder in the function registry. // Please note that it is allowed to overwrite an existing temp function. - sqlContext.sessionState.catalog.loadFunctionResources(resources) + catalog.loadFunctionResources(resources) val info = new ExpressionInfo(className, functionName) - val builder = - sqlContext.sessionState.catalog.makeFunctionBuilder(functionName, className) - sqlContext.sessionState.catalog.createTempFunction( - functionName, info, builder, ignoreIfExists = false) + val builder = catalog.makeFunctionBuilder(functionName, className) + catalog.createTempFunction(functionName, info, builder, ignoreIfExists = false) } else { // For a permanent, we will store the metadata into underlying external catalog. // This function will be loaded into the FunctionRegistry when a query uses it. // We do not load it into FunctionRegistry right now. - val dbName = databaseName.getOrElse(sqlContext.sessionState.catalog.getCurrentDatabase) - val func = FunctionIdentifier(functionName, Some(dbName)) - val catalogFunc = CatalogFunction(func, className, resources) - if (sqlContext.sessionState.catalog.functionExists(func)) { - throw new AnalysisException( - s"Function '$functionName' already exists in database '$dbName'.") - } - sqlContext.sessionState.catalog.createFunction(catalogFunc) + // TODO: should we also parse "IF NOT EXISTS"? + catalog.createFunction( + CatalogFunction(FunctionIdentifier(functionName, databaseName), className, resources), + ignoreIfExists = false) } Seq.empty[Row] } @@ -101,13 +96,9 @@ case class DropFunction( catalog.dropTempFunction(functionName, ifExists) } else { // We are dropping a permanent function. - val dbName = databaseName.getOrElse(catalog.getCurrentDatabase) - val func = FunctionIdentifier(functionName, Some(dbName)) - if (!ifExists && !catalog.functionExists(func)) { - throw new AnalysisException( - s"Function '$functionName' does not exist in database '$dbName'.") - } - catalog.dropFunction(func) + catalog.dropFunction( + FunctionIdentifier(functionName, databaseName), + ignoreIfNotExists = ifExists) } Seq.empty[Row] } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala index a8db4e9923..7084665b3b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala @@ -88,7 +88,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { val dbNameWithoutBackTicks = cleanIdentifier(dbName) sql(s"CREATE DATABASE $dbName") - val db1 = catalog.getDatabase(dbNameWithoutBackTicks) + val db1 = catalog.getDatabaseMetadata(dbNameWithoutBackTicks) assert(db1 == CatalogDatabase( dbNameWithoutBackTicks, "", @@ -110,7 +110,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { try { val dbNameWithoutBackTicks = cleanIdentifier(dbName) sql(s"CREATE DATABASE $dbName") - val db1 = catalog.getDatabase(dbNameWithoutBackTicks) + val db1 = catalog.getDatabaseMetadata(dbNameWithoutBackTicks) assert(db1 == CatalogDatabase( dbNameWithoutBackTicks, "", @@ -233,14 +233,15 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { val tableIdent = TableIdentifier("tab1", Some("dbx")) createDatabase(catalog, "dbx") createTable(catalog, tableIdent) - assert(catalog.getTable(tableIdent).properties.isEmpty) + assert(catalog.getTableMetadata(tableIdent).properties.isEmpty) // set table properties sql("ALTER TABLE dbx.tab1 SET TBLPROPERTIES ('andrew' = 'or14', 'kor' = 'bel')") - assert(catalog.getTable(tableIdent).properties == Map("andrew" -> "or14", "kor" -> "bel")) + assert(catalog.getTableMetadata(tableIdent).properties == + Map("andrew" -> "or14", "kor" -> "bel")) // set table properties without explicitly specifying database catalog.setCurrentDatabase("dbx") sql("ALTER TABLE tab1 SET TBLPROPERTIES ('kor' = 'belle', 'kar' = 'bol')") - assert(catalog.getTable(tableIdent).properties == + assert(catalog.getTableMetadata(tableIdent).properties == Map("andrew" -> "or14", "kor" -> "belle", "kar" -> "bol")) // table to alter does not exist intercept[AnalysisException] { @@ -262,11 +263,11 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { // unset table properties sql("ALTER TABLE dbx.tab1 SET TBLPROPERTIES ('j' = 'am', 'p' = 'an', 'c' = 'lan')") sql("ALTER TABLE dbx.tab1 UNSET TBLPROPERTIES ('j')") - assert(catalog.getTable(tableIdent).properties == Map("p" -> "an", "c" -> "lan")) + assert(catalog.getTableMetadata(tableIdent).properties == Map("p" -> "an", "c" -> "lan")) // unset table properties without explicitly specifying database catalog.setCurrentDatabase("dbx") sql("ALTER TABLE tab1 UNSET TBLPROPERTIES ('p')") - assert(catalog.getTable(tableIdent).properties == Map("c" -> "lan")) + assert(catalog.getTableMetadata(tableIdent).properties == Map("c" -> "lan")) // table to alter does not exist intercept[AnalysisException] { sql("ALTER TABLE does_not_exist UNSET TBLPROPERTIES ('c' = 'lan')") @@ -278,7 +279,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { assert(e.getMessage.contains("xyz")) // property to unset does not exist, but "IF EXISTS" is specified sql("ALTER TABLE tab1 UNSET TBLPROPERTIES IF EXISTS ('c', 'xyz')") - assert(catalog.getTable(tableIdent).properties.isEmpty) + assert(catalog.getTableMetadata(tableIdent).properties.isEmpty) // throw exception for datasource tables convertToDatasourceTable(catalog, tableIdent) val e1 = intercept[AnalysisException] { @@ -393,7 +394,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { private def convertToDatasourceTable( catalog: SessionCatalog, tableIdent: TableIdentifier): Unit = { - catalog.alterTable(catalog.getTable(tableIdent).copy( + catalog.alterTable(catalog.getTableMetadata(tableIdent).copy( properties = Map("spark.sql.sources.provider" -> "csv"))) } @@ -407,15 +408,15 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { if (isDatasourceTable) { convertToDatasourceTable(catalog, tableIdent) } - assert(catalog.getTable(tableIdent).storage.locationUri.isEmpty) - assert(catalog.getTable(tableIdent).storage.serdeProperties.isEmpty) + assert(catalog.getTableMetadata(tableIdent).storage.locationUri.isEmpty) + assert(catalog.getTableMetadata(tableIdent).storage.serdeProperties.isEmpty) assert(catalog.getPartition(tableIdent, partSpec).storage.locationUri.isEmpty) assert(catalog.getPartition(tableIdent, partSpec).storage.serdeProperties.isEmpty) // Verify that the location is set to the expected string def verifyLocation(expected: String, spec: Option[TablePartitionSpec] = None): Unit = { val storageFormat = spec .map { s => catalog.getPartition(tableIdent, s).storage } - .getOrElse { catalog.getTable(tableIdent).storage } + .getOrElse { catalog.getTableMetadata(tableIdent).storage } if (isDatasourceTable) { if (spec.isDefined) { assert(storageFormat.serdeProperties.isEmpty) @@ -467,8 +468,8 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { if (isDatasourceTable) { convertToDatasourceTable(catalog, tableIdent) } - assert(catalog.getTable(tableIdent).storage.serde.isEmpty) - assert(catalog.getTable(tableIdent).storage.serdeProperties.isEmpty) + assert(catalog.getTableMetadata(tableIdent).storage.serde.isEmpty) + assert(catalog.getTableMetadata(tableIdent).storage.serdeProperties.isEmpty) // set table serde and/or properties (should fail on datasource tables) if (isDatasourceTable) { val e1 = intercept[AnalysisException] { @@ -482,22 +483,22 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { assert(e2.getMessage.contains("datasource")) } else { sql("ALTER TABLE dbx.tab1 SET SERDE 'org.apache.jadoop'") - assert(catalog.getTable(tableIdent).storage.serde == Some("org.apache.jadoop")) - assert(catalog.getTable(tableIdent).storage.serdeProperties.isEmpty) + assert(catalog.getTableMetadata(tableIdent).storage.serde == Some("org.apache.jadoop")) + assert(catalog.getTableMetadata(tableIdent).storage.serdeProperties.isEmpty) sql("ALTER TABLE dbx.tab1 SET SERDE 'org.apache.madoop' " + "WITH SERDEPROPERTIES ('k' = 'v', 'kay' = 'vee')") - assert(catalog.getTable(tableIdent).storage.serde == Some("org.apache.madoop")) - assert(catalog.getTable(tableIdent).storage.serdeProperties == + assert(catalog.getTableMetadata(tableIdent).storage.serde == Some("org.apache.madoop")) + assert(catalog.getTableMetadata(tableIdent).storage.serdeProperties == Map("k" -> "v", "kay" -> "vee")) } // set serde properties only sql("ALTER TABLE dbx.tab1 SET SERDEPROPERTIES ('k' = 'vvv', 'kay' = 'vee')") - assert(catalog.getTable(tableIdent).storage.serdeProperties == + assert(catalog.getTableMetadata(tableIdent).storage.serdeProperties == Map("k" -> "vvv", "kay" -> "vee")) // set things without explicitly specifying database catalog.setCurrentDatabase("dbx") sql("ALTER TABLE tab1 SET SERDEPROPERTIES ('kay' = 'veee')") - assert(catalog.getTable(tableIdent).storage.serdeProperties == + assert(catalog.getTableMetadata(tableIdent).storage.serdeProperties == Map("k" -> "vvv", "kay" -> "veee")) // table to alter does not exist intercept[AnalysisException] { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala index 98a5998d03..b1156fb3e2 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala @@ -292,6 +292,10 @@ private[spark] class HiveExternalCatalog(client: HiveClient) extends ExternalCat client.getFunction(db, funcName) } + override def functionExists(db: String, funcName: String): Boolean = withClient { + client.functionExists(db, funcName) + } + override def listFunctions(db: String, pattern: String): Seq[String] = withClient { client.listFunctions(db, pattern) } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClient.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClient.scala index ee56f9d75d..94794b1572 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClient.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClient.scala @@ -232,6 +232,11 @@ private[hive] trait HiveClient { /** Return an existing function in the database, or None if it doesn't exist. */ def getFunctionOption(db: String, name: String): Option[CatalogFunction] + /** Return whether a function exists in the specified database. */ + final def functionExists(db: String, name: String): Boolean = { + getFunctionOption(db, name).isDefined + } + /** Return the names of all functions that match the given pattern in the database. */ def listFunctions(db: String, pattern: String): Seq[String] diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala index 1f66fbfd85..d0eb9ddf50 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala @@ -21,7 +21,6 @@ import java.io.{File, PrintStream} import scala.collection.JavaConverters._ import scala.language.reflectiveCalls -import scala.util.Try import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path @@ -30,7 +29,8 @@ import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.metastore.{TableType => HiveTableType} import org.apache.hadoop.hive.metastore.api.{Database => HiveDatabase, FieldSchema, Function => HiveFunction, FunctionType, PrincipalType, ResourceType, ResourceUri} import org.apache.hadoop.hive.ql.Driver -import org.apache.hadoop.hive.ql.metadata.{Hive, Partition => HivePartition, Table => HiveTable} +import org.apache.hadoop.hive.ql.metadata.{Partition => HivePartition, Table => HiveTable} +import org.apache.hadoop.hive.ql.metadata.{Hive, HiveException} import org.apache.hadoop.hive.ql.plan.AddPartitionDesc import org.apache.hadoop.hive.ql.processors._ import org.apache.hadoop.hive.ql.session.SessionState @@ -559,7 +559,11 @@ private[hive] class HiveClientImpl( override def getFunctionOption( db: String, name: String): Option[CatalogFunction] = withHiveState { - Option(client.getFunction(db, name)).map(fromHiveFunction) + try { + Option(client.getFunction(db, name)).map(fromHiveFunction) + } catch { + case he: HiveException => None + } } override def listFunctions(db: String, pattern: String): Seq[String] = withHiveState { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala index 6967395613..ada8621d07 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala @@ -83,7 +83,7 @@ class DataSourceWithHiveMetastoreCatalogSuite .saveAsTable("t") } - val hiveTable = sessionState.catalog.getTable(TableIdentifier("t", Some("default"))) + val hiveTable = sessionState.catalog.getTableMetadata(TableIdentifier("t", Some("default"))) assert(hiveTable.storage.inputFormat === Some(inputFormat)) assert(hiveTable.storage.outputFormat === Some(outputFormat)) assert(hiveTable.storage.serde === Some(serde)) @@ -114,7 +114,8 @@ class DataSourceWithHiveMetastoreCatalogSuite .saveAsTable("t") } - val hiveTable = sessionState.catalog.getTable(TableIdentifier("t", Some("default"))) + val hiveTable = + sessionState.catalog.getTableMetadata(TableIdentifier("t", Some("default"))) assert(hiveTable.storage.inputFormat === Some(inputFormat)) assert(hiveTable.storage.outputFormat === Some(outputFormat)) assert(hiveTable.storage.serde === Some(serde)) @@ -144,7 +145,8 @@ class DataSourceWithHiveMetastoreCatalogSuite |AS SELECT 1 AS d1, "val_1" AS d2 """.stripMargin) - val hiveTable = sessionState.catalog.getTable(TableIdentifier("t", Some("default"))) + val hiveTable = + sessionState.catalog.getTableMetadata(TableIdentifier("t", Some("default"))) assert(hiveTable.storage.inputFormat === Some(inputFormat)) assert(hiveTable.storage.outputFormat === Some(outputFormat)) assert(hiveTable.storage.serde === Some(serde)) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala index dd2129375d..c5417b06a4 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala @@ -354,7 +354,7 @@ object PermanentHiveUDFTest2 extends Logging { FunctionIdentifier("example_max"), "org.apache.hadoop.hive.contrib.udaf.example.UDAFExampleMax", ("JAR" -> jar) :: Nil) - hiveContext.sessionState.catalog.createFunction(function) + hiveContext.sessionState.catalog.createFunction(function, ignoreIfExists = false) val source = hiveContext.createDataFrame((1 to 10).map(i => (i, s"str$i"))).toDF("key", "val") source.registerTempTable("sourceTable") -- cgit v1.2.3 From 49fb237081bbca0d811aa48aa06f4728fea62781 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 7 Apr 2016 17:23:34 -0700 Subject: [SPARK-14270][SQL] whole stage codegen support for typed filter ## What changes were proposed in this pull request? We implement typed filter by `MapPartitions`, which doesn't work well with whole stage codegen. This PR use `Filter` to implement typed filter and we can get the whole stage codegen support for free. This PR also introduced `DeserializeToObject` and `SerializeFromObject`, to seperate serialization logic from object operator, so that it's eaiser to write optimization rules for adjacent object operators. ## How was this patch tested? existing tests. Author: Wenchen Fan Closes #12061 from cloud-fan/whole-stage-codegen. --- .../spark/sql/catalyst/analysis/Analyzer.scala | 2 + .../apache/spark/sql/catalyst/dsl/package.scala | 19 ++++++ .../spark/sql/catalyst/optimizer/Optimizer.scala | 39 ++++++++++- .../spark/sql/catalyst/plans/logical/object.scala | 37 ++++++++++- .../optimizer/TypedFilterOptimizationSuite.scala | 74 +++++++++++++++++++++ .../main/scala/org/apache/spark/sql/Dataset.scala | 16 ++++- .../spark/sql/execution/SparkStrategies.scala | 4 ++ .../org/apache/spark/sql/execution/objects.scala | 67 +++++++++++++++++++ .../org/apache/spark/sql/DatasetBenchmark.scala | 76 +++++++++++++++++++--- .../scala/org/apache/spark/sql/QueryTest.scala | 2 + .../sql/execution/WholeStageCodegenSuite.scala | 21 +++++- 11 files changed, 342 insertions(+), 15 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/TypedFilterOptimizationSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 7bcba421fd..3555a6d7fa 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -1670,6 +1670,8 @@ object CleanupAliases extends Rule[LogicalPlan] { // Operators that operate on objects should only have expressions from encoders, which should // never have extra aliases. case o: ObjectOperator => o + case d: DeserializeToObject => d + case s: SerializeFromObject => s case other => var stop = false diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index 105947028d..1e7296664b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -21,6 +21,7 @@ import java.sql.{Date, Timestamp} import scala.language.implicitConversions +import org.apache.spark.sql.Encoder import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ @@ -166,6 +167,14 @@ package object dsl { case target => UnresolvedStar(Option(target)) } + def callFunction[T, U]( + func: T => U, + returnType: DataType, + argument: Expression): Expression = { + val function = Literal.create(func, ObjectType(classOf[T => U])) + Invoke(function, "apply", returnType, argument :: Nil) + } + implicit class DslSymbol(sym: Symbol) extends ImplicitAttribute { def s: String = sym.name } // TODO more implicit class for literal? implicit class DslString(val s: String) extends ImplicitOperators { @@ -270,6 +279,16 @@ package object dsl { def where(condition: Expression): LogicalPlan = Filter(condition, logicalPlan) + def filter[T : Encoder](func: T => Boolean): LogicalPlan = { + val deserialized = logicalPlan.deserialize[T] + val condition = expressions.callFunction(func, BooleanType, deserialized.output.head) + Filter(condition, deserialized).serialize[T] + } + + def serialize[T : Encoder]: LogicalPlan = CatalystSerde.serialize[T](logicalPlan) + + def deserialize[T : Encoder]: LogicalPlan = CatalystSerde.deserialize[T](logicalPlan) + def limit(limitExpr: Expression): LogicalPlan = Limit(limitExpr, logicalPlan) def join( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index f581810c26..619514e8aa 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -93,6 +93,8 @@ abstract class Optimizer extends RuleExecutor[LogicalPlan] { EliminateSerialization) :: Batch("Decimal Optimizations", FixedPoint(100), DecimalAggregates) :: + Batch("Typed Filter Optimization", FixedPoint(100), + EmbedSerializerInFilter) :: Batch("LocalRelation", FixedPoint(100), ConvertToLocalRelation) :: Batch("Subquery", Once, @@ -147,12 +149,18 @@ object EliminateSerialization extends Rule[LogicalPlan] { child = childWithoutSerialization) case m @ MapElements(_, deserializer, _, child: ObjectOperator) - if !deserializer.isInstanceOf[Attribute] && - deserializer.dataType == child.outputObject.dataType => + if !deserializer.isInstanceOf[Attribute] && + deserializer.dataType == child.outputObject.dataType => val childWithoutSerialization = child.withObjectOutput m.copy( deserializer = childWithoutSerialization.output.head, child = childWithoutSerialization) + + case d @ DeserializeToObject(_, s: SerializeFromObject) + if d.outputObjectType == s.inputObjectType => + // Adds an extra Project here, to preserve the output expr id of `DeserializeToObject`. + val objAttr = Alias(s.child.output.head, "obj")(exprId = d.output.head.exprId) + Project(objAttr :: Nil, s.child) } } @@ -1329,3 +1337,30 @@ object ComputeCurrentTime extends Rule[LogicalPlan] { } } } + +/** + * Typed [[Filter]] is by default surrounded by a [[DeserializeToObject]] beneath it and a + * [[SerializeFromObject]] above it. If these serializations can't be eliminated, we should embed + * the deserializer in filter condition to save the extra serialization at last. + */ +object EmbedSerializerInFilter extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case s @ SerializeFromObject(_, Filter(condition, d: DeserializeToObject)) => + val numObjects = condition.collect { + case a: Attribute if a == d.output.head => a + }.length + + if (numObjects > 1) { + // If the filter condition references the object more than one times, we should not embed + // deserializer in it as the deserialization will happen many times and slow down the + // execution. + // TODO: we can still embed it if we can make sure subexpression elimination works here. + s + } else { + val newCondition = condition transform { + case a: Attribute if a == d.output.head => d.deserializer.child + } + Filter(newCondition, d.child) + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala index ec33a538a9..6df46189b6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala @@ -21,7 +21,42 @@ import org.apache.spark.sql.Encoder import org.apache.spark.sql.catalyst.analysis.UnresolvedDeserializer import org.apache.spark.sql.catalyst.encoders._ import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.types.{ObjectType, StructType} +import org.apache.spark.sql.types.{DataType, ObjectType, StructType} + +object CatalystSerde { + def deserialize[T : Encoder](child: LogicalPlan): DeserializeToObject = { + val deserializer = UnresolvedDeserializer(encoderFor[T].deserializer) + DeserializeToObject(Alias(deserializer, "obj")(), child) + } + + def serialize[T : Encoder](child: LogicalPlan): SerializeFromObject = { + SerializeFromObject(encoderFor[T].namedExpressions, child) + } +} + +/** + * Takes the input row from child and turns it into object using the given deserializer expression. + * The output of this operator is a single-field safe row containing the deserialized object. + */ +case class DeserializeToObject( + deserializer: Alias, + child: LogicalPlan) extends UnaryNode { + override def output: Seq[Attribute] = deserializer.toAttribute :: Nil + + def outputObjectType: DataType = deserializer.dataType +} + +/** + * Takes the input object from child and turns in into unsafe row using the given serializer + * expression. The output of its child must be a single-field row containing the input object. + */ +case class SerializeFromObject( + serializer: Seq[NamedExpression], + child: LogicalPlan) extends UnaryNode { + override def output: Seq[Attribute] = serializer.map(_.toAttribute) + + def inputObjectType: DataType = child.output.head.dataType +} /** * A trait for logical operators that apply user defined functions to domain objects. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/TypedFilterOptimizationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/TypedFilterOptimizationSuite.scala new file mode 100644 index 0000000000..1fae64e3bc --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/TypedFilterOptimizationSuite.scala @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import scala.reflect.runtime.universe.TypeTag + +import org.apache.spark.sql.catalyst.analysis.UnresolvedDeserializer +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder} +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} +import org.apache.spark.sql.catalyst.rules.RuleExecutor +import org.apache.spark.sql.types.BooleanType + +class TypedFilterOptimizationSuite extends PlanTest { + object Optimize extends RuleExecutor[LogicalPlan] { + val batches = + Batch("EliminateSerialization", FixedPoint(50), + EliminateSerialization) :: + Batch("EmbedSerializerInFilter", FixedPoint(50), + EmbedSerializerInFilter) :: Nil + } + + implicit private def productEncoder[T <: Product : TypeTag] = ExpressionEncoder[T]() + + test("back to back filter") { + val input = LocalRelation('_1.int, '_2.int) + val f1 = (i: (Int, Int)) => i._1 > 0 + val f2 = (i: (Int, Int)) => i._2 > 0 + + val query = input.filter(f1).filter(f2).analyze + + val optimized = Optimize.execute(query) + + val expected = input.deserialize[(Int, Int)] + .where(callFunction(f1, BooleanType, 'obj)) + .select('obj.as("obj")) + .where(callFunction(f2, BooleanType, 'obj)) + .serialize[(Int, Int)].analyze + + comparePlans(optimized, expected) + } + + test("embed deserializer in filter condition if there is only one filter") { + val input = LocalRelation('_1.int, '_2.int) + val f = (i: (Int, Int)) => i._1 > 0 + + val query = input.filter(f).analyze + + val optimized = Optimize.execute(query) + + val deserializer = UnresolvedDeserializer(encoderFor[(Int, Int)].deserializer) + val condition = callFunction(f, BooleanType, deserializer) + val expected = input.where(condition).analyze + + comparePlans(optimized, expected) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 2854d5f9da..2f6d8d109f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -1879,7 +1879,13 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ @Experimental - def filter(func: T => Boolean): Dataset[T] = mapPartitions(_.filter(func)) + def filter(func: T => Boolean): Dataset[T] = { + val deserialized = CatalystSerde.deserialize[T](logicalPlan) + val function = Literal.create(func, ObjectType(classOf[T => Boolean])) + val condition = Invoke(function, "apply", BooleanType, deserialized.output) + val filter = Filter(condition, deserialized) + withTypedPlan(CatalystSerde.serialize[T](filter)) + } /** * :: Experimental :: @@ -1890,7 +1896,13 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ @Experimental - def filter(func: FilterFunction[T]): Dataset[T] = filter(t => func.call(t)) + def filter(func: FilterFunction[T]): Dataset[T] = { + val deserialized = CatalystSerde.deserialize[T](logicalPlan) + val function = Literal.create(func, ObjectType(classOf[FilterFunction[T]])) + val condition = Invoke(function, "call", BooleanType, deserialized.output) + val filter = Filter(condition, deserialized) + withTypedPlan(CatalystSerde.serialize[T](filter)) + } /** * :: Experimental :: diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index eee2b946e3..c15aaed365 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -346,6 +346,10 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { throw new IllegalStateException( "logical intersect operator should have been replaced by semi-join in the optimizer") + case logical.DeserializeToObject(deserializer, child) => + execution.DeserializeToObject(deserializer, planLater(child)) :: Nil + case logical.SerializeFromObject(serializer, child) => + execution.SerializeFromObject(serializer, planLater(child)) :: Nil case logical.MapPartitions(f, in, out, child) => execution.MapPartitions(f, in, out, planLater(child)) :: Nil case logical.MapElements(f, in, out, child) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala index f48f3f09c7..d2ab18ef0e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala @@ -27,6 +27,73 @@ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.types.ObjectType +/** + * Takes the input row from child and turns it into object using the given deserializer expression. + * The output of this operator is a single-field safe row containing the deserialized object. + */ +case class DeserializeToObject( + deserializer: Alias, + child: SparkPlan) extends UnaryNode with CodegenSupport { + override def output: Seq[Attribute] = deserializer.toAttribute :: Nil + + override def upstreams(): Seq[RDD[InternalRow]] = { + child.asInstanceOf[CodegenSupport].upstreams() + } + + protected override def doProduce(ctx: CodegenContext): String = { + child.asInstanceOf[CodegenSupport].produce(ctx, this) + } + + override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = { + val bound = ExpressionCanonicalizer.execute( + BindReferences.bindReference(deserializer, child.output)) + ctx.currentVars = input + val resultVars = bound.gen(ctx) :: Nil + consume(ctx, resultVars) + } + + override protected def doExecute(): RDD[InternalRow] = { + child.execute().mapPartitionsInternal { iter => + val projection = GenerateSafeProjection.generate(deserializer :: Nil, child.output) + iter.map(projection) + } + } +} + +/** + * Takes the input object from child and turns in into unsafe row using the given serializer + * expression. The output of its child must be a single-field row containing the input object. + */ +case class SerializeFromObject( + serializer: Seq[NamedExpression], + child: SparkPlan) extends UnaryNode with CodegenSupport { + override def output: Seq[Attribute] = serializer.map(_.toAttribute) + + override def upstreams(): Seq[RDD[InternalRow]] = { + child.asInstanceOf[CodegenSupport].upstreams() + } + + protected override def doProduce(ctx: CodegenContext): String = { + child.asInstanceOf[CodegenSupport].produce(ctx, this) + } + + override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = { + val bound = serializer.map { expr => + ExpressionCanonicalizer.execute(BindReferences.bindReference(expr, child.output)) + } + ctx.currentVars = input + val resultVars = bound.map(_.gen(ctx)) + consume(ctx, resultVars) + } + + override protected def doExecute(): RDD[InternalRow] = { + child.execute().mapPartitionsInternal { iter => + val projection = UnsafeProjection.create(serializer) + iter.map(projection) + } + } +} + /** * Helper functions for physical operators that work with user defined objects. */ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetBenchmark.scala index 6eb952445f..5f3dd906fe 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetBenchmark.scala @@ -28,16 +28,10 @@ object DatasetBenchmark { case class Data(l: Long, s: String) - def main(args: Array[String]): Unit = { - val sparkContext = new SparkContext("local[*]", "Dataset benchmark") - val sqlContext = new SQLContext(sparkContext) - + def backToBackMap(sqlContext: SQLContext, numRows: Long, numChains: Int): Benchmark = { import sqlContext.implicits._ - val numRows = 10000000 val df = sqlContext.range(1, numRows).select($"id".as("l"), $"id".cast(StringType).as("s")) - val numChains = 10 - val benchmark = new Benchmark("back-to-back map", numRows) val func = (d: Data) => Data(d.l + 1, d.s) @@ -61,7 +55,7 @@ object DatasetBenchmark { res.queryExecution.toRdd.foreach(_ => Unit) } - val rdd = sparkContext.range(1, numRows).map(l => Data(l, l.toString)) + val rdd = sqlContext.sparkContext.range(1, numRows).map(l => Data(l, l.toString)) benchmark.addCase("RDD") { iter => var res = rdd var i = 0 @@ -72,6 +66,63 @@ object DatasetBenchmark { res.foreach(_ => Unit) } + benchmark + } + + def backToBackFilter(sqlContext: SQLContext, numRows: Long, numChains: Int): Benchmark = { + import sqlContext.implicits._ + + val df = sqlContext.range(1, numRows).select($"id".as("l"), $"id".cast(StringType).as("s")) + val benchmark = new Benchmark("back-to-back filter", numRows) + + val func = (d: Data, i: Int) => d.l % (100L + i) == 0L + val funcs = 0.until(numChains).map { i => + (d: Data) => func(d, i) + } + benchmark.addCase("Dataset") { iter => + var res = df.as[Data] + var i = 0 + while (i < numChains) { + res = res.filter(funcs(i)) + i += 1 + } + res.queryExecution.toRdd.foreach(_ => Unit) + } + + benchmark.addCase("DataFrame") { iter => + var res = df + var i = 0 + while (i < numChains) { + res = res.filter($"l" % (100L + i) === 0L) + i += 1 + } + res.queryExecution.toRdd.foreach(_ => Unit) + } + + val rdd = sqlContext.sparkContext.range(1, numRows).map(l => Data(l, l.toString)) + benchmark.addCase("RDD") { iter => + var res = rdd + var i = 0 + while (i < numChains) { + res = rdd.filter(funcs(i)) + i += 1 + } + res.foreach(_ => Unit) + } + + benchmark + } + + def main(args: Array[String]): Unit = { + val sparkContext = new SparkContext("local[*]", "Dataset benchmark") + val sqlContext = new SQLContext(sparkContext) + + val numRows = 10000000 + val numChains = 10 + + val benchmark = backToBackMap(sqlContext, numRows, numChains) + val benchmark2 = backToBackFilter(sqlContext, numRows, numChains) + /* Java HotSpot(TM) 64-Bit Server VM 1.8.0_60-b27 on Mac OS X 10.11.4 Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz @@ -82,5 +133,14 @@ object DatasetBenchmark { RDD 216 / 237 46.3 21.6 4.2X */ benchmark.run() + + /* + back-to-back filter: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------- + Dataset 585 / 628 17.1 58.5 1.0X + DataFrame 62 / 80 160.7 6.2 9.4X + RDD 205 / 220 48.7 20.5 2.8X + */ + benchmark2.run() } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index 48a077d0e5..826862835a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -32,6 +32,7 @@ import org.apache.spark.sql.execution.LogicalRDD import org.apache.spark.sql.execution.columnar.InMemoryRelation import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.execution.streaming.MemoryPlan +import org.apache.spark.sql.types.ObjectType abstract class QueryTest extends PlanTest { @@ -204,6 +205,7 @@ abstract class QueryTest extends PlanTest { case _: MemoryPlan => return }.transformAllExpressions { case a: ImperativeAggregate => return + case Literal(_, _: ObjectType) => return } // bypass hive tests before we fix all corner cases in hive module. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala index f73ca887f1..4474cfcf6e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala @@ -17,8 +17,7 @@ package org.apache.spark.sql.execution -import org.apache.spark.api.java.function.MapFunction -import org.apache.spark.sql.{Encoders, Row} +import org.apache.spark.sql.Row import org.apache.spark.sql.execution.aggregate.TungstenAggregate import org.apache.spark.sql.execution.joins.BroadcastHashJoin import org.apache.spark.sql.functions.{avg, broadcast, col, max} @@ -82,4 +81,22 @@ class WholeStageCodegenSuite extends SparkPlanTest with SharedSQLContext { p.asInstanceOf[WholeStageCodegen].child.isInstanceOf[MapElements]).isDefined) assert(ds.collect() === 0.until(10).map(_.toString).toArray) } + + test("typed filter should be included in WholeStageCodegen") { + val ds = sqlContext.range(10).filter(_ % 2 == 0) + val plan = ds.queryExecution.executedPlan + assert(plan.find(p => + p.isInstanceOf[WholeStageCodegen] && + p.asInstanceOf[WholeStageCodegen].child.isInstanceOf[Filter]).isDefined) + assert(ds.collect() === Array(0, 2, 4, 6, 8)) + } + + test("back-to-back typed filter should be included in WholeStageCodegen") { + val ds = sqlContext.range(10).filter(_ % 2 == 0).filter(_ % 3 == 0) + val plan = ds.queryExecution.executedPlan + assert(plan.find(p => + p.isInstanceOf[WholeStageCodegen] && + p.asInstanceOf[WholeStageCodegen].child.isInstanceOf[SerializeFromObject]).isDefined) + assert(ds.collect() === Array(0, 6)) + } } -- cgit v1.2.3 From 30e980ad8e6443dddd54f3c2d48b3904499545cf Mon Sep 17 00:00:00 2001 From: Michael Gummelt Date: Thu, 7 Apr 2016 17:41:37 -0700 Subject: [DOCS][MINOR] Remove sentence about Mesos not supporting cluster mode. Docs change to remove the sentence about Mesos not supporting cluster mode. It was not. Author: Michael Gummelt Closes #12249 from mgummelt/fix-mesos-cluster-docs. --- docs/submitting-applications.md | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/docs/submitting-applications.md b/docs/submitting-applications.md index 66025ed6ba..100ff0b147 100644 --- a/docs/submitting-applications.md +++ b/docs/submitting-applications.md @@ -58,8 +58,7 @@ for applications that involve the REPL (e.g. Spark shell). Alternatively, if your application is submitted from a machine far from the worker machines (e.g. locally on your laptop), it is common to use `cluster` mode to minimize network latency between -the drivers and the executors. Note that `cluster` mode is currently not supported for -Mesos clusters. Currently only YARN supports cluster mode for Python applications. +the drivers and the executors. Currently only YARN supports cluster mode for Python applications. For Python applications, simply pass a `.py` file in the place of `` instead of a JAR, and add Python `.zip`, `.egg` or `.py` files to the search path with `--py-files`. -- cgit v1.2.3 From 3e29e372ff518827bae9dcd26087946fde476843 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Thu, 7 Apr 2016 17:49:39 -0700 Subject: [SPARK-14468] Always enable OutputCommitCoordinator ## What changes were proposed in this pull request? `OutputCommitCoordinator` was introduced to deal with concurrent task attempts racing to write output, leading to data loss or corruption. For more detail, read the [JIRA description](https://issues.apache.org/jira/browse/SPARK-14468). Before: `OutputCommitCoordinator` is enabled only if speculation is enabled. After: `OutputCommitCoordinator` is always enabled. Users may still disable this through `spark.hadoop.outputCommitCoordination.enabled`, but they really shouldn't... ## How was this patch tested? `OutputCommitCoordinator*Suite` Author: Andrew Or Closes #12244 from andrewor14/always-occ. --- .../org/apache/spark/mapred/SparkHadoopMapRedUtil.scala | 16 ++++++---------- .../OutputCommitCoordinatorIntegrationSuite.scala | 2 +- .../spark/scheduler/OutputCommitCoordinatorSuite.scala | 2 +- 3 files changed, 8 insertions(+), 12 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/mapred/SparkHadoopMapRedUtil.scala b/core/src/main/scala/org/apache/spark/mapred/SparkHadoopMapRedUtil.scala index 891facba33..607283a306 100644 --- a/core/src/main/scala/org/apache/spark/mapred/SparkHadoopMapRedUtil.scala +++ b/core/src/main/scala/org/apache/spark/mapred/SparkHadoopMapRedUtil.scala @@ -33,11 +33,8 @@ object SparkHadoopMapRedUtil extends Logging { * the driver in order to determine whether this attempt can commit (please see SPARK-4879 for * details). * - * Output commit coordinator is only contacted when the following two configurations are both set - * to `true`: - * - * - `spark.speculation` - * - `spark.hadoop.outputCommitCoordination.enabled` + * Output commit coordinator is only used when `spark.hadoop.outputCommitCoordination.enabled` + * is set to true (which is the default). */ def commitTask( committer: MapReduceOutputCommitter, @@ -64,11 +61,10 @@ object SparkHadoopMapRedUtil extends Logging { if (committer.needsTaskCommit(mrTaskContext)) { val shouldCoordinateWithDriver: Boolean = { val sparkConf = SparkEnv.get.conf - // We only need to coordinate with the driver if there are multiple concurrent task - // attempts, which should only occur if speculation is enabled - val speculationEnabled = sparkConf.getBoolean("spark.speculation", defaultValue = false) - // This (undocumented) setting is an escape-hatch in case the commit code introduces bugs - sparkConf.getBoolean("spark.hadoop.outputCommitCoordination.enabled", speculationEnabled) + // We only need to coordinate with the driver if there are concurrent task attempts. + // Note that this could happen even when speculation is not enabled (e.g. see SPARK-8029). + // This (undocumented) setting is an escape-hatch in case the commit code introduces bugs. + sparkConf.getBoolean("spark.hadoop.outputCommitCoordination.enabled", defaultValue = true) } if (shouldCoordinateWithDriver) { diff --git a/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorIntegrationSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorIntegrationSuite.scala index 9f41aca8a1..601f1c378c 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorIntegrationSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorIntegrationSuite.scala @@ -38,7 +38,7 @@ class OutputCommitCoordinatorIntegrationSuite super.beforeAll() val conf = new SparkConf() .set("master", "local[2,4]") - .set("spark.speculation", "true") + .set("spark.hadoop.outputCommitCoordination.enabled", "true") .set("spark.hadoop.mapred.output.committer.class", classOf[ThrowExceptionOnFirstAttemptOutputCommitter].getCanonicalName) sc = new SparkContext("local[2, 4]", "test", conf) diff --git a/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala index c461da65bd..8e509de767 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala @@ -77,7 +77,7 @@ class OutputCommitCoordinatorSuite extends SparkFunSuite with BeforeAndAfter { val conf = new SparkConf() .setMaster("local[4]") .setAppName(classOf[OutputCommitCoordinatorSuite].getSimpleName) - .set("spark.speculation", "true") + .set("spark.hadoop.outputCommitCoordination.enabled", "true") sc = new SparkContext(conf) { override private[spark] def createSparkEnv( conf: SparkConf, -- cgit v1.2.3 From 692c74840bc53debbb842db5372702f58207412c Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Thu, 7 Apr 2016 18:05:54 -0700 Subject: [SPARK-14449][SQL] SparkContext should use SparkListenerInterface Currently all `SparkFirehoseListener` implementations are broken since we expect listeners to extend `SparkListener`, while the fire hose only extends `SparkListenerInterface`. This changes the addListener function and the config based injection to use the interface instead. The existing tests in SparkListenerSuite are improved such that they would have caught this. Follow-up to #12142 Author: Michael Armbrust Closes #12227 from marmbrus/fixListener. --- .../main/scala/org/apache/spark/SparkContext.scala | 8 +++++--- .../org/apache/spark/scheduler/SparkListenerBus.scala | 7 +++++-- .../apache/spark/scheduler/SparkListenerSuite.scala | 19 ++++++++++++++++--- project/MimaExcludes.scala | 1 + 4 files changed, 27 insertions(+), 8 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index c40fada64b..9ec5cedf25 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -1356,7 +1356,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * Register a listener to receive up-calls from events that happen during execution. */ @DeveloperApi - def addSparkListener(listener: SparkListener) { + def addSparkListener(listener: SparkListenerInterface) { listenerBus.addListener(listener) } @@ -2007,7 +2007,9 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli // Use reflection to find the right constructor val constructors = { val listenerClass = Utils.classForName(className) - listenerClass.getConstructors.asInstanceOf[Array[Constructor[_ <: SparkListener]]] + listenerClass + .getConstructors + .asInstanceOf[Array[Constructor[_ <: SparkListenerInterface]]] } val constructorTakingSparkConf = constructors.find { c => c.getParameterTypes.sameElements(Array(classOf[SparkConf])) @@ -2015,7 +2017,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli lazy val zeroArgumentConstructor = constructors.find { c => c.getParameterTypes.isEmpty } - val listener: SparkListener = { + val listener: SparkListenerInterface = { if (constructorTakingSparkConf.isDefined) { constructorTakingSparkConf.get.newInstance(conf) } else if (zeroArgumentConstructor.isDefined) { diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala index 94f0574f0e..471586ac08 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala @@ -22,9 +22,12 @@ import org.apache.spark.util.ListenerBus /** * A [[SparkListenerEvent]] bus that relays [[SparkListenerEvent]]s to its listeners */ -private[spark] trait SparkListenerBus extends ListenerBus[SparkListener, SparkListenerEvent] { +private[spark] trait SparkListenerBus + extends ListenerBus[SparkListenerInterface, SparkListenerEvent] { - protected override def doPostEvent(listener: SparkListener, event: SparkListenerEvent): Unit = { + protected override def doPostEvent( + listener: SparkListenerInterface, + event: SparkListenerEvent): Unit = { event match { case stageSubmitted: SparkListenerStageSubmitted => listener.onStageSubmitted(stageSubmitted) diff --git a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala index 58d217ffef..b854d742b5 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala @@ -24,7 +24,7 @@ import scala.collection.JavaConverters._ import org.scalatest.Matchers -import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkException, SparkFunSuite} +import org.apache.spark._ import org.apache.spark.executor.TaskMetrics import org.apache.spark.util.{ResetSystemProperties, RpcUtils} @@ -377,13 +377,18 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match } test("registering listeners via spark.extraListeners") { + val listeners = Seq( + classOf[ListenerThatAcceptsSparkConf], + classOf[FirehoseListenerThatAcceptsSparkConf], + classOf[BasicJobCounter]) val conf = new SparkConf().setMaster("local").setAppName("test") - .set("spark.extraListeners", classOf[ListenerThatAcceptsSparkConf].getName + "," + - classOf[BasicJobCounter].getName) + .set("spark.extraListeners", listeners.map(_.getName).mkString(",")) sc = new SparkContext(conf) sc.listenerBus.listeners.asScala.count(_.isInstanceOf[BasicJobCounter]) should be (1) sc.listenerBus.listeners.asScala .count(_.isInstanceOf[ListenerThatAcceptsSparkConf]) should be (1) + sc.listenerBus.listeners.asScala + .count(_.isInstanceOf[FirehoseListenerThatAcceptsSparkConf]) should be (1) } /** @@ -476,3 +481,11 @@ private class ListenerThatAcceptsSparkConf(conf: SparkConf) extends SparkListene var count = 0 override def onJobEnd(job: SparkListenerJobEnd): Unit = count += 1 } + +private class FirehoseListenerThatAcceptsSparkConf(conf: SparkConf) extends SparkFirehoseListener { + var count = 0 + override def onEvent(event: SparkListenerEvent): Unit = event match { + case job: SparkListenerJobEnd => count += 1 + case _ => + } +} diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index d916c49a6a..fbadc563b8 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -68,6 +68,7 @@ object MimaExcludes { ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.SparkContext.emptyRDD"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.broadcast.HttpBroadcastFactory"), // SPARK-14358 SparkListener from trait to abstract class + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.SparkContext.addSparkListener"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.JavaSparkListener"), ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.SparkFirehoseListener"), ProblemFilters.exclude[IncompatibleTemplateDefProblem]("org.apache.spark.scheduler.SparkListener"), -- cgit v1.2.3 From 953ff897e422570a329d0aec98d573d3fb66ab9a Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Thu, 7 Apr 2016 19:48:33 -0700 Subject: [SPARK-13048][ML][MLLIB] keepLastCheckpoint option for LDA EM optimizer ## What changes were proposed in this pull request? The EMLDAOptimizer should generally not delete its last checkpoint since that can cause failures when DistributedLDAModel methods are called (if any partitions need to be recovered from the checkpoint). This PR adds a "deleteLastCheckpoint" option which defaults to false. This is a change in behavior from Spark 1.6, in that the last checkpoint will not be removed by default. This involves adding the deleteLastCheckpoint option to both spark.ml and spark.mllib, and modifying PeriodicCheckpointer to support the option. This also: * Makes MLlibTestSparkContext extend TempDirectory and set the checkpointDir to tempDir * Updates LibSVMRelationSuite because of a name conflict with "tempDir" (and fixes a bug where it failed to delete a temp directory) * Adds a MIMA exclude for DistributedLDAModel constructor, which is already ```private[clustering]``` ## How was this patch tested? Added 2 new unit tests to spark.ml LDASuite, which calls into spark.mllib. Author: Joseph K. Bradley Closes #12166 from jkbradley/emlda-save-checkpoint. --- .../scala/org/apache/spark/ml/clustering/LDA.scala | 79 ++++++++++++++++++++-- .../apache/spark/mllib/clustering/LDAModel.scala | 13 ++-- .../spark/mllib/clustering/LDAOptimizer.scala | 34 ++++++++-- .../spark/mllib/impl/PeriodicCheckpointer.scala | 41 ++++++++--- .../org/apache/spark/ml/clustering/LDASuite.scala | 28 ++++++++ .../ml/source/libsvm/LibSVMRelationSuite.scala | 15 ++-- .../spark/mllib/util/MLlibTestSparkContext.scala | 13 +++- project/MimaExcludes.scala | 3 + 8 files changed, 194 insertions(+), 32 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala index 60cc345565..727b724708 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala @@ -17,18 +17,19 @@ package org.apache.spark.ml.clustering -import org.apache.hadoop.fs.Path +import org.apache.hadoop.fs.{FileSystem, Path} -import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.annotation.{DeveloperApi, Experimental, Since} import org.apache.spark.internal.Logging import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared.{HasCheckpointInterval, HasFeaturesCol, HasMaxIter, HasSeed} import org.apache.spark.ml.util._ import org.apache.spark.mllib.clustering.{DistributedLDAModel => OldDistributedLDAModel, - EMLDAOptimizer => OldEMLDAOptimizer, LDA => OldLDA, LDAModel => OldLDAModel, - LDAOptimizer => OldLDAOptimizer, LocalLDAModel => OldLocalLDAModel, - OnlineLDAOptimizer => OldOnlineLDAOptimizer} + EMLDAOptimizer => OldEMLDAOptimizer, LDA => OldLDA, LDAModel => OldLDAModel, + LDAOptimizer => OldLDAOptimizer, LocalLDAModel => OldLocalLDAModel, + OnlineLDAOptimizer => OldOnlineLDAOptimizer} +import org.apache.spark.mllib.impl.PeriodicCheckpointer import org.apache.spark.mllib.linalg.{Matrix, Vector, Vectors, VectorUDT} import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Row, SQLContext} @@ -41,6 +42,7 @@ private[clustering] trait LDAParams extends Params with HasFeaturesCol with HasM /** * Param for the number of topics (clusters) to infer. Must be > 1. Default: 10. + * * @group param */ @Since("1.6.0") @@ -173,6 +175,7 @@ private[clustering] trait LDAParams extends Params with HasFeaturesCol with HasM * This uses a variational approximation following Hoffman et al. (2010), where the approximate * distribution is called "gamma." Technically, this method returns this approximation "gamma" * for each document. + * * @group param */ @Since("1.6.0") @@ -191,6 +194,7 @@ private[clustering] trait LDAParams extends Params with HasFeaturesCol with HasM * iterations count less. * This is called "tau0" in the Online LDA paper (Hoffman et al., 2010) * Default: 1024, following Hoffman et al. + * * @group expertParam */ @Since("1.6.0") @@ -207,6 +211,7 @@ private[clustering] trait LDAParams extends Params with HasFeaturesCol with HasM * This should be between (0.5, 1.0] to guarantee asymptotic convergence. * This is called "kappa" in the Online LDA paper (Hoffman et al., 2010). * Default: 0.51, based on Hoffman et al. + * * @group expertParam */ @Since("1.6.0") @@ -230,6 +235,7 @@ private[clustering] trait LDAParams extends Params with HasFeaturesCol with HasM * [[org.apache.spark.mllib.clustering.OnlineLDAOptimizer]]. * * Default: 0.05, i.e., 5% of total documents. + * * @group param */ @Since("1.6.0") @@ -246,6 +252,7 @@ private[clustering] trait LDAParams extends Params with HasFeaturesCol with HasM * document-topic distribution) will be optimized during training. * Setting this to true will make the model more expressive and fit the training data better. * Default: false + * * @group expertParam */ @Since("1.6.0") @@ -257,8 +264,32 @@ private[clustering] trait LDAParams extends Params with HasFeaturesCol with HasM @Since("1.6.0") def getOptimizeDocConcentration: Boolean = $(optimizeDocConcentration) + /** + * For EM optimizer, if using checkpointing, this indicates whether to keep the last + * checkpoint. If false, then the checkpoint will be deleted. Deleting the checkpoint can + * cause failures if a data partition is lost, so set this bit with care. + * Note that checkpoints will be cleaned up via reference counting, regardless. + * + * See [[DistributedLDAModel.getCheckpointFiles]] for getting remaining checkpoints and + * [[DistributedLDAModel.deleteCheckpointFiles]] for removing remaining checkpoints. + * + * Default: true + * + * @group expertParam + */ + @Since("2.0.0") + final val keepLastCheckpoint = new BooleanParam(this, "keepLastCheckpoint", + "For EM optimizer, if using checkpointing, this indicates whether to keep the last" + + " checkpoint. If false, then the checkpoint will be deleted. Deleting the checkpoint can" + + " cause failures if a data partition is lost, so set this bit with care.") + + /** @group expertGetParam */ + @Since("2.0.0") + def getKeepLastCheckpoint: Boolean = $(keepLastCheckpoint) + /** * Validates and transforms the input schema. + * * @param schema input schema * @return output schema */ @@ -303,6 +334,7 @@ private[clustering] trait LDAParams extends Params with HasFeaturesCol with HasM .setOptimizeDocConcentration($(optimizeDocConcentration)) case "em" => new OldEMLDAOptimizer() + .setKeepLastCheckpoint($(keepLastCheckpoint)) } } @@ -341,6 +373,7 @@ sealed abstract class LDAModel private[ml] ( /** * The features for LDA should be a [[Vector]] representing the word counts in a document. * The vector should be of length vocabSize, with counts for each term (word). + * * @group setParam */ @Since("1.6.0") @@ -619,6 +652,35 @@ class DistributedLDAModel private[ml] ( @Since("1.6.0") lazy val logPrior: Double = oldDistributedModel.logPrior + private var _checkpointFiles: Array[String] = oldDistributedModel.checkpointFiles + + /** + * If using checkpointing and [[LDA.keepLastCheckpoint]] is set to true, then there may be + * saved checkpoint files. This method is provided so that users can manage those files. + * + * Note that removing the checkpoints can cause failures if a partition is lost and is needed + * by certain [[DistributedLDAModel]] methods. Reference counting will clean up the checkpoints + * when this model and derivative data go out of scope. + * + * @return Checkpoint files from training + */ + @DeveloperApi + @Since("2.0.0") + def getCheckpointFiles: Array[String] = _checkpointFiles + + /** + * Remove any remaining checkpoint files from training. + * + * @see [[getCheckpointFiles]] + */ + @DeveloperApi + @Since("2.0.0") + def deleteCheckpointFiles(): Unit = { + val fs = FileSystem.get(sqlContext.sparkContext.hadoopConfiguration) + _checkpointFiles.foreach(PeriodicCheckpointer.removeCheckpointFile(_, fs)) + _checkpointFiles = Array.empty[String] + } + @Since("1.6.0") override def write: MLWriter = new DistributedLDAModel.DistributedWriter(this) } @@ -696,11 +758,12 @@ class LDA @Since("1.6.0") ( setDefault(maxIter -> 20, k -> 10, optimizer -> "online", checkpointInterval -> 10, learningOffset -> 1024, learningDecay -> 0.51, subsamplingRate -> 0.05, - optimizeDocConcentration -> true) + optimizeDocConcentration -> true, keepLastCheckpoint -> true) /** * The features for LDA should be a [[Vector]] representing the word counts in a document. * The vector should be of length vocabSize, with counts for each term (word). + * * @group setParam */ @Since("1.6.0") @@ -758,6 +821,10 @@ class LDA @Since("1.6.0") ( @Since("1.6.0") def setOptimizeDocConcentration(value: Boolean): this.type = set(optimizeDocConcentration, value) + /** @group expertSetParam */ + @Since("2.0.0") + def setKeepLastCheckpoint(value: Boolean): this.type = set(keepLastCheckpoint, value) + @Since("1.6.0") override def copy(extra: ParamMap): LDA = defaultCopy(extra) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala index 25d67a3756..27b4004927 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala @@ -534,7 +534,8 @@ class DistributedLDAModel private[clustering] ( @Since("1.5.0") override val docConcentration: Vector, @Since("1.5.0") override val topicConcentration: Double, private[spark] val iterationTimes: Array[Double], - override protected[clustering] val gammaShape: Double = 100) + override protected[clustering] val gammaShape: Double = DistributedLDAModel.defaultGammaShape, + private[spark] val checkpointFiles: Array[String] = Array.empty[String]) extends LDAModel { import LDA._ @@ -806,11 +807,9 @@ class DistributedLDAModel private[clustering] ( override protected def formatVersion = "1.0" - /** - * Java-friendly version of [[topicDistributions]] - */ @Since("1.5.0") override def save(sc: SparkContext, path: String): Unit = { + // Note: This intentionally does not save checkpointFiles. DistributedLDAModel.SaveLoadV1_0.save( sc, path, graph, globalTopicTotals, k, vocabSize, docConcentration, topicConcentration, iterationTimes, gammaShape) @@ -822,6 +821,12 @@ class DistributedLDAModel private[clustering] ( @Since("1.5.0") object DistributedLDAModel extends Loader[DistributedLDAModel] { + /** + * The [[DistributedLDAModel]] constructor's default arguments assume gammaShape = 100 + * to ensure equivalence in LDAModel.toLocal conversion. + */ + private[clustering] val defaultGammaShape: Double = 100 + private object SaveLoadV1_0 { val thisFormatVersion = "1.0" diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala index 2b404a8651..6418f0d3b3 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala @@ -80,9 +80,29 @@ final class EMLDAOptimizer extends LDAOptimizer { import LDA._ + // Adjustable parameters + private var keepLastCheckpoint: Boolean = true + /** - * The following fields will only be initialized through the initialize() method + * If using checkpointing, this indicates whether to keep the last checkpoint (vs clean up). + */ + @Since("2.0.0") + def getKeepLastCheckpoint: Boolean = this.keepLastCheckpoint + + /** + * If using checkpointing, this indicates whether to keep the last checkpoint (vs clean up). + * Deleting the checkpoint can cause failures if a data partition is lost, so set this bit with + * care. Note that checkpoints will be cleaned up via reference counting, regardless. + * + * Default: true */ + @Since("2.0.0") + def setKeepLastCheckpoint(keepLastCheckpoint: Boolean): this.type = { + this.keepLastCheckpoint = keepLastCheckpoint + this + } + + // The following fields will only be initialized through the initialize() method private[clustering] var graph: Graph[TopicCounts, TokenCount] = null private[clustering] var k: Int = 0 private[clustering] var vocabSize: Int = 0 @@ -208,12 +228,18 @@ final class EMLDAOptimizer extends LDAOptimizer { override private[clustering] def getLDAModel(iterationTimes: Array[Double]): LDAModel = { require(graph != null, "graph is null, EMLDAOptimizer not initialized.") - this.graphCheckpointer.deleteAllCheckpoints() + val checkpointFiles: Array[String] = if (keepLastCheckpoint) { + this.graphCheckpointer.deleteAllCheckpointsButLast() + this.graphCheckpointer.getAllCheckpointFiles + } else { + this.graphCheckpointer.deleteAllCheckpoints() + Array.empty[String] + } // The constructor's default arguments assume gammaShape = 100 to ensure equivalence in - // LDAModel.toLocal conversion + // LDAModel.toLocal conversion. new DistributedLDAModel(this.graph, this.globalTopicTotals, this.k, this.vocabSize, Vectors.dense(Array.fill(this.k)(this.docConcentration)), this.topicConcentration, - iterationTimes) + iterationTimes, DistributedLDAModel.defaultGammaShape, checkpointFiles) } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicCheckpointer.scala b/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicCheckpointer.scala index 391f89aa14..cbc8f60112 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicCheckpointer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicCheckpointer.scala @@ -133,6 +133,24 @@ private[mllib] abstract class PeriodicCheckpointer[T]( } } + /** + * Call this at the end to delete any remaining checkpoint files, except for the last checkpoint. + * Note that there may not be any checkpoints at all. + */ + def deleteAllCheckpointsButLast(): Unit = { + while (checkpointQueue.size > 1) { + removeCheckpointFile() + } + } + + /** + * Get all current checkpoint files. + * This is useful in combination with [[deleteAllCheckpointsButLast()]]. + */ + def getAllCheckpointFiles: Array[String] = { + checkpointQueue.flatMap(getCheckpointFiles).toArray + } + /** * Dequeue the oldest checkpointed Dataset, and remove its checkpoint files. * This prints a warning but does not fail if the files cannot be removed. @@ -141,15 +159,20 @@ private[mllib] abstract class PeriodicCheckpointer[T]( val old = checkpointQueue.dequeue() // Since the old checkpoint is not deleted by Spark, we manually delete it. val fs = FileSystem.get(sc.hadoopConfiguration) - getCheckpointFiles(old).foreach { checkpointFile => - try { - fs.delete(new Path(checkpointFile), true) - } catch { - case e: Exception => - logWarning("PeriodicCheckpointer could not remove old checkpoint file: " + - checkpointFile) - } - } + getCheckpointFiles(old).foreach(PeriodicCheckpointer.removeCheckpointFile(_, fs)) } +} + +private[spark] object PeriodicCheckpointer extends Logging { + /** Delete a checkpoint file, and log a warning if deletion fails. */ + def removeCheckpointFile(path: String, fs: FileSystem): Unit = { + try { + fs.delete(new Path(path), true) + } catch { + case e: Exception => + logWarning("PeriodicCheckpointer could not remove old checkpoint file: " + + path) + } + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala index dd3f4c6e53..a1c93891c7 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.ml.clustering +import org.apache.hadoop.fs.{FileSystem, Path} + import org.apache.spark.SparkFunSuite import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.mllib.linalg.{Vector, Vectors} @@ -261,4 +263,30 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead testEstimatorAndModelReadWrite(lda, dataset, LDASuite.allParamSettings ++ Map("optimizer" -> "em"), checkModelData) } + + test("EM LDA checkpointing: save last checkpoint") { + // Checkpoint dir is set by MLlibTestSparkContext + val lda = new LDA().setK(2).setSeed(1).setOptimizer("em").setMaxIter(3).setCheckpointInterval(1) + val model_ = lda.fit(dataset) + assert(model_.isInstanceOf[DistributedLDAModel]) + val model = model_.asInstanceOf[DistributedLDAModel] + + // There should be 1 checkpoint remaining. + assert(model.getCheckpointFiles.length === 1) + val fs = FileSystem.get(sqlContext.sparkContext.hadoopConfiguration) + assert(fs.exists(new Path(model.getCheckpointFiles.head))) + model.deleteCheckpointFiles() + assert(model.getCheckpointFiles.isEmpty) + } + + test("EM LDA checkpointing: remove last checkpoint") { + // Checkpoint dir is set by MLlibTestSparkContext + val lda = new LDA().setK(2).setSeed(1).setOptimizer("em").setMaxIter(3).setCheckpointInterval(1) + .setKeepLastCheckpoint(false) + val model_ = lda.fit(dataset) + assert(model_.isInstanceOf[DistributedLDAModel]) + val model = model_.asInstanceOf[DistributedLDAModel] + + assert(model.getCheckpointFiles.isEmpty) + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala index 114a238462..0bd14978b2 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala @@ -28,8 +28,9 @@ import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.SaveMode import org.apache.spark.util.Utils + class LibSVMRelationSuite extends SparkFunSuite with MLlibTestSparkContext { - var tempDir: File = _ + // Path for dataset var path: String = _ override def beforeAll(): Unit = { @@ -40,15 +41,15 @@ class LibSVMRelationSuite extends SparkFunSuite with MLlibTestSparkContext { |0 |0 2:4.0 4:5.0 6:6.0 """.stripMargin - tempDir = Utils.createTempDir() - val file = new File(tempDir, "part-00000") + val dir = Utils.createDirectory(tempDir.getCanonicalPath, "data") + val file = new File(dir, "part-00000") Files.write(lines, file, StandardCharsets.UTF_8) - path = tempDir.toURI.toString + path = dir.toURI.toString } override def afterAll(): Unit = { try { - Utils.deleteRecursively(tempDir) + Utils.deleteRecursively(new File(path)) } finally { super.afterAll() } @@ -86,7 +87,7 @@ class LibSVMRelationSuite extends SparkFunSuite with MLlibTestSparkContext { test("write libsvm data and read it again") { val df = sqlContext.read.format("libsvm").load(path) - val tempDir2 = Utils.createTempDir() + val tempDir2 = new File(tempDir, "read_write_test") val writepath = tempDir2.toURI.toString // TODO: Remove requirement to coalesce by supporting multiple reads. df.coalesce(1).write.format("libsvm").mode(SaveMode.Overwrite).save(writepath) @@ -99,7 +100,7 @@ class LibSVMRelationSuite extends SparkFunSuite with MLlibTestSparkContext { test("write libsvm data failed due to invalid schema") { val df = sqlContext.read.format("text").load(path) - val e = intercept[SparkException] { + intercept[SparkException] { df.write.format("libsvm").save(path + "_2") } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/MLlibTestSparkContext.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/MLlibTestSparkContext.scala index ebcd591465..cb1efd5251 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/util/MLlibTestSparkContext.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/util/MLlibTestSparkContext.scala @@ -17,14 +17,20 @@ package org.apache.spark.mllib.util -import org.scalatest.{BeforeAndAfterAll, Suite} +import java.io.File + +import org.apache.hadoop.fs.Path +import org.scalatest.Suite import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.ml.util.TempDirectory import org.apache.spark.sql.SQLContext +import org.apache.spark.util.Utils -trait MLlibTestSparkContext extends BeforeAndAfterAll { self: Suite => +trait MLlibTestSparkContext extends TempDirectory { self: Suite => @transient var sc: SparkContext = _ @transient var sqlContext: SQLContext = _ + @transient var checkpointDir: String = _ override def beforeAll() { super.beforeAll() @@ -35,10 +41,13 @@ trait MLlibTestSparkContext extends BeforeAndAfterAll { self: Suite => SQLContext.clearActive() sqlContext = new SQLContext(sc) SQLContext.setActive(sqlContext) + checkpointDir = Utils.createDirectory(tempDir.getCanonicalPath, "checkpoints").toString + sc.setCheckpointDir(checkpointDir) } override def afterAll() { try { + Utils.deleteRecursively(new File(checkpointDir)) sqlContext = null SQLContext.clearActive() if (sc != null) { diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index fbadc563b8..a53161dc9a 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -614,6 +614,9 @@ object MimaExcludes { ) ++ Seq( // [SPARK-13430][ML] moved featureCol from LinearRegressionModelSummary to LinearRegressionSummary ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.ml.regression.LinearRegressionSummary.this") + ) ++ Seq( + // [SPARK-13048][ML][MLLIB] keepLastCheckpoint option for LDA EM optimizer + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.mllib.clustering.DistributedLDAModel.this") ) case v if v.startsWith("1.6") => Seq( -- cgit v1.2.3 From 04fb7dba704afa4e20eb8c72d6568f7f55694157 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Thu, 7 Apr 2016 21:41:41 -0700 Subject: Replace getLocalizedMessage with just normal toString in exception handling in WriterContainer. --- .../org/apache/spark/sql/execution/datasources/WriterContainer.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala index f6b7f0854b..d2bbf196cb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala @@ -136,7 +136,7 @@ private[sql] abstract class BaseWriterContainer( // to tell the user to look for the actual error. throw new SparkException("The output file already exists but this could be due to a " + "failure from an earlier attempt. Look through the earlier logs or stage page for " + - "the first error.\n File exists error: " + e.getLocalizedMessage, e) + "the first error.\n File exists error: " + e, e) } else { throw e } -- cgit v1.2.3 From 725b860e2b7b675d95b10c46f2b329c30cd21faf Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Fri, 8 Apr 2016 00:28:59 -0700 Subject: [SPARK-14103][SQL] Parse unescaped quotes in CSV data source. ## What changes were proposed in this pull request? This PR resolves the problem during parsing unescaped quotes in input data. For example, currently the data below: ``` "a"b,ccc,ddd e,f,g ``` produces a data below: - **Before** ```bash ["a"b,ccc,ddd[\n]e,f,g] <- as a value. ``` - **After** ```bash ["a"b], [ccc], [ddd] [e], [f], [g] ``` This PR bumps up the Univocity parser's version. This was fixed in `2.0.2`, https://github.com/uniVocity/univocity-parsers/issues/60. ## How was this patch tested? Unit tests in `CSVSuite` and `sbt/sbt scalastyle`. Author: hyukjinkwon Closes #12226 from HyukjinKwon/SPARK-14103-quote. --- dev/deps/spark-deps-hadoop-2.2 | 2 +- dev/deps/spark-deps-hadoop-2.3 | 2 +- dev/deps/spark-deps-hadoop-2.4 | 2 +- dev/deps/spark-deps-hadoop-2.6 | 2 +- dev/deps/spark-deps-hadoop-2.7 | 2 +- sql/core/pom.xml | 2 +- .../spark/sql/execution/datasources/csv/CSVParser.scala | 1 + sql/core/src/test/resources/unescaped-quotes.csv | 2 ++ .../spark/sql/execution/datasources/csv/CSVSuite.scala | 12 ++++++++++++ 9 files changed, 21 insertions(+), 6 deletions(-) create mode 100644 sql/core/src/test/resources/unescaped-quotes.csv diff --git a/dev/deps/spark-deps-hadoop-2.2 b/dev/deps/spark-deps-hadoop-2.2 index 2c24366cc3..2794b3d235 100644 --- a/dev/deps/spark-deps-hadoop-2.2 +++ b/dev/deps/spark-deps-hadoop-2.2 @@ -175,7 +175,7 @@ stax-api-1.0.1.jar stream-2.7.0.jar stringtemplate-3.2.1.jar super-csv-2.2.0.jar -univocity-parsers-1.5.6.jar +univocity-parsers-2.0.2.jar xbean-asm5-shaded-4.4.jar xmlenc-0.52.jar xz-1.0.jar diff --git a/dev/deps/spark-deps-hadoop-2.3 b/dev/deps/spark-deps-hadoop-2.3 index e9cb0d8f3e..4906fe9cfa 100644 --- a/dev/deps/spark-deps-hadoop-2.3 +++ b/dev/deps/spark-deps-hadoop-2.3 @@ -166,7 +166,7 @@ stax-api-1.0.1.jar stream-2.7.0.jar stringtemplate-3.2.1.jar super-csv-2.2.0.jar -univocity-parsers-1.5.6.jar +univocity-parsers-2.0.2.jar xbean-asm5-shaded-4.4.jar xmlenc-0.52.jar xz-1.0.jar diff --git a/dev/deps/spark-deps-hadoop-2.4 b/dev/deps/spark-deps-hadoop-2.4 index d8d1840da5..23ff5cfa2e 100644 --- a/dev/deps/spark-deps-hadoop-2.4 +++ b/dev/deps/spark-deps-hadoop-2.4 @@ -167,7 +167,7 @@ stax-api-1.0.1.jar stream-2.7.0.jar stringtemplate-3.2.1.jar super-csv-2.2.0.jar -univocity-parsers-1.5.6.jar +univocity-parsers-2.0.2.jar xbean-asm5-shaded-4.4.jar xmlenc-0.52.jar xz-1.0.jar diff --git a/dev/deps/spark-deps-hadoop-2.6 b/dev/deps/spark-deps-hadoop-2.6 index 8beede1e38..9b5a5643f3 100644 --- a/dev/deps/spark-deps-hadoop-2.6 +++ b/dev/deps/spark-deps-hadoop-2.6 @@ -173,7 +173,7 @@ stax-api-1.0.1.jar stream-2.7.0.jar stringtemplate-3.2.1.jar super-csv-2.2.0.jar -univocity-parsers-1.5.6.jar +univocity-parsers-2.0.2.jar xbean-asm5-shaded-4.4.jar xercesImpl-2.9.1.jar xmlenc-0.52.jar diff --git a/dev/deps/spark-deps-hadoop-2.7 b/dev/deps/spark-deps-hadoop-2.7 index a9d814f944..1dca2fc55a 100644 --- a/dev/deps/spark-deps-hadoop-2.7 +++ b/dev/deps/spark-deps-hadoop-2.7 @@ -174,7 +174,7 @@ stax-api-1.0.1.jar stream-2.7.0.jar stringtemplate-3.2.1.jar super-csv-2.2.0.jar -univocity-parsers-1.5.6.jar +univocity-parsers-2.0.2.jar xbean-asm5-shaded-4.4.jar xercesImpl-2.9.1.jar xmlenc-0.52.jar diff --git a/sql/core/pom.xml b/sql/core/pom.xml index 708670b292..8b1017042c 100644 --- a/sql/core/pom.xml +++ b/sql/core/pom.xml @@ -39,7 +39,7 @@ com.univocity univocity-parsers - 1.5.6 + 2.0.2 jar diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVParser.scala index 5570b2c173..c3d863f547 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVParser.scala @@ -47,6 +47,7 @@ private[sql] abstract class CsvReader(params: CSVOptions, headers: Seq[String]) settings.setMaxColumns(params.maxColumns) settings.setNullValue(params.nullValue) settings.setMaxCharsPerColumn(params.maxCharsPerColumn) + settings.setParseUnescapedQuotesUntilDelimiter(true) if (headers != null) settings.setHeaders(headers: _*) new CsvParser(settings) diff --git a/sql/core/src/test/resources/unescaped-quotes.csv b/sql/core/src/test/resources/unescaped-quotes.csv new file mode 100644 index 0000000000..7c68055575 --- /dev/null +++ b/sql/core/src/test/resources/unescaped-quotes.csv @@ -0,0 +1,2 @@ +"a"b,ccc,ddd +ab,cc"c,ddd" diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala index 58d9d69d9a..9baae80f15 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala @@ -45,6 +45,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { private val disableCommentsFile = "disable_comments.csv" private val boolFile = "bool.csv" private val simpleSparseFile = "simple_sparse.csv" + private val unescapedQuotesFile = "unescaped-quotes.csv" private def testFile(fileName: String): String = { Thread.currentThread().getContextClassLoader.getResource(fileName).toString @@ -140,6 +141,17 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { verifyCars(cars, withHeader = true) } + test("parse unescaped quotes with maxCharsPerColumn") { + val rows = sqlContext.read + .format("csv") + .option("maxCharsPerColumn", "4") + .load(testFile(unescapedQuotesFile)) + + val expectedRows = Seq(Row("\"a\"b", "ccc", "ddd"), Row("ab", "cc\"c", "ddd\"")) + + checkAnswer(rows, expectedRows) + } + test("bad encoding name") { val exception = intercept[UnsupportedCharsetException] { sqlContext -- cgit v1.2.3 From 73b56a3c6c5c590219b42884c8bbe88b0a236987 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Fri, 8 Apr 2016 00:30:26 -0700 Subject: [SPARK-14189][SQL] JSON data sources find compatible types even if inferred decimal type is not capable of the others ## What changes were proposed in this pull request? https://issues.apache.org/jira/browse/SPARK-14189 When inferred types in the same field during finding compatible `DataType`, are `IntegralType` and `DecimalType` but `DecimalType` is not capable of the given `IntegralType`, JSON data source simply fails to find a compatible type resulting in `StringType`. This can be observed when `prefersDecimal` is enabled. ```scala def mixedIntegerAndDoubleRecords: RDD[String] = sqlContext.sparkContext.parallelize( """{"a": 3, "b": 1.1}""" :: """{"a": 3.1, "b": 1}""" :: Nil) val jsonDF = sqlContext.read .option("prefersDecimal", "true") .json(mixedIntegerAndDoubleRecords) .printSchema() ``` - **Before** ``` root |-- a: string (nullable = true) |-- b: string (nullable = true) ``` - **After** ``` root |-- a: decimal(21, 1) (nullable = true) |-- b: decimal(21, 1) (nullable = true) ``` (Note that integer is inferred as `LongType` which becomes `DecimalType(20, 0)`) ## How was this patch tested? unit tests were used and style tests by `dev/run_tests`. Author: hyukjinkwon Closes #11993 from HyukjinKwon/SPARK-14189. --- .../execution/datasources/json/InferSchema.scala | 8 ++++++++ .../sql/execution/datasources/json/JsonSuite.scala | 22 ++++++++++++++++++++++ 2 files changed, 30 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/InferSchema.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/InferSchema.scala index 4a34f365e4..8e8238a594 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/InferSchema.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/InferSchema.scala @@ -256,6 +256,14 @@ private[sql] object InferSchema { case (ArrayType(elementType1, containsNull1), ArrayType(elementType2, containsNull2)) => ArrayType(compatibleType(elementType1, elementType2), containsNull1 || containsNull2) + // The case that given `DecimalType` is capable of given `IntegralType` is handled in + // `findTightestCommonTypeOfTwo`. Both cases below will be executed only when + // the given `DecimalType` is not capable of the given `IntegralType`. + case (t1: IntegralType, t2: DecimalType) => + compatibleType(DecimalType.forType(t1), t2) + case (t1: DecimalType, t2: IntegralType) => + compatibleType(t1, DecimalType.forType(t2)) + // strings and every string is a Json object. case (_, _) => StringType } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index 421862c394..2a18acb95b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -773,6 +773,28 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { ) } + test("Find compatible types even if inferred DecimalType is not capable of other IntegralType") { + val mixedIntegerAndDoubleRecords = sparkContext.parallelize( + """{"a": 3, "b": 1.1}""" :: + s"""{"a": 3.1, "b": 0.${"0" * 38}1}""" :: Nil) + val jsonDF = sqlContext.read + .option("prefersDecimal", "true") + .json(mixedIntegerAndDoubleRecords) + + // The values in `a` field will be decimals as they fit in decimal. For `b` field, + // they will be doubles as `1.0E-39D` does not fit. + val expectedSchema = StructType( + StructField("a", DecimalType(21, 1), true) :: + StructField("b", DoubleType, true) :: Nil) + + assert(expectedSchema === jsonDF.schema) + checkAnswer( + jsonDF, + Row(BigDecimal("3"), 1.1D) :: + Row(BigDecimal("3.1"), 1.0E-39D) :: Nil + ) + } + test("Infer big integers correctly even when it does not fit in decimal") { val jsonDF = sqlContext.read .json(bigIntegerRecords) -- cgit v1.2.3 From 6447098013fad708769423ef108a2b071e0930d8 Mon Sep 17 00:00:00 2001 From: Jacek Laskowski Date: Fri, 8 Apr 2016 11:36:41 +0100 Subject: [SPARK-14402][HOTFIX] Fix ExpressionDescription annotation ## What changes were proposed in this pull request? Fix for the error introduced in https://github.com/apache/spark/commit/c59abad052b7beec4ef550049413e95578e545be: ``` /Users/jacek/dev/oss/spark/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala:626: error: annotation argument needs to be a constant; found: "_FUNC_(str) - ".+("Returns str, with the first letter of each word in uppercase, all other letters in ").+("lowercase. Words are delimited by white space.") "Returns str, with the first letter of each word in uppercase, all other letters in " + ^ ``` ## How was this patch tested? Local build Author: Jacek Laskowski Closes #12192 from jaceklaskowski/SPARK-14402-HOTFIX. --- .../apache/spark/sql/catalyst/expressions/stringExpressions.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index b6ea03cd5c..7e0e7a833b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -622,9 +622,9 @@ case class FormatString(children: Expression*) extends Expression with ImplicitC * Words are delimited by whitespace. */ @ExpressionDescription( - usage = "_FUNC_(str) - " + - "Returns str, with the first letter of each word in uppercase, all other letters in " + - "lowercase. Words are delimited by white space.", + usage = + """_FUNC_(str) - Returns str with the first letter of each word in uppercase. + All other letters are in lowercase. Words are delimited by white space.""", extended = "> SELECT initcap('sPark sql');\n 'Spark Sql'") case class InitCap(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { -- cgit v1.2.3 From 583b5e05309adb73cdffd974a810d6bfb5f2ff95 Mon Sep 17 00:00:00 2001 From: Aaron Tokhy Date: Fri, 8 Apr 2016 12:52:25 +0100 Subject: [SPARK-14470] Allow for overriding both httpclient and httpcore versions ## What changes were proposed in this pull request? This splits commons.httpclient.version from commons.httpcore.version, since these two versions do not necessarily have to be the same. This change may follow up with an up-to-date version of the httpclient/httpcore libraries. The latest 4.3.x httpclient version as of writing is 4.3.6 and the latest 4.3.x httpcore version as of writing is 4.3.3. This change would be a prerequisite for potentially moving to this new bugfix version. ## How was this patch tested? no version change was made for httpclient/httpcore versions mvn package Author: Aaron Tokhy Closes #12245 from atokhy/pull-request. --- pom.xml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pom.xml b/pom.xml index 66a34e4bdf..1b40983a6c 100644 --- a/pom.xml +++ b/pom.xml @@ -151,6 +151,7 @@ 0.10.2 4.3.2 + 4.3.2 3.1 3.4.1 @@ -412,7 +413,7 @@ org.apache.httpcomponents httpcore - ${commons.httpclient.version} + ${commons.httpcore.version} org.seleniumhq.selenium -- cgit v1.2.3 From a9b630f42ac0c6be3437f206beddaf0ef737f5c8 Mon Sep 17 00:00:00 2001 From: Mark Grover Date: Fri, 8 Apr 2016 10:10:10 -0700 Subject: [SPARK-14477][BUILD] Allow custom mirrors for downloading artifacts in build/mvn ## What changes were proposed in this pull request? Allows to override locations for downloading Apache and Typesafe artifacts in build/mvn script. ## How was this patch tested? By running script like ```` # Remove all previously downloaded artifacts rm -rf build/apache-maven* rm -rf build/zinc-* rm -rf build/scala-* # Make sure path is clean and doesn't contain mvn, for example. ... # Run a command without setting anything and make sure it succeeds build/mvn -Pyarn -Phadoop-2.4 -Dhadoop.version=2.6.0 package # Run a command setting the default location as mirror and make sure it succeeds APACHE_MIRROR=http://mirror.infra.cloudera.com/apache/ build/mvn -Pyarn -Phadoop-2.4 -Dhadoop.version=2.6.0 package # Do the same without the trailing slash this time and make sure it succeeds APACHE_MIRROR=http://mirror.infra.cloudera.com/apache build/mvn -Pyarn -Phadoop-2.4 -Dhadoop.version=2.6.0 package # Do it with a bad URL and make sure it fails APACHE_MIRROR=xyz build/mvn -Pyarn -Phadoop-2.4 -Dhadoop.version=2.6.0 package ```` Author: Mark Grover Closes #12250 from markgrover/spark-14477. --- build/mvn | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/build/mvn b/build/mvn index 41c0850ccb..54b815d2b8 100755 --- a/build/mvn +++ b/build/mvn @@ -70,9 +70,10 @@ install_app() { # Install maven under the build/ folder install_mvn() { local MVN_VERSION="3.3.9" + local APACHE_MIRROR=${APACHE_MIRROR:-https://archive.apache.org/dist} install_app \ - "https://archive.apache.org/dist/maven/maven-3/${MVN_VERSION}/binaries" \ + "${APACHE_MIRROR}/maven/maven-3/${MVN_VERSION}/binaries" \ "apache-maven-${MVN_VERSION}-bin.tar.gz" \ "apache-maven-${MVN_VERSION}/bin/mvn" @@ -83,8 +84,10 @@ install_mvn() { install_zinc() { local zinc_path="zinc-0.3.9/bin/zinc" [ ! -f "${_DIR}/${zinc_path}" ] && ZINC_INSTALL_FLAG=1 + local TYPESAFE_MIRROR=${TYPESAFE_MIRROR:-https://downloads.typesafe.com} + install_app \ - "https://downloads.typesafe.com/zinc/0.3.9" \ + "${TYPESAFE_MIRROR}/zinc/0.3.9" \ "zinc-0.3.9.tgz" \ "${zinc_path}" ZINC_BIN="${_DIR}/${zinc_path}" @@ -98,9 +101,10 @@ install_scala() { local scala_version=`grep "scala.version" "${_DIR}/../pom.xml" | \ head -1 | cut -f2 -d'>' | cut -f1 -d'<'` local scala_bin="${_DIR}/scala-${scala_version}/bin/scala" + local TYPESAFE_MIRROR=${TYPESAFE_MIRROR:-https://downloads.typesafe.com} install_app \ - "https://downloads.typesafe.com/scala/${scala_version}" \ + "${TYPESAFE_MIRROR}/scala/${scala_version}" \ "scala-${scala_version}.tgz" \ "scala-${scala_version}/bin/scala" -- cgit v1.2.3 From e5d8d6e09cad304e353c96f9408fb9f799348827 Mon Sep 17 00:00:00 2001 From: Kai Jiang Date: Fri, 8 Apr 2016 10:39:12 -0700 Subject: [SPARK-14373][PYSPARK] PySpark RandomForestClassifier, Regressor support export/import ## What changes were proposed in this pull request? supporting `RandomForest{Classifier, Regressor}` save/load for Python API. [JIRA](https://issues.apache.org/jira/browse/SPARK-14373) ## How was this patch tested? doctest Author: Kai Jiang Closes #12238 from vectorijk/spark-14373. --- python/pyspark/ml/classification.py | 15 +++++++++++++-- python/pyspark/ml/regression.py | 15 +++++++++++++-- 2 files changed, 26 insertions(+), 4 deletions(-) diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index be7f9ea9ef..d98919b3c6 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -621,7 +621,8 @@ class DecisionTreeClassificationModel(DecisionTreeModel, JavaMLWritable, JavaMLR @inherit_doc class RandomForestClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasSeed, HasRawPredictionCol, HasProbabilityCol, - RandomForestParams, TreeClassifierParams, HasCheckpointInterval): + RandomForestParams, TreeClassifierParams, HasCheckpointInterval, + JavaMLWritable, JavaMLReadable): """ `http://en.wikipedia.org/wiki/Random_forest Random Forest` learning algorithm for classification. @@ -655,6 +656,16 @@ class RandomForestClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPred >>> test1 = sqlContext.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], ["features"]) >>> model.transform(test1).head().prediction 1.0 + >>> rfc_path = temp_path + "/rfc" + >>> rf.save(rfc_path) + >>> rf2 = RandomForestClassifier.load(rfc_path) + >>> rf2.getNumTrees() + 3 + >>> model_path = temp_path + "/rfc_model" + >>> model.save(model_path) + >>> model2 = RandomForestClassificationModel.load(model_path) + >>> model.featureImportances == model2.featureImportances + True .. versionadded:: 1.4.0 """ @@ -703,7 +714,7 @@ class RandomForestClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPred return RandomForestClassificationModel(java_model) -class RandomForestClassificationModel(TreeEnsembleModels): +class RandomForestClassificationModel(TreeEnsembleModels, JavaMLWritable, JavaMLReadable): """ Model fitted by RandomForestClassifier. diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py index 6cd1b4bf3a..00a6a0de90 100644 --- a/python/pyspark/ml/regression.py +++ b/python/pyspark/ml/regression.py @@ -782,7 +782,8 @@ class DecisionTreeRegressionModel(DecisionTreeModel, JavaMLWritable, JavaMLReada @inherit_doc class RandomForestRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasSeed, - RandomForestParams, TreeRegressorParams, HasCheckpointInterval): + RandomForestParams, TreeRegressorParams, HasCheckpointInterval, + JavaMLWritable, JavaMLReadable): """ `http://en.wikipedia.org/wiki/Random_forest Random Forest` learning algorithm for regression. @@ -805,6 +806,16 @@ class RandomForestRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredi >>> test1 = sqlContext.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], ["features"]) >>> model.transform(test1).head().prediction 0.5 + >>> rfr_path = temp_path + "/rfr" + >>> rf.save(rfr_path) + >>> rf2 = RandomForestRegressor.load(rfr_path) + >>> rf2.getNumTrees() + 2 + >>> model_path = temp_path + "/rfr_model" + >>> model.save(model_path) + >>> model2 = RandomForestRegressionModel.load(model_path) + >>> model.featureImportances == model2.featureImportances + True .. versionadded:: 1.4.0 """ @@ -854,7 +865,7 @@ class RandomForestRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredi return RandomForestRegressionModel(java_model) -class RandomForestRegressionModel(TreeEnsembleModels): +class RandomForestRegressionModel(TreeEnsembleModels, JavaMLWritable, JavaMLReadable): """ Model fitted by RandomForestRegressor. -- cgit v1.2.3 From e0ad75f2b55772efc82a6f8ebb1b2d80fe27d9b5 Mon Sep 17 00:00:00 2001 From: "wm624@hotmail.com" Date: Fri, 8 Apr 2016 10:47:05 -0700 Subject: [SPARK-12569][PYSPARK][ML] DecisionTreeRegressor: provide variance of prediction: Python API ## What changes were proposed in this pull request? A new column VarianceCol has been added to DecisionTreeRegressor in ML scala code. This patch adds the corresponding Python API, HasVarianceCol, to class DecisionTreeRegressor. ## How was this patch tested? ./dev/lint-python PEP8 checks passed. rm -rf _build/* pydoc checks passed. ./python/run-tests --python-executables=python2.7 --modules=pyspark-ml Running PySpark tests. Output is in /Users/mwang/spark_ws_0904/python/unit-tests.log Will test against the following Python executables: ['python2.7'] Will test the following Python modules: ['pyspark-ml'] Finished test(python2.7): pyspark.ml.evaluation (12s) Finished test(python2.7): pyspark.ml.clustering (18s) Finished test(python2.7): pyspark.ml.classification (30s) Finished test(python2.7): pyspark.ml.recommendation (28s) Finished test(python2.7): pyspark.ml.feature (43s) Finished test(python2.7): pyspark.ml.regression (31s) Finished test(python2.7): pyspark.ml.tuning (19s) Finished test(python2.7): pyspark.ml.tests (34s) (If this patch involves UI changes, please attach a screenshot; otherwise, remove this) Author: wm624@hotmail.com Closes #12116 from wangmiao1981/fix_api. --- python/pyspark/ml/param/_shared_params_code_gen.py | 4 +++- python/pyspark/ml/param/shared.py | 24 ++++++++++++++++++++++ python/pyspark/ml/regression.py | 14 +++++++------ 3 files changed, 35 insertions(+), 7 deletions(-) diff --git a/python/pyspark/ml/param/_shared_params_code_gen.py b/python/pyspark/ml/param/_shared_params_code_gen.py index 715fa9e9f8..a7615c43be 100644 --- a/python/pyspark/ml/param/_shared_params_code_gen.py +++ b/python/pyspark/ml/param/_shared_params_code_gen.py @@ -146,7 +146,9 @@ if __name__ == "__main__": ("weightCol", "weight column name. If this is not set or empty, we treat " + "all instance weights as 1.0.", None, "TypeConverters.toString"), ("solver", "the solver algorithm for optimization. If this is not set or empty, " + - "default value is 'auto'.", "'auto'", "TypeConverters.toString")] + "default value is 'auto'.", "'auto'", "TypeConverters.toString"), + ("varianceCol", "column name for the biased sample variance of prediction.", + None, "TypeConverters.toString")] code = [] for name, doc, defaultValueStr, typeConverter in shared: diff --git a/python/pyspark/ml/param/shared.py b/python/pyspark/ml/param/shared.py index d79d55e463..c9e975525c 100644 --- a/python/pyspark/ml/param/shared.py +++ b/python/pyspark/ml/param/shared.py @@ -559,6 +559,30 @@ class HasSolver(Params): return self.getOrDefault(self.solver) +class HasVarianceCol(Params): + """ + Mixin for param varianceCol: column name for the biased sample variance of prediction. + """ + + varianceCol = Param(Params._dummy(), "varianceCol", "column name for the biased sample variance of prediction.", typeConverter=TypeConverters.toString) + + def __init__(self): + super(HasVarianceCol, self).__init__() + + def setVarianceCol(self, value): + """ + Sets the value of :py:attr:`varianceCol`. + """ + self._set(varianceCol=value) + return self + + def getVarianceCol(self): + """ + Gets the value of varianceCol or its default value. + """ + return self.getOrDefault(self.varianceCol) + + class DecisionTreeParams(Params): """ Mixin for Decision Tree parameters. diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py index 00a6a0de90..f6c5d130dd 100644 --- a/python/pyspark/ml/regression.py +++ b/python/pyspark/ml/regression.py @@ -630,7 +630,7 @@ class GBTParams(TreeEnsembleParams): @inherit_doc class DecisionTreeRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, DecisionTreeParams, TreeRegressorParams, HasCheckpointInterval, - HasSeed, JavaMLWritable, JavaMLReadable): + HasSeed, JavaMLWritable, JavaMLReadable, HasVarianceCol): """ `http://en.wikipedia.org/wiki/Decision_tree_learning Decision tree` learning algorithm for regression. @@ -640,7 +640,7 @@ class DecisionTreeRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredi >>> df = sqlContext.createDataFrame([ ... (1.0, Vectors.dense(1.0)), ... (0.0, Vectors.sparse(1, [], []))], ["label", "features"]) - >>> dt = DecisionTreeRegressor(maxDepth=2) + >>> dt = DecisionTreeRegressor(maxDepth=2, varianceCol="variance") >>> model = dt.fit(df) >>> model.depth 1 @@ -666,6 +666,8 @@ class DecisionTreeRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredi True >>> model.depth == model2.depth True + >>> model.transform(test1).head().variance + 0.0 .. versionadded:: 1.4.0 """ @@ -674,12 +676,12 @@ class DecisionTreeRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredi def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="variance", - seed=None): + seed=None, varianceCol=None): """ __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \ maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, \ - impurity="variance", seed=None) + impurity="variance", seed=None, varianceCol=None) """ super(DecisionTreeRegressor, self).__init__() self._java_obj = self._new_java_obj( @@ -695,12 +697,12 @@ class DecisionTreeRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredi def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, - impurity="variance", seed=None): + impurity="variance", seed=None, varianceCol=None): """ setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \ maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, \ - impurity="variance", seed=None) + impurity="variance", seed=None, varianceCol=None) Sets params for the DecisionTreeRegressor. """ kwargs = self.setParams._input_kwargs -- cgit v1.2.3 From 94ac58b2a8ae83be670169062c8b83bf10e41d74 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Fri, 8 Apr 2016 11:26:28 -0700 Subject: [BUILD][HOTFIX] Download Maven from regular mirror network rather than archive.apache.org [archive.apache.org](https://archive.apache.org/) is undergoing maintenance, breaking our `build/mvn` script: > We are in the process of relocating this service. To save on the immense bandwidth that this service outputs, we have put it in maintenance mode, disabling all downloads for the next few days. We expect the maintenance to be complete no later than the morning of Monday the 11th of April, 2016. This patch fixes this issue by updating the script to use the regular mirror network to download Maven. Author: Josh Rosen Closes #12262 from JoshRosen/fix-mvn-download. --- build/mvn | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/build/mvn b/build/mvn index 54b815d2b8..eb42552fc4 100755 --- a/build/mvn +++ b/build/mvn @@ -70,7 +70,7 @@ install_app() { # Install maven under the build/ folder install_mvn() { local MVN_VERSION="3.3.9" - local APACHE_MIRROR=${APACHE_MIRROR:-https://archive.apache.org/dist} + local APACHE_MIRROR=${APACHE_MIRROR:-'https://www.apache.org/dyn/closer.lua?action=download&filename='} install_app \ "${APACHE_MIRROR}/maven/maven-3/${MVN_VERSION}/binaries" \ -- cgit v1.2.3 From 56af8e85cca056096fe4e765d8d287e0f9efc0d2 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Fri, 8 Apr 2016 11:49:44 -0700 Subject: [SPARK-14298][ML][MLLIB] LDA should support disable checkpoint ## What changes were proposed in this pull request? In the doc of [```checkpointInterval```](https://github.com/apache/spark/blob/master/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala#L241), we told users that they can disable checkpoint by setting ```checkpointInterval = -1```. But we did not handle this situation for LDA actually, we should fix this bug. ## How was this patch tested? Existing tests. cc jkbradley Author: Yanbo Liang Closes #12089 from yanboliang/spark-14298. --- .../scala/org/apache/spark/mllib/impl/PeriodicCheckpointer.scala | 6 ++++-- .../org/apache/spark/mllib/impl/PeriodicGraphCheckpointer.scala | 3 ++- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicCheckpointer.scala b/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicCheckpointer.scala index cbc8f60112..5c12c9305b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicCheckpointer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicCheckpointer.scala @@ -52,7 +52,8 @@ import org.apache.spark.storage.StorageLevel * - This class removes checkpoint files once later Datasets have been checkpointed. * However, references to the older Datasets will still return isCheckpointed = true. * - * @param checkpointInterval Datasets will be checkpointed at this interval + * @param checkpointInterval Datasets will be checkpointed at this interval. + * If this interval was set as -1, then checkpointing will be disabled. * @param sc SparkContext for the Datasets given to this checkpointer * @tparam T Dataset type, such as RDD[Double] */ @@ -89,7 +90,8 @@ private[mllib] abstract class PeriodicCheckpointer[T]( updateCount += 1 // Handle checkpointing (after persisting) - if ((updateCount % checkpointInterval) == 0 && sc.getCheckpointDir.nonEmpty) { + if (checkpointInterval != -1 && (updateCount % checkpointInterval) == 0 + && sc.getCheckpointDir.nonEmpty) { // Add new checkpoint before removing old checkpoints. checkpoint(newData) checkpointQueue.enqueue(newData) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointer.scala b/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointer.scala index 11a059536c..20db6084d0 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointer.scala @@ -69,7 +69,8 @@ import org.apache.spark.storage.StorageLevel * // checkpointed: graph4 * }}} * - * @param checkpointInterval Graphs will be checkpointed at this interval + * @param checkpointInterval Graphs will be checkpointed at this interval. + * If this interval was set as -1, then checkpointing will be disabled. * @tparam VD Vertex descriptor type * @tparam ED Edge descriptor type * -- cgit v1.2.3 From 02757535b58069ce8258108d89d8172a53c358e5 Mon Sep 17 00:00:00 2001 From: tedyu Date: Fri, 8 Apr 2016 12:25:36 -0700 Subject: [SPARK-14448] Improvements to ColumnVector ## What changes were proposed in this pull request? In this PR, two changes are proposed for ColumnVector : 1. ColumnVector should be declared as implementing AutoCloseable - it already has close() method 2. In OnHeapColumnVector#reserveInternal(), we only need to allocate new array when existing array is null or the length of existing array is shorter than the newCapacity. ## How was this patch tested? Existing unit tests. Author: tedyu Closes #12225 from tedyu/master. --- .../sql/execution/vectorized/ColumnVector.java | 2 +- .../execution/vectorized/OnHeapColumnVector.java | 56 ++++++++++++++-------- 2 files changed, 36 insertions(+), 22 deletions(-) diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java index d5daaf99df..0b276e6c77 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java @@ -56,7 +56,7 @@ import org.apache.spark.unsafe.types.UTF8String; * * ColumnVectors are intended to be reused. */ -public abstract class ColumnVector { +public abstract class ColumnVector implements AutoCloseable { /** * Allocates a column to store elements of `type` on or off heap. * Capacity is the initial capacity of the vector and it will grow as necessary. Capacity is diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java index 708a00953a..e97276800d 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java @@ -387,35 +387,49 @@ public final class OnHeapColumnVector extends ColumnVector { arrayLengths = newLengths; arrayOffsets = newOffsets; } else if (type instanceof BooleanType) { - byte[] newData = new byte[newCapacity]; - if (byteData != null) System.arraycopy(byteData, 0, newData, 0, elementsAppended); - byteData = newData; + if (byteData == null || byteData.length < newCapacity) { + byte[] newData = new byte[newCapacity]; + if (byteData != null) System.arraycopy(byteData, 0, newData, 0, elementsAppended); + byteData = newData; + } } else if (type instanceof ByteType) { - byte[] newData = new byte[newCapacity]; - if (byteData != null) System.arraycopy(byteData, 0, newData, 0, elementsAppended); - byteData = newData; + if (byteData == null || byteData.length < newCapacity) { + byte[] newData = new byte[newCapacity]; + if (byteData != null) System.arraycopy(byteData, 0, newData, 0, elementsAppended); + byteData = newData; + } } else if (type instanceof ShortType) { - short[] newData = new short[newCapacity]; - if (shortData != null) System.arraycopy(shortData, 0, newData, 0, elementsAppended); - shortData = newData; + if (shortData == null || shortData.length < newCapacity) { + short[] newData = new short[newCapacity]; + if (shortData != null) System.arraycopy(shortData, 0, newData, 0, elementsAppended); + shortData = newData; + } } else if (type instanceof IntegerType || type instanceof DateType || DecimalType.is32BitDecimalType(type)) { - int[] newData = new int[newCapacity]; - if (intData != null) System.arraycopy(intData, 0, newData, 0, elementsAppended); - intData = newData; + if (intData == null || intData.length < newCapacity) { + int[] newData = new int[newCapacity]; + if (intData != null) System.arraycopy(intData, 0, newData, 0, elementsAppended); + intData = newData; + } } else if (type instanceof LongType || type instanceof TimestampType || DecimalType.is64BitDecimalType(type)) { - long[] newData = new long[newCapacity]; - if (longData != null) System.arraycopy(longData, 0, newData, 0, elementsAppended); - longData = newData; + if (longData == null || longData.length < newCapacity) { + long[] newData = new long[newCapacity]; + if (longData != null) System.arraycopy(longData, 0, newData, 0, elementsAppended); + longData = newData; + } } else if (type instanceof FloatType) { - float[] newData = new float[newCapacity]; - if (floatData != null) System.arraycopy(floatData, 0, newData, 0, elementsAppended); - floatData = newData; + if (floatData == null || floatData.length < newCapacity) { + float[] newData = new float[newCapacity]; + if (floatData != null) System.arraycopy(floatData, 0, newData, 0, elementsAppended); + floatData = newData; + } } else if (type instanceof DoubleType) { - double[] newData = new double[newCapacity]; - if (doubleData != null) System.arraycopy(doubleData, 0, newData, 0, elementsAppended); - doubleData = newData; + if (doubleData == null || doubleData.length < newCapacity) { + double[] newData = new double[newCapacity]; + if (doubleData != null) System.arraycopy(doubleData, 0, newData, 0, elementsAppended); + doubleData = newData; + } } else if (resultStruct != null) { // Nothing to store. } else { -- cgit v1.2.3 From f8c9beca38f1f396eb3220b23db6d77112a50293 Mon Sep 17 00:00:00 2001 From: Sameer Agarwal Date: Fri, 8 Apr 2016 13:52:28 -0700 Subject: [SPARK-14394][SQL] Generate AggregateHashMap class for LongTypes during TungstenAggregate codegen ## What changes were proposed in this pull request? This PR adds support for generating the `AggregateHashMap` class in `TungstenAggregate` if the aggregate group by keys/value are of `LongType`. Note that currently this generate aggregate is not actually used. NB: This currently only supports `LongType` keys/values (please see `isAggregateHashMapSupported` in `TungstenAggregate`) and will be generalized to other data types in a subsequent PR. ## How was this patch tested? Manually inspected the generated code. This is what the generated map looks like for 2 keys: ```java /* 068 */ public class agg_GeneratedAggregateHashMap { /* 069 */ private org.apache.spark.sql.execution.vectorized.ColumnarBatch batch; /* 070 */ private int[] buckets; /* 071 */ private int numBuckets; /* 072 */ private int maxSteps; /* 073 */ private int numRows = 0; /* 074 */ private org.apache.spark.sql.types.StructType schema = /* 075 */ new org.apache.spark.sql.types.StructType() /* 076 */ .add("k1", org.apache.spark.sql.types.DataTypes.LongType) /* 077 */ .add("k2", org.apache.spark.sql.types.DataTypes.LongType) /* 078 */ .add("sum", org.apache.spark.sql.types.DataTypes.LongType); /* 079 */ /* 080 */ public agg_GeneratedAggregateHashMap(int capacity, double loadFactor, int maxSteps) { /* 081 */ assert (capacity > 0 && ((capacity & (capacity - 1)) == 0)); /* 082 */ this.maxSteps = maxSteps; /* 083 */ numBuckets = (int) (capacity / loadFactor); /* 084 */ batch = org.apache.spark.sql.execution.vectorized.ColumnarBatch.allocate(schema, /* 085 */ org.apache.spark.memory.MemoryMode.ON_HEAP, capacity); /* 086 */ buckets = new int[numBuckets]; /* 087 */ java.util.Arrays.fill(buckets, -1); /* 088 */ } /* 089 */ /* 090 */ public agg_GeneratedAggregateHashMap() { /* 091 */ new agg_GeneratedAggregateHashMap(1 << 16, 0.25, 5); /* 092 */ } /* 093 */ /* 094 */ public org.apache.spark.sql.execution.vectorized.ColumnarBatch.Row findOrInsert(long agg_key, long agg_key1) { /* 095 */ long h = hash(agg_key, agg_key1); /* 096 */ int step = 0; /* 097 */ int idx = (int) h & (numBuckets - 1); /* 098 */ while (step < maxSteps) { /* 099 */ // Return bucket index if it's either an empty slot or already contains the key /* 100 */ if (buckets[idx] == -1) { /* 101 */ batch.column(0).putLong(numRows, agg_key); /* 102 */ batch.column(1).putLong(numRows, agg_key1); /* 103 */ batch.column(2).putLong(numRows, 0); /* 104 */ buckets[idx] = numRows++; /* 105 */ return batch.getRow(buckets[idx]); /* 106 */ } else if (equals(idx, agg_key, agg_key1)) { /* 107 */ return batch.getRow(buckets[idx]); /* 108 */ } /* 109 */ idx = (idx + 1) & (numBuckets - 1); /* 110 */ step++; /* 111 */ } /* 112 */ // Didn't find it /* 113 */ return null; /* 114 */ } /* 115 */ /* 116 */ private boolean equals(int idx, long agg_key, long agg_key1) { /* 117 */ return batch.column(0).getLong(buckets[idx]) == agg_key && batch.column(1).getLong(buckets[idx]) == agg_key1; /* 118 */ } /* 119 */ /* 120 */ // TODO: Improve this Hash Function /* 121 */ private long hash(long agg_key, long agg_key1) { /* 122 */ return agg_key ^ agg_key1; /* 123 */ } /* 124 */ /* 125 */ } ``` Author: Sameer Agarwal Closes #12161 from sameeragarwal/tungsten-aggregate. --- .../aggregate/ColumnarAggMapCodeGenerator.scala | 193 +++++++++++++++++++++ .../execution/aggregate/TungstenAggregate.scala | 20 ++- 2 files changed, 210 insertions(+), 3 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ColumnarAggMapCodeGenerator.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ColumnarAggMapCodeGenerator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ColumnarAggMapCodeGenerator.scala new file mode 100644 index 0000000000..e415dd8e6a --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ColumnarAggMapCodeGenerator.scala @@ -0,0 +1,193 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.aggregate + +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext +import org.apache.spark.sql.types.StructType + +/** + * This is a helper object to generate an append-only single-key/single value aggregate hash + * map that can act as a 'cache' for extremely fast key-value lookups while evaluating aggregates + * (and fall back to the `BytesToBytesMap` if a given key isn't found). This is 'codegened' in + * TungstenAggregate to speed up aggregates w/ key. + * + * It is backed by a power-of-2-sized array for index lookups and a columnar batch that stores the + * key-value pairs. The index lookups in the array rely on linear probing (with a small number of + * maximum tries) and use an inexpensive hash function which makes it really efficient for a + * majority of lookups. However, using linear probing and an inexpensive hash function also makes it + * less robust as compared to the `BytesToBytesMap` (especially for a large number of keys or even + * for certain distribution of keys) and requires us to fall back on the latter for correctness. + */ +class ColumnarAggMapCodeGenerator( + ctx: CodegenContext, + generatedClassName: String, + groupingKeySchema: StructType, + bufferSchema: StructType) { + val groupingKeys = groupingKeySchema.map(k => (k.dataType.typeName, ctx.freshName("key"))) + val bufferValues = bufferSchema.map(k => (k.dataType.typeName, ctx.freshName("value"))) + val groupingKeySignature = groupingKeys.map(_.productIterator.toList.mkString(" ")).mkString(", ") + + def generate(): String = { + s""" + |public class $generatedClassName { + |${initializeAggregateHashMap()} + | + |${generateFindOrInsert()} + | + |${generateEquals()} + | + |${generateHashFunction()} + |} + """.stripMargin + } + + private def initializeAggregateHashMap(): String = { + val generatedSchema: String = + s""" + |new org.apache.spark.sql.types.StructType() + |${(groupingKeySchema ++ bufferSchema).map(key => + s""".add("${key.name}", org.apache.spark.sql.types.DataTypes.${key.dataType})""") + .mkString("\n")}; + """.stripMargin + + s""" + | private org.apache.spark.sql.execution.vectorized.ColumnarBatch batch; + | private int[] buckets; + | private int numBuckets; + | private int maxSteps; + | private int numRows = 0; + | private org.apache.spark.sql.types.StructType schema = $generatedSchema + | + | public $generatedClassName(int capacity, double loadFactor, int maxSteps) { + | assert (capacity > 0 && ((capacity & (capacity - 1)) == 0)); + | this.maxSteps = maxSteps; + | numBuckets = (int) (capacity / loadFactor); + | batch = org.apache.spark.sql.execution.vectorized.ColumnarBatch.allocate(schema, + | org.apache.spark.memory.MemoryMode.ON_HEAP, capacity); + | buckets = new int[numBuckets]; + | java.util.Arrays.fill(buckets, -1); + | } + | + | public $generatedClassName() { + | new $generatedClassName(1 << 16, 0.25, 5); + | } + """.stripMargin + } + + /** + * Generates a method that computes a hash by currently xor-ing all individual group-by keys. For + * instance, if we have 2 long group-by keys, the generated function would be of the form: + * + * {{{ + * private long hash(long agg_key, long agg_key1) { + * return agg_key ^ agg_key1; + * } + * }}} + */ + private def generateHashFunction(): String = { + s""" + |// TODO: Improve this hash function + |private long hash($groupingKeySignature) { + | return ${groupingKeys.map(_._2).mkString(" ^ ")}; + |} + """.stripMargin + } + + /** + * Generates a method that returns true if the group-by keys exist at a given index in the + * associated [[org.apache.spark.sql.execution.vectorized.ColumnarBatch]]. For instance, if we + * have 2 long group-by keys, the generated function would be of the form: + * + * {{{ + * private boolean equals(int idx, long agg_key, long agg_key1) { + * return batch.column(0).getLong(buckets[idx]) == agg_key && + * batch.column(1).getLong(buckets[idx]) == agg_key1; + * } + * }}} + */ + private def generateEquals(): String = { + s""" + |private boolean equals(int idx, $groupingKeySignature) { + | return ${groupingKeys.zipWithIndex.map(k => + s"batch.column(${k._2}).getLong(buckets[idx]) == ${k._1._2}").mkString(" && ")}; + |} + """.stripMargin + } + + /** + * Generates a method that returns a mutable + * [[org.apache.spark.sql.execution.vectorized.ColumnarBatch.Row]] which keeps track of the + * aggregate value(s) for a given set of keys. If the corresponding row doesn't exist, the + * generated method adds the corresponding row in the associated + * [[org.apache.spark.sql.execution.vectorized.ColumnarBatch]]. For instance, if we + * have 2 long group-by keys, the generated function would be of the form: + * + * {{{ + * public org.apache.spark.sql.execution.vectorized.ColumnarBatch.Row findOrInsert( + * long agg_key, long agg_key1) { + * long h = hash(agg_key, agg_key1); + * int step = 0; + * int idx = (int) h & (numBuckets - 1); + * while (step < maxSteps) { + * // Return bucket index if it's either an empty slot or already contains the key + * if (buckets[idx] == -1) { + * batch.column(0).putLong(numRows, agg_key); + * batch.column(1).putLong(numRows, agg_key1); + * batch.column(2).putLong(numRows, 0); + * buckets[idx] = numRows++; + * return batch.getRow(buckets[idx]); + * } else if (equals(idx, agg_key, agg_key1)) { + * return batch.getRow(buckets[idx]); + * } + * idx = (idx + 1) & (numBuckets - 1); + * step++; + * } + * // Didn't find it + * return null; + * } + * }}} + */ + private def generateFindOrInsert(): String = { + s""" + |public org.apache.spark.sql.execution.vectorized.ColumnarBatch.Row findOrInsert(${ + groupingKeySignature}) { + | long h = hash(${groupingKeys.map(_._2).mkString(", ")}); + | int step = 0; + | int idx = (int) h & (numBuckets - 1); + | while (step < maxSteps) { + | // Return bucket index if it's either an empty slot or already contains the key + | if (buckets[idx] == -1) { + | ${groupingKeys.zipWithIndex.map(k => + s"batch.column(${k._2}).putLong(numRows, ${k._1._2});").mkString("\n")} + | ${bufferValues.zipWithIndex.map(k => + s"batch.column(${groupingKeys.length + k._2}).putLong(numRows, 0);") + .mkString("\n")} + | buckets[idx] = numRows++; + | return batch.getRow(buckets[idx]); + | } else if (equals(idx, ${groupingKeys.map(_._2).mkString(", ")})) { + | return batch.getRow(buckets[idx]); + | } + | idx = (idx + 1) & (numBuckets - 1); + | step++; + | } + | // Didn't find it + | return null; + |} + """.stripMargin + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala index 60027edc7c..0a5a72c52a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.metric.SQLMetrics -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types.{LongType, StructType} import org.apache.spark.unsafe.KVIterator case class TungstenAggregate( @@ -64,8 +64,8 @@ case class TungstenAggregate( override def requiredChildDistribution: List[Distribution] = { requiredChildDistributionExpressions match { - case Some(exprs) if exprs.length == 0 => AllTuples :: Nil - case Some(exprs) if exprs.length > 0 => ClusteredDistribution(exprs) :: Nil + case Some(exprs) if exprs.isEmpty => AllTuples :: Nil + case Some(exprs) if exprs.nonEmpty => ClusteredDistribution(exprs) :: Nil case None => UnspecifiedDistribution :: Nil } } @@ -437,6 +437,19 @@ case class TungstenAggregate( val initAgg = ctx.freshName("initAgg") ctx.addMutableState("boolean", initAgg, s"$initAgg = false;") + // create AggregateHashMap + val isAggregateHashMapEnabled: Boolean = false + val isAggregateHashMapSupported: Boolean = + (groupingKeySchema ++ bufferSchema).forall(_.dataType == LongType) + val aggregateHashMapTerm = ctx.freshName("aggregateHashMap") + val aggregateHashMapClassName = ctx.freshName("GeneratedAggregateHashMap") + val aggregateHashMapGenerator = new ColumnarAggMapCodeGenerator(ctx, aggregateHashMapClassName, + groupingKeySchema, bufferSchema) + if (isAggregateHashMapEnabled && isAggregateHashMapSupported) { + ctx.addMutableState(aggregateHashMapClassName, aggregateHashMapTerm, + s"$aggregateHashMapTerm = new $aggregateHashMapClassName();") + } + // create hashMap val thisPlan = ctx.addReferenceObj("plan", this) hashMapTerm = ctx.freshName("hashMap") @@ -452,6 +465,7 @@ case class TungstenAggregate( val doAgg = ctx.freshName("doAggregateWithKeys") ctx.addNewFunction(doAgg, s""" + ${if (isAggregateHashMapSupported) aggregateHashMapGenerator.generate() else ""} private void $doAgg() throws java.io.IOException { ${child.asInstanceOf[CodegenSupport].produce(ctx, this)} -- cgit v1.2.3 From 464a3c1e02c665c7ad2709f8c47898b682526eb3 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Fri, 8 Apr 2016 13:58:58 -0700 Subject: [SPARK-14435][BUILD] Shade Kryo in our custom Hive 1.2.1 fork This patch updates our custom Hive 1.2.1 fork in order to shade Kryo in Hive. This is a blocker for upgrading Spark to use Kryo 3 (see #12076). The source for this new fork of Hive can be found at https://github.com/JoshRosen/hive/tree/release-1.2.1-spark2 Here's the complete diff from the official Hive 1.2.1 release: https://github.com/apache/hive/compare/release-1.2.1...JoshRosen:release-1.2.1-spark2 Here's the diff from the sources that pwendell used to publish the current `1.2.1.spark` release of Hive: https://github.com/pwendell/hive/compare/release-1.2.1-spark...JoshRosen:release-1.2.1-spark2. This diff looks large because his branch used a shell script to rewrite the groupId, whereas I had to commit the groupId changes in order to prevent the find-and-replace from affecting the package names in our relocated Kryo classes: https://github.com/pwendell/hive/compare/release-1.2.1-spark...JoshRosen:release-1.2.1-spark2#diff-6ada9aaec70e069df8f2c34c5519dd1e Using these changes, I was able to publish a local version of Hive and verify that this change fixes the test failures which are blocking #12076. Note that this PR will not compile until we complete the review of the Hive POM changes and stage and publish a release. /cc vanzin, steveloughran, and pwendell for review. Author: Josh Rosen Closes #12215 from JoshRosen/shade-kryo-in-hive. --- pom.xml | 2 +- .../scala/org/apache/spark/sql/hive/HiveShim.scala | 4 +-- .../sql/hive/ClasspathDependenciesSuite.scala | 41 +++++----------------- 3 files changed, 12 insertions(+), 35 deletions(-) diff --git a/pom.xml b/pom.xml index 1b40983a6c..f37a8988f7 100644 --- a/pom.xml +++ b/pom.xml @@ -131,7 +131,7 @@ 2.4.0 org.spark-project.hive - 1.2.1.spark + 1.2.1.spark2 1.2.1 10.10.1.1 diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveShim.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveShim.scala index da910533d0..0d2a765a38 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveShim.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveShim.scala @@ -24,8 +24,6 @@ import scala.collection.JavaConverters._ import scala.language.implicitConversions import scala.reflect.ClassTag -import com.esotericsoftware.kryo.Kryo -import com.esotericsoftware.kryo.io.{Input, Output} import com.google.common.base.Objects import org.apache.avro.Schema import org.apache.hadoop.conf.Configuration @@ -37,6 +35,8 @@ import org.apache.hadoop.hive.serde2.ColumnProjectionUtils import org.apache.hadoop.hive.serde2.avro.{AvroGenericRecordWritable, AvroSerdeUtils} import org.apache.hadoop.hive.serde2.objectinspector.primitive.HiveDecimalObjectInspector import org.apache.hadoop.io.Writable +import org.apache.hive.com.esotericsoftware.kryo.Kryo +import org.apache.hive.com.esotericsoftware.kryo.io.{Input, Output} import org.apache.spark.internal.Logging import org.apache.spark.sql.types.Decimal diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ClasspathDependenciesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ClasspathDependenciesSuite.scala index 34b2edb44b..f262ef62be 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ClasspathDependenciesSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ClasspathDependenciesSuite.scala @@ -24,9 +24,7 @@ import org.apache.spark.SparkFunSuite /** * Verify that some classes load and that others are not found on the classpath. * - * - * This is used to detect classpath and shading conflict, especially between - * Spark's required Kryo version and that which can be found in some Hive versions. + * This is used to detect classpath and shading conflicts. */ class ClasspathDependenciesSuite extends SparkFunSuite { private val classloader = this.getClass.getClassLoader @@ -40,10 +38,6 @@ class ClasspathDependenciesSuite extends SparkFunSuite { classloader.loadClass(classname) } - private def assertLoads(classes: String*): Unit = { - classes.foreach(assertLoads) - } - private def findResource(classname: String): URL = { val resource = resourceName(classname) classloader.getResource(resource) @@ -63,17 +57,12 @@ class ClasspathDependenciesSuite extends SparkFunSuite { } } - private def assertClassNotFound(classes: String*): Unit = { - classes.foreach(assertClassNotFound) + test("shaded Protobuf") { + assertLoads("org.apache.hive.com.google.protobuf.ServiceException") } - private val KRYO = "com.esotericsoftware.kryo.Kryo" - - private val SPARK_HIVE = "org.apache.hive." - private val SPARK_SHADED = "org.spark-project.hive.shaded." - - test("shaded Protobuf") { - assertLoads(SPARK_SHADED + "com.google.protobuf.ServiceException") + test("shaded Kryo") { + assertLoads("org.apache.hive.com.esotericsoftware.kryo.Kryo") } test("hive-common") { @@ -86,25 +75,13 @@ class ClasspathDependenciesSuite extends SparkFunSuite { private val STD_INSTANTIATOR = "org.objenesis.strategy.StdInstantiatorStrategy" - test("unshaded kryo") { - assertLoads(KRYO, STD_INSTANTIATOR) - } - test("Forbidden Dependencies") { - assertClassNotFound( - SPARK_HIVE + KRYO, - SPARK_SHADED + KRYO, - "org.apache.hive." + KRYO, - "com.esotericsoftware.shaded." + STD_INSTANTIATOR, - SPARK_HIVE + "com.esotericsoftware.shaded." + STD_INSTANTIATOR, - "org.apache.hive.com.esotericsoftware.shaded." + STD_INSTANTIATOR - ) + assertClassNotFound("com.esotericsoftware.shaded." + STD_INSTANTIATOR) + assertClassNotFound("org.apache.hive.com.esotericsoftware.shaded." + STD_INSTANTIATOR) } test("parquet-hadoop-bundle") { - assertLoads( - "parquet.hadoop.ParquetOutputFormat", - "parquet.hadoop.ParquetInputFormat" - ) + assertLoads("parquet.hadoop.ParquetOutputFormat") + assertLoads("parquet.hadoop.ParquetInputFormat") } } -- cgit v1.2.3 From 906eef4c7a380419f2d089262afdcf39454fe31e Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Fri, 8 Apr 2016 16:35:30 -0700 Subject: [SPARK-11416][BUILD] Update to Chill 0.8.0 & Kryo 3.0.3 This patch upgrades Chill to 0.8.0 and Kryo to 3.0.3. While we'll likely need to bump these dependencies again before Spark 2.0 (due to SPARK-14221 / https://github.com/twitter/chill/issues/252), I wanted to get the bulk of the Kryo 2 -> Kryo 3 migration done now in order to figure out whether there are any unexpected surprises. Author: Josh Rosen Closes #12076 from JoshRosen/kryo3. --- LICENSE | 5 ++--- dev/deps/spark-deps-hadoop-2.2 | 11 +++++------ dev/deps/spark-deps-hadoop-2.3 | 11 +++++------ dev/deps/spark-deps-hadoop-2.4 | 11 +++++------ dev/deps/spark-deps-hadoop-2.6 | 11 +++++------ dev/deps/spark-deps-hadoop-2.7 | 11 +++++------ pom.xml | 22 +--------------------- 7 files changed, 28 insertions(+), 54 deletions(-) diff --git a/LICENSE b/LICENSE index 5a8c78b98b..9714b3b1e4 100644 --- a/LICENSE +++ b/LICENSE @@ -257,9 +257,8 @@ The text of each license is also included at licenses/LICENSE-[project].txt. (BSD-style) scalacheck (org.scalacheck:scalacheck_2.11:1.10.0 - http://www.scalacheck.org) (BSD-style) spire (org.spire-math:spire_2.11:0.7.1 - http://spire-math.org) (BSD-style) spire-macros (org.spire-math:spire-macros_2.11:0.7.1 - http://spire-math.org) - (New BSD License) Kryo (com.esotericsoftware.kryo:kryo:2.21 - http://code.google.com/p/kryo/) - (New BSD License) MinLog (com.esotericsoftware.minlog:minlog:1.2 - http://code.google.com/p/minlog/) - (New BSD License) ReflectASM (com.esotericsoftware.reflectasm:reflectasm:1.07 - http://code.google.com/p/reflectasm/) + (New BSD License) Kryo (com.esotericsoftware:kryo:3.0.3 - https://github.com/EsotericSoftware/kryo) + (New BSD License) MinLog (com.esotericsoftware:minlog:1.3.0 - https://github.com/EsotericSoftware/minlog) (New BSD license) Protocol Buffer Java API (com.google.protobuf:protobuf-java:2.5.0 - http://code.google.com/p/protobuf) (New BSD license) Protocol Buffer Java API (org.spark-project.protobuf:protobuf-java:2.4.1-shaded - http://code.google.com/p/protobuf) (The BSD License) Fortran to Java ARPACK (net.sourceforge.f2j:arpack_combined_all:0.1 - http://f2j.sourceforge.net) diff --git a/dev/deps/spark-deps-hadoop-2.2 b/dev/deps/spark-deps-hadoop-2.2 index 2794b3d235..023fba5369 100644 --- a/dev/deps/spark-deps-hadoop-2.2 +++ b/dev/deps/spark-deps-hadoop-2.2 @@ -20,8 +20,8 @@ breeze_2.11-0.11.2.jar calcite-avatica-1.2.0-incubating.jar calcite-core-1.2.0-incubating.jar calcite-linq4j-1.2.0-incubating.jar -chill-java-0.7.4.jar -chill_2.11-0.7.4.jar +chill-java-0.8.0.jar +chill_2.11-0.8.0.jar commons-beanutils-1.7.0.jar commons-beanutils-core-1.8.0.jar commons-cli-1.2.jar @@ -123,7 +123,7 @@ jsr305-1.3.9.jar jta-1.1.jar jtransforms-2.4.0.jar jul-to-slf4j-1.7.16.jar -kryo-2.21.jar +kryo-shaded-3.0.3.jar leveldbjni-all-1.8.jar libfb303-0.9.2.jar libthrift-0.9.2.jar @@ -136,10 +136,10 @@ metrics-core-3.1.2.jar metrics-graphite-3.1.2.jar metrics-json-3.1.2.jar metrics-jvm-3.1.2.jar -minlog-1.2.jar +minlog-1.3.0.jar netty-3.8.0.Final.jar netty-all-4.0.29.Final.jar -objenesis-1.2.jar +objenesis-2.1.jar opencsv-2.3.jar oro-2.0.8.jar paranamer-2.6.jar @@ -157,7 +157,6 @@ pmml-schema-1.2.7.jar protobuf-java-2.5.0.jar py4j-0.9.2.jar pyrolite-4.9.jar -reflectasm-1.07-shaded.jar scala-compiler-2.11.8.jar scala-library-2.11.8.jar scala-parser-combinators_2.11-1.0.4.jar diff --git a/dev/deps/spark-deps-hadoop-2.3 b/dev/deps/spark-deps-hadoop-2.3 index 4906fe9cfa..003c540d72 100644 --- a/dev/deps/spark-deps-hadoop-2.3 +++ b/dev/deps/spark-deps-hadoop-2.3 @@ -22,8 +22,8 @@ breeze_2.11-0.11.2.jar calcite-avatica-1.2.0-incubating.jar calcite-core-1.2.0-incubating.jar calcite-linq4j-1.2.0-incubating.jar -chill-java-0.7.4.jar -chill_2.11-0.7.4.jar +chill-java-0.8.0.jar +chill_2.11-0.8.0.jar commons-beanutils-1.7.0.jar commons-beanutils-core-1.8.0.jar commons-cli-1.2.jar @@ -114,7 +114,7 @@ jsr305-1.3.9.jar jta-1.1.jar jtransforms-2.4.0.jar jul-to-slf4j-1.7.16.jar -kryo-2.21.jar +kryo-shaded-3.0.3.jar leveldbjni-all-1.8.jar libfb303-0.9.2.jar libthrift-0.9.2.jar @@ -126,11 +126,11 @@ metrics-core-3.1.2.jar metrics-graphite-3.1.2.jar metrics-json-3.1.2.jar metrics-jvm-3.1.2.jar -minlog-1.2.jar +minlog-1.3.0.jar mx4j-3.0.2.jar netty-3.8.0.Final.jar netty-all-4.0.29.Final.jar -objenesis-1.2.jar +objenesis-2.1.jar opencsv-2.3.jar oro-2.0.8.jar paranamer-2.6.jar @@ -148,7 +148,6 @@ pmml-schema-1.2.7.jar protobuf-java-2.5.0.jar py4j-0.9.2.jar pyrolite-4.9.jar -reflectasm-1.07-shaded.jar scala-compiler-2.11.8.jar scala-library-2.11.8.jar scala-parser-combinators_2.11-1.0.4.jar diff --git a/dev/deps/spark-deps-hadoop-2.4 b/dev/deps/spark-deps-hadoop-2.4 index 23ff5cfa2e..80fbaea222 100644 --- a/dev/deps/spark-deps-hadoop-2.4 +++ b/dev/deps/spark-deps-hadoop-2.4 @@ -22,8 +22,8 @@ breeze_2.11-0.11.2.jar calcite-avatica-1.2.0-incubating.jar calcite-core-1.2.0-incubating.jar calcite-linq4j-1.2.0-incubating.jar -chill-java-0.7.4.jar -chill_2.11-0.7.4.jar +chill-java-0.8.0.jar +chill_2.11-0.8.0.jar commons-beanutils-1.7.0.jar commons-beanutils-core-1.8.0.jar commons-cli-1.2.jar @@ -115,7 +115,7 @@ jsr305-1.3.9.jar jta-1.1.jar jtransforms-2.4.0.jar jul-to-slf4j-1.7.16.jar -kryo-2.21.jar +kryo-shaded-3.0.3.jar leveldbjni-all-1.8.jar libfb303-0.9.2.jar libthrift-0.9.2.jar @@ -127,11 +127,11 @@ metrics-core-3.1.2.jar metrics-graphite-3.1.2.jar metrics-json-3.1.2.jar metrics-jvm-3.1.2.jar -minlog-1.2.jar +minlog-1.3.0.jar mx4j-3.0.2.jar netty-3.8.0.Final.jar netty-all-4.0.29.Final.jar -objenesis-1.2.jar +objenesis-2.1.jar opencsv-2.3.jar oro-2.0.8.jar paranamer-2.6.jar @@ -149,7 +149,6 @@ pmml-schema-1.2.7.jar protobuf-java-2.5.0.jar py4j-0.9.2.jar pyrolite-4.9.jar -reflectasm-1.07-shaded.jar scala-compiler-2.11.8.jar scala-library-2.11.8.jar scala-parser-combinators_2.11-1.0.4.jar diff --git a/dev/deps/spark-deps-hadoop-2.6 b/dev/deps/spark-deps-hadoop-2.6 index 9b5a5643f3..b2c2a4caec 100644 --- a/dev/deps/spark-deps-hadoop-2.6 +++ b/dev/deps/spark-deps-hadoop-2.6 @@ -26,8 +26,8 @@ breeze_2.11-0.11.2.jar calcite-avatica-1.2.0-incubating.jar calcite-core-1.2.0-incubating.jar calcite-linq4j-1.2.0-incubating.jar -chill-java-0.7.4.jar -chill_2.11-0.7.4.jar +chill-java-0.8.0.jar +chill_2.11-0.8.0.jar commons-beanutils-1.7.0.jar commons-beanutils-core-1.8.0.jar commons-cli-1.2.jar @@ -121,7 +121,7 @@ jsr305-1.3.9.jar jta-1.1.jar jtransforms-2.4.0.jar jul-to-slf4j-1.7.16.jar -kryo-2.21.jar +kryo-shaded-3.0.3.jar leveldbjni-all-1.8.jar libfb303-0.9.2.jar libthrift-0.9.2.jar @@ -133,11 +133,11 @@ metrics-core-3.1.2.jar metrics-graphite-3.1.2.jar metrics-json-3.1.2.jar metrics-jvm-3.1.2.jar -minlog-1.2.jar +minlog-1.3.0.jar mx4j-3.0.2.jar netty-3.8.0.Final.jar netty-all-4.0.29.Final.jar -objenesis-1.2.jar +objenesis-2.1.jar opencsv-2.3.jar oro-2.0.8.jar paranamer-2.6.jar @@ -155,7 +155,6 @@ pmml-schema-1.2.7.jar protobuf-java-2.5.0.jar py4j-0.9.2.jar pyrolite-4.9.jar -reflectasm-1.07-shaded.jar scala-compiler-2.11.8.jar scala-library-2.11.8.jar scala-parser-combinators_2.11-1.0.4.jar diff --git a/dev/deps/spark-deps-hadoop-2.7 b/dev/deps/spark-deps-hadoop-2.7 index 1dca2fc55a..71e51883d5 100644 --- a/dev/deps/spark-deps-hadoop-2.7 +++ b/dev/deps/spark-deps-hadoop-2.7 @@ -26,8 +26,8 @@ breeze_2.11-0.11.2.jar calcite-avatica-1.2.0-incubating.jar calcite-core-1.2.0-incubating.jar calcite-linq4j-1.2.0-incubating.jar -chill-java-0.7.4.jar -chill_2.11-0.7.4.jar +chill-java-0.8.0.jar +chill_2.11-0.8.0.jar commons-beanutils-1.7.0.jar commons-beanutils-core-1.8.0.jar commons-cli-1.2.jar @@ -122,7 +122,7 @@ jsr305-1.3.9.jar jta-1.1.jar jtransforms-2.4.0.jar jul-to-slf4j-1.7.16.jar -kryo-2.21.jar +kryo-shaded-3.0.3.jar leveldbjni-all-1.8.jar libfb303-0.9.2.jar libthrift-0.9.2.jar @@ -134,11 +134,11 @@ metrics-core-3.1.2.jar metrics-graphite-3.1.2.jar metrics-json-3.1.2.jar metrics-jvm-3.1.2.jar -minlog-1.2.jar +minlog-1.3.0.jar mx4j-3.0.2.jar netty-3.8.0.Final.jar netty-all-4.0.29.Final.jar -objenesis-1.2.jar +objenesis-2.1.jar opencsv-2.3.jar oro-2.0.8.jar paranamer-2.6.jar @@ -156,7 +156,6 @@ pmml-schema-1.2.7.jar protobuf-java-2.5.0.jar py4j-0.9.2.jar pyrolite-4.9.jar -reflectasm-1.07-shaded.jar scala-compiler-2.11.8.jar scala-library-2.11.8.jar scala-parser-combinators_2.11-1.0.4.jar diff --git a/pom.xml b/pom.xml index f37a8988f7..3f9e4abc32 100644 --- a/pom.xml +++ b/pom.xml @@ -139,7 +139,7 @@ 1.6.0 8.1.14.v20131031 3.0.0.v201112011016 - 0.7.4 + 0.8.0 2.4.0 2.0.8 3.1.2 @@ -277,31 +277,11 @@ com.twitter chill_${scala.binary.version} ${chill.version} - - - org.ow2.asm - asm - - - org.ow2.asm - asm-commons - - com.twitter chill-java ${chill.version} - - - org.ow2.asm - asm - - - org.ow2.asm - asm-commons - - + + + 4.0.0 + + org.apache.spark + spark-parent_2.11 + 2.0.0-SNAPSHOT + ../pom.xml + + + org.apache.spark + spark-mllib-local_2.11 + + mllib-local + + jar + Spark Project ML Local Library + http://spark.apache.org/ + + + + org.apache.spark + spark-core_${scala.binary.version} + ${project.version} + test-jar + test + + + org.scalanlp + breeze_${scala.binary.version} + 0.11.2 + + + + junit + junit + + + org.apache.commons + commons-math3 + + + + + org.apache.commons + commons-math3 + + + org.scalacheck + scalacheck_${scala.binary.version} + test + + + org.mockito + mockito-core + test + + + + + netlib-lgpl + + + com.github.fommil.netlib + all + ${netlib.java.version} + pom + + + + + + target/scala-${scala.binary.version}/classes + target/scala-${scala.binary.version}/test-classes + + diff --git a/mllib-local/src/main/scala/org/apache/spark/ml/DummyTesting.scala b/mllib-local/src/main/scala/org/apache/spark/ml/DummyTesting.scala new file mode 100644 index 0000000000..6b3268cdfa --- /dev/null +++ b/mllib-local/src/main/scala/org/apache/spark/ml/DummyTesting.scala @@ -0,0 +1,23 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml + +// This is a private class testing if the new build works. To be removed soon. +private[ml] object DummyTesting { + private[ml] def add10(input: Double): Double = input + 10 +} diff --git a/mllib-local/src/test/scala/org/apache/spark/ml/DummyTestingSuite.scala b/mllib-local/src/test/scala/org/apache/spark/ml/DummyTestingSuite.scala new file mode 100644 index 0000000000..6c76dbfbfa --- /dev/null +++ b/mllib-local/src/test/scala/org/apache/spark/ml/DummyTestingSuite.scala @@ -0,0 +1,28 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml + +import org.apache.spark.SparkFunSuite + +// This is testing if the new build works. To be removed soon. +class DummyTestingSuite extends SparkFunSuite { + + test("This is testing if the new build works.") { + assert(DummyTesting.add10(15) === 25) + } +} diff --git a/mllib/pom.xml b/mllib/pom.xml index 428176dcbf..e56eafc300 100644 --- a/mllib/pom.xml +++ b/mllib/pom.xml @@ -62,6 +62,18 @@ spark-graphx_${scala.binary.version} ${project.version} + + org.apache.spark + spark-mllib-local_${scala.binary.version} + ${project.version} + + + org.apache.spark + spark-mllib-local_${scala.binary.version} + ${project.version} + test-jar + test + org.scalanlp breeze_${scala.binary.version} diff --git a/pom.xml b/pom.xml index 3f9e4abc32..58c85b1d36 100644 --- a/pom.xml +++ b/pom.xml @@ -94,6 +94,7 @@ core graphx mllib + mllib-local tools streaming sql/catalyst diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 60124ef0a1..c5688ecec6 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -47,9 +47,9 @@ object BuildCommons { ).map(ProjectRef(buildLocation, _)) val allProjects@Seq( - core, graphx, mllib, repl, networkCommon, networkShuffle, launcher, unsafe, testTags, sketch, _* + core, graphx, mllib, mllibLocal, repl, networkCommon, networkShuffle, launcher, unsafe, testTags, sketch, _* ) = Seq( - "core", "graphx", "mllib", "repl", "network-common", "network-shuffle", "launcher", "unsafe", + "core", "graphx", "mllib", "mllib-local", "repl", "network-common", "network-shuffle", "launcher", "unsafe", "test-tags", "sketch" ).map(ProjectRef(buildLocation, _)) ++ sqlProjects ++ streamingProjects @@ -254,7 +254,7 @@ object SparkBuild extends PomBuild { val mimaProjects = allProjects.filterNot { x => Seq( spark, hive, hiveThriftServer, catalyst, repl, networkCommon, networkShuffle, networkYarn, - unsafe, testTags, sketch + unsafe, testTags, sketch, mllibLocal ).contains(x) } -- cgit v1.2.3 From adb9d73cd6543c9edfc6b03a6d20061ff09c69f9 Mon Sep 17 00:00:00 2001 From: Zheng RuiFeng Date: Sat, 9 Apr 2016 11:25:39 -0700 Subject: [SPARK-14339][DOC] Add python examples for DCT,MinMaxScaler,MaxAbsScaler ## What changes were proposed in this pull request? add three python examples ## How was this patch tested? manual tests Author: Zheng RuiFeng Closes #12063 from zhengruifeng/dct_pe. --- docs/ml-features.md | 24 ++++++++++++ examples/src/main/python/ml/dct_example.py | 45 ++++++++++++++++++++++ .../src/main/python/ml/max_abs_scaler_example.py | 43 +++++++++++++++++++++ .../src/main/python/ml/min_max_scaler_example.py | 43 +++++++++++++++++++++ 4 files changed, 155 insertions(+) create mode 100644 examples/src/main/python/ml/dct_example.py create mode 100644 examples/src/main/python/ml/max_abs_scaler_example.py create mode 100644 examples/src/main/python/ml/min_max_scaler_example.py diff --git a/docs/ml-features.md b/docs/ml-features.md index 4fe8eefc26..5cc27d3565 100644 --- a/docs/ml-features.md +++ b/docs/ml-features.md @@ -413,6 +413,14 @@ for more details on the API. {% include_example java/org/apache/spark/examples/ml/JavaDCTExample.java %}
    + +
    + +Refer to the [DCT Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.DCT) +for more details on the API. + +{% include_example python/ml/dct_example.py %} +
    ## StringIndexer @@ -771,6 +779,14 @@ for more details on the API. {% include_example java/org/apache/spark/examples/ml/JavaMinMaxScalerExample.java %} + +
    + +Refer to the [MinMaxScaler Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.MinMaxScaler) +for more details on the API. + +{% include_example python/ml/min_max_scaler_example.py %} +
    @@ -803,6 +819,14 @@ for more details on the API. {% include_example java/org/apache/spark/examples/ml/JavaMaxAbsScalerExample.java %} + +
    + +Refer to the [MaxAbsScaler Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.MaxAbsScaler) +for more details on the API. + +{% include_example python/ml/max_abs_scaler_example.py %} +
    ## Bucketizer diff --git a/examples/src/main/python/ml/dct_example.py b/examples/src/main/python/ml/dct_example.py new file mode 100644 index 0000000000..264d47f404 --- /dev/null +++ b/examples/src/main/python/ml/dct_example.py @@ -0,0 +1,45 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +from pyspark import SparkContext +from pyspark.sql import SQLContext +# $example on$ +from pyspark.ml.feature import DCT +from pyspark.mllib.linalg import Vectors +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="DCTExample") + sqlContext = SQLContext(sc) + + # $example on$ + df = sqlContext.createDataFrame([ + (Vectors.dense([0.0, 1.0, -2.0, 3.0]),), + (Vectors.dense([-1.0, 2.0, 4.0, -7.0]),), + (Vectors.dense([14.0, -2.0, -5.0, 1.0]),)], ["features"]) + + dct = DCT(inverse=False, inputCol="features", outputCol="featuresDCT") + + dctDf = dct.transform(df) + + for dcts in dctDf.select("featuresDCT").take(3): + print(dcts) + # $example off$ + + sc.stop() diff --git a/examples/src/main/python/ml/max_abs_scaler_example.py b/examples/src/main/python/ml/max_abs_scaler_example.py new file mode 100644 index 0000000000..d9b69eef1c --- /dev/null +++ b/examples/src/main/python/ml/max_abs_scaler_example.py @@ -0,0 +1,43 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +from pyspark import SparkContext +from pyspark.sql import SQLContext +# $example on$ +from pyspark.ml.feature import MaxAbsScaler +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="MaxAbsScalerExample") + sqlContext = SQLContext(sc) + + # $example on$ + dataFrame = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") + + scaler = MaxAbsScaler(inputCol="features", outputCol="scaledFeatures") + + # Compute summary statistics and generate MaxAbsScalerModel + scalerModel = scaler.fit(dataFrame) + + # rescale each feature to range [-1, 1]. + scaledData = scalerModel.transform(dataFrame) + scaledData.show() + # $example off$ + + sc.stop() diff --git a/examples/src/main/python/ml/min_max_scaler_example.py b/examples/src/main/python/ml/min_max_scaler_example.py new file mode 100644 index 0000000000..2f8e4ade46 --- /dev/null +++ b/examples/src/main/python/ml/min_max_scaler_example.py @@ -0,0 +1,43 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +from pyspark import SparkContext +from pyspark.sql import SQLContext +# $example on$ +from pyspark.ml.feature import MinMaxScaler +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="MinMaxScalerExample") + sqlContext = SQLContext(sc) + + # $example on$ + dataFrame = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") + + scaler = MinMaxScaler(inputCol="features", outputCol="scaledFeatures") + + # Compute summary statistics and generate MinMaxScalerModel + scalerModel = scaler.fit(dataFrame) + + # rescale each feature to range [min, max]. + scaledData = scalerModel.transform(dataFrame) + scaledData.show() + # $example off$ + + sc.stop() -- cgit v1.2.3 From f7ec854f1b7f575c4c7437daf8e6992c684b6de2 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Sat, 9 Apr 2016 13:51:28 -0700 Subject: Revert "[SPARK-14419] [SQL] Improve HashedRelation for key fit within Long" This reverts commit 90c0a04506a4972b7a2ac2b7dda0c5f8509a6e2f. --- .../execution/aggregate/TungstenAggregate.scala | 3 +- .../sql/execution/joins/BroadcastHashJoin.scala | 18 +- .../spark/sql/execution/joins/HashJoin.scala | 31 +- .../spark/sql/execution/joins/HashedRelation.scala | 688 +++++++-------------- .../sql/execution/joins/ShuffledHashJoin.scala | 51 +- .../sql/execution/BenchmarkWholeStageCodegen.scala | 132 +--- .../apache/spark/sql/execution/ExchangeSuite.scala | 8 +- .../sql/execution/joins/HashedRelationSuite.scala | 48 +- 8 files changed, 346 insertions(+), 633 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala index 692fef703f..0a5a72c52a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala @@ -454,7 +454,7 @@ case class TungstenAggregate( val thisPlan = ctx.addReferenceObj("plan", this) hashMapTerm = ctx.freshName("hashMap") val hashMapClassName = classOf[UnsafeFixedWidthAggregationMap].getName - ctx.addMutableState(hashMapClassName, hashMapTerm, s"") + ctx.addMutableState(hashMapClassName, hashMapTerm, s"$hashMapTerm = $thisPlan.createHashMap();") sorterTerm = ctx.freshName("sorter") ctx.addMutableState(classOf[UnsafeKVExternalSorter].getName, sorterTerm, "") @@ -467,7 +467,6 @@ case class TungstenAggregate( s""" ${if (isAggregateHashMapSupported) aggregateHashMapGenerator.generate() else ""} private void $doAgg() throws java.io.IOException { - $hashMapTerm = $thisPlan.createHashMap(); ${child.asInstanceOf[CodegenSupport].produce(ctx, this)} $iterTerm = $thisPlan.finishAggregate($hashMapTerm, $sorterTerm); diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala index a8f854136c..e3d554c2de 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala @@ -27,7 +27,6 @@ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.physical.{BroadcastDistribution, Distribution, Partitioning, UnspecifiedDistribution} import org.apache.spark.sql.execution.{BinaryNode, CodegenSupport, SparkPlan} import org.apache.spark.sql.execution.metric.SQLMetrics -import org.apache.spark.sql.types.LongType /** * Performs an inner hash join of two child relations. When the output RDD of this operator is @@ -51,7 +50,10 @@ case class BroadcastHashJoin( override def outputPartitioning: Partitioning = streamedPlan.outputPartitioning override def requiredChildDistribution: Seq[Distribution] = { - val mode = HashedRelationBroadcastMode(buildKeys) + val mode = HashedRelationBroadcastMode( + canJoinKeyFitWithinLong, + rewriteKeyExpr(buildKeys), + buildPlan.output) buildSide match { case BuildLeft => BroadcastDistribution(mode) :: UnspecifiedDistribution :: Nil @@ -66,7 +68,7 @@ case class BroadcastHashJoin( val broadcastRelation = buildPlan.executeBroadcast[HashedRelation]() streamedPlan.execute().mapPartitions { streamedIter => val hashed = broadcastRelation.value.asReadOnlyCopy() - TaskContext.get().taskMetrics().incPeakExecutionMemory(hashed.estimatedSize) + TaskContext.get().taskMetrics().incPeakExecutionMemory(hashed.getMemorySize) join(streamedIter, hashed, numOutputRows) } } @@ -103,7 +105,7 @@ case class BroadcastHashJoin( ctx.addMutableState(clsName, relationTerm, s""" | $relationTerm = (($clsName) $broadcast.value()).asReadOnlyCopy(); - | incPeakExecutionMemory($relationTerm.estimatedSize()); + | incPeakExecutionMemory($relationTerm.getMemorySize()); """.stripMargin) (broadcastRelation, relationTerm) } @@ -116,13 +118,15 @@ case class BroadcastHashJoin( ctx: CodegenContext, input: Seq[ExprCode]): (ExprCode, String) = { ctx.currentVars = input - if (streamedKeys.length == 1 && streamedKeys.head.dataType == LongType) { + if (canJoinKeyFitWithinLong) { // generate the join key as Long - val ev = streamedKeys.head.gen(ctx) + val expr = rewriteKeyExpr(streamedKeys).head + val ev = BindReferences.bindReference(expr, streamedPlan.output).gen(ctx) (ev, ev.isNull) } else { // generate the join key as UnsafeRow - val ev = GenerateUnsafeProjection.createCode(ctx, streamedKeys) + val keyExpr = streamedKeys.map(BindReferences.bindReference(_, streamedPlan.output)) + val ev = GenerateUnsafeProjection.createCode(ctx, keyExpr) (ev, s"${ev.value}.anyNull()") } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala index 4c912d371e..8f45d57126 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala @@ -59,13 +59,9 @@ trait HashJoin { case BuildRight => (right, left) } - protected lazy val (buildKeys, streamedKeys) = { - val lkeys = rewriteKeyExpr(leftKeys).map(BindReferences.bindReference(_, left.output)) - val rkeys = rewriteKeyExpr(rightKeys).map(BindReferences.bindReference(_, right.output)) - buildSide match { - case BuildLeft => (lkeys, rkeys) - case BuildRight => (rkeys, lkeys) - } + protected lazy val (buildKeys, streamedKeys) = buildSide match { + case BuildLeft => (leftKeys, rightKeys) + case BuildRight => (rightKeys, leftKeys) } /** @@ -88,8 +84,17 @@ trait HashJoin { width = dt.defaultSize } else { val bits = dt.defaultSize * 8 + // hashCode of Long is (l >> 32) ^ l.toInt, it means the hash code of an long with same + // value in high 32 bit and low 32 bit will be 0. To avoid the worst case that keys + // with two same ints have hash code 0, we rotate the bits of second one. + val rotated = if (e.dataType == IntegerType) { + // (e >>> 15) | (e << 17) + BitwiseOr(ShiftRightUnsigned(e, Literal(15)), ShiftLeft(e, Literal(17))) + } else { + e + } keyExpr = BitwiseOr(ShiftLeft(keyExpr, Literal(bits)), - BitwiseAnd(Cast(e, LongType), Literal((1L << bits) - 1))) + BitwiseAnd(Cast(rotated, LongType), Literal((1L << bits) - 1))) width -= bits } // TODO: support BooleanType, DateType and TimestampType @@ -100,11 +105,17 @@ trait HashJoin { keyExpr :: Nil } + protected lazy val canJoinKeyFitWithinLong: Boolean = { + val sameTypes = buildKeys.map(_.dataType) == streamedKeys.map(_.dataType) + val key = rewriteKeyExpr(buildKeys) + sameTypes && key.length == 1 && key.head.dataType.isInstanceOf[LongType] + } + protected def buildSideKeyGenerator(): Projection = - UnsafeProjection.create(buildKeys) + UnsafeProjection.create(rewriteKeyExpr(buildKeys), buildPlan.output) protected def streamSideKeyGenerator(): UnsafeProjection = - UnsafeProjection.create(streamedKeys) + UnsafeProjection.create(rewriteKeyExpr(streamedKeys), streamedPlan.output) @transient private[this] lazy val boundCondition = if (condition.isDefined) { newPredicate(condition.get, streamedPlan.output ++ buildPlan.output) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala index 4959f60dab..5ccb435686 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala @@ -18,22 +18,24 @@ package org.apache.spark.sql.execution.joins import java.io.{Externalizable, IOException, ObjectInput, ObjectOutput} +import java.util.{HashMap => JavaHashMap} -import org.apache.spark.{SparkConf, SparkEnv, SparkException} -import org.apache.spark.memory.{MemoryConsumer, MemoryMode, StaticMemoryManager, TaskMemoryManager} +import org.apache.spark.{SparkConf, SparkEnv, SparkException, TaskContext} +import org.apache.spark.memory.{StaticMemoryManager, TaskMemoryManager} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical.BroadcastMode -import org.apache.spark.sql.types.LongType +import org.apache.spark.sql.execution.SparkSqlSerializer import org.apache.spark.unsafe.Platform import org.apache.spark.unsafe.map.BytesToBytesMap import org.apache.spark.util.{KnownSizeEstimation, Utils} +import org.apache.spark.util.collection.CompactBuffer /** * Interface for a hashed relation by some key. Use [[HashedRelation.apply]] to create a concrete * object. */ -private[execution] sealed trait HashedRelation extends KnownSizeEstimation { +private[execution] sealed trait HashedRelation { /** * Returns matched rows. * @@ -72,36 +74,51 @@ private[execution] sealed trait HashedRelation extends KnownSizeEstimation { */ def asReadOnlyCopy(): HashedRelation + /** + * Returns the size of used memory. + */ + def getMemorySize: Long = 1L // to make the test happy + /** * Release any used resources. */ - def close(): Unit + def close(): Unit = {} + + // This is a helper method to implement Externalizable, and is used by + // GeneralHashedRelation and UniqueKeyHashedRelation + protected def writeBytes(out: ObjectOutput, serialized: Array[Byte]): Unit = { + out.writeInt(serialized.length) // Write the length of serialized bytes first + out.write(serialized) + } + + // This is a helper method to implement Externalizable, and is used by + // GeneralHashedRelation and UniqueKeyHashedRelation + protected def readBytes(in: ObjectInput): Array[Byte] = { + val serializedSize = in.readInt() // Read the length of serialized bytes first + val bytes = new Array[Byte](serializedSize) + in.readFully(bytes) + bytes + } } private[execution] object HashedRelation { /** * Create a HashedRelation from an Iterator of InternalRow. + * + * Note: The caller should make sure that these InternalRow are different objects. */ def apply( + canJoinKeyFitWithinLong: Boolean, input: Iterator[InternalRow], - key: Seq[Expression], - sizeEstimate: Int = 64, - taskMemoryManager: TaskMemoryManager = null): HashedRelation = { - val mm = Option(taskMemoryManager).getOrElse { - new TaskMemoryManager( - new StaticMemoryManager( - new SparkConf().set("spark.memory.offHeap.enabled", "false"), - Long.MaxValue, - Long.MaxValue, - 1), - 0) - } + keyGenerator: Projection, + sizeEstimate: Int = 64): HashedRelation = { - if (key.length == 1 && key.head.dataType == LongType) { - LongHashedRelation(input, key, sizeEstimate, mm) + if (canJoinKeyFitWithinLong) { + LongHashedRelation(input, keyGenerator, sizeEstimate) } else { - UnsafeHashedRelation(input, key, sizeEstimate, mm) + UnsafeHashedRelation( + input, keyGenerator.asInstanceOf[UnsafeProjection], sizeEstimate) } } } @@ -116,7 +133,7 @@ private[execution] object HashedRelation { private[joins] class UnsafeHashedRelation( private var numFields: Int, private var binaryMap: BytesToBytesMap) - extends HashedRelation with Externalizable { + extends HashedRelation with KnownSizeEstimation with Externalizable { private[joins] def this() = this(0, null) // Needed for serialization @@ -125,6 +142,10 @@ private[joins] class UnsafeHashedRelation( override def asReadOnlyCopy(): UnsafeHashedRelation = new UnsafeHashedRelation(numFields, binaryMap) + override def getMemorySize: Long = { + binaryMap.getTotalMemoryConsumption + } + override def estimatedSize: Long = { binaryMap.getTotalMemoryConsumption } @@ -255,10 +276,20 @@ private[joins] object UnsafeHashedRelation { def apply( input: Iterator[InternalRow], - key: Seq[Expression], - sizeEstimate: Int, - taskMemoryManager: TaskMemoryManager): HashedRelation = { + keyGenerator: UnsafeProjection, + sizeEstimate: Int): HashedRelation = { + val taskMemoryManager = if (TaskContext.get() != null) { + TaskContext.get().taskMemoryManager() + } else { + new TaskMemoryManager( + new StaticMemoryManager( + new SparkConf().set("spark.memory.offHeap.enabled", "false"), + Long.MaxValue, + Long.MaxValue, + 1), + 0) + } val pageSizeBytes = Option(SparkEnv.get).map(_.memoryManager.pageSizeBytes) .getOrElse(new SparkConf().getSizeAsBytes("spark.buffer.pageSize", "16m")) @@ -269,7 +300,6 @@ private[joins] object UnsafeHashedRelation { pageSizeBytes) // Create a mapping of buildKeys -> rows - val keyGenerator = UnsafeProjection.create(key) var numFields = 0 while (input.hasNext) { val row = input.next().asInstanceOf[UnsafeRow] @@ -291,471 +321,144 @@ private[joins] object UnsafeHashedRelation { } } -private[joins] object LongToUnsafeRowMap { - // the largest prime that below 2^n - val LARGEST_PRIMES = { - // https://primes.utm.edu/lists/2small/0bit.html - val diffs = Seq( - 0, 1, 1, 3, 1, 3, 1, 5, - 3, 3, 9, 3, 1, 3, 19, 15, - 1, 5, 1, 3, 9, 3, 15, 3, - 39, 5, 39, 57, 3, 35, 1, 5 - ) - val primes = new Array[Int](32) - primes(0) = 1 - var power2 = 1 - (1 until 32).foreach { i => - power2 *= 2 - primes(i) = power2 - diffs(i) - } - primes - } -} - /** - * An append-only hash map mapping from key of Long to UnsafeRow. - * - * The underlying bytes of all values (UnsafeRows) are packed together as a single byte array - * (`page`) in this format: - * - * [bytes of row1][address1][bytes of row2][address1] ... - * - * address1 (8 bytes) is the offset and size of next value for the same key as row1, any key - * could have multiple values. the address at the end of last value for every key is 0. - * - * The keys and addresses of their values could be stored in two modes: - * - * 1) sparse mode: the keys and addresses are stored in `array` as: - * - * [key1][address1][key2][address2]...[] - * - * address1 (Long) is the offset (in `page`) and size of the value for key1. The position of key1 - * is determined by `key1 % cap`. Quadratic probing with triangular numbers is used to address - * hash collision. - * - * 2) dense mode: all the addresses are packed into a single array of long, as: - * - * [address1] [address2] ... - * - * address1 (Long) is the offset (in `page`) and size of the value for key1, the position is - * determined by `key1 - minKey`. - * - * The map is created as sparse mode, then key-value could be appended into it. Once finish - * appending, caller could all optimize() to try to turn the map into dense mode, which is faster - * to probe. + * An interface for a hashed relation that the key is a Long. */ -private[execution] final class LongToUnsafeRowMap(var mm: TaskMemoryManager, capacity: Int) - extends MemoryConsumer(mm) with Externalizable { - import org.apache.spark.sql.execution.joins.LongToUnsafeRowMap._ - - // Whether the keys are stored in dense mode or not. - private var isDense = false - - // The minimum value of keys. - private var minKey = Long.MaxValue - - // The Maxinum value of keys. - private var maxKey = Long.MinValue - - // Sparse mode: the actual capacity of map, is a prime number. - private var cap: Int = 0 - - // The array to store the key and offset of UnsafeRow in the page. - // - // Sparse mode: [key1] [offset1 | size1] [key2] [offset | size2] ... - // Dense mode: [offset1 | size1] [offset2 | size2] - private var array: Array[Long] = null - - // The page to store all bytes of UnsafeRow and the pointer to next rows. - // [row1][pointer1] [row2][pointer2] - private var page: Array[Byte] = null - - // Current write cursor in the page. - private var cursor = Platform.BYTE_ARRAY_OFFSET - - // The total number of values of all keys. - private var numValues = 0 - - // The number of unique keys. - private var numKeys = 0 - - // needed by serializer - def this() = { - this( - new TaskMemoryManager( - new StaticMemoryManager( - new SparkConf().set("spark.memory.offHeap.enabled", "false"), - Long.MaxValue, - Long.MaxValue, - 1), - 0), - 0) - } - - private def acquireMemory(size: Long): Unit = { - // do not support spilling - val got = mm.acquireExecutionMemory(size, MemoryMode.ON_HEAP, this) - if (got < size) { - mm.releaseExecutionMemory(got, MemoryMode.ON_HEAP, this) - throw new SparkException(s"Can't acquire $size bytes memory to build hash relation") - } - } - - private def freeMemory(size: Long): Unit = { - mm.releaseExecutionMemory(size, MemoryMode.ON_HEAP, this) - } - - private def init(): Unit = { - if (mm != null) { - cap = LARGEST_PRIMES.find(_ > capacity).getOrElse{ - sys.error(s"Can't create map with capacity $capacity") - } - acquireMemory(cap * 2 * 8 + (1 << 20)) - array = new Array[Long](cap * 2) - page = new Array[Byte](1 << 20) // 1M bytes - } - } - - init() - - def spill(size: Long, trigger: MemoryConsumer): Long = { - 0L - } - - /** - * Returns whether all the keys are unique. - */ - def keyIsUnique: Boolean = numKeys == numValues - - /** - * Returns total memory consumption. - */ - def getTotalMemoryConsumption: Long = { - array.length * 8 + page.length - } - - /** - * Returns the slot of array that store the keys (sparse mode). - */ - private def getSlot(key: Long): Int = { - var s = (key % cap).toInt - if (s < 0) { - s += cap - } - s * 2 - } - - private def getRow(address: Long, resultRow: UnsafeRow): UnsafeRow = { - val offset = address >>> 32 - val size = address & 0xffffffffL - resultRow.pointTo(page, offset, size.toInt) - resultRow +private[joins] trait LongHashedRelation extends HashedRelation { + override def get(key: InternalRow): Iterator[InternalRow] = { + get(key.getLong(0)) } - - /** - * Returns the single UnsafeRow for given key, or null if not found. - */ - def getValue(key: Long, resultRow: UnsafeRow): UnsafeRow = { - if (isDense) { - val idx = (key - minKey).toInt - if (idx >= 0 && key <= maxKey && array(idx) > 0) { - return getRow(array(idx), resultRow) - } - } else { - var pos = getSlot(key) - var step = 1 - while (array(pos + 1) != 0) { - if (array(pos) == key) { - return getRow(array(pos + 1), resultRow) - } - pos += 2 * step - step += 1 - if (pos >= array.length) { - pos -= array.length - } - } - } - null + override def getValue(key: InternalRow): InternalRow = { + getValue(key.getLong(0)) } +} - /** - * Returns an interator of UnsafeRow for multiple linked values. - */ - private def valueIter(address: Long, resultRow: UnsafeRow): Iterator[UnsafeRow] = { - new Iterator[UnsafeRow] { - var addr = address - override def hasNext: Boolean = addr != 0 - override def next(): UnsafeRow = { - val offset = addr >>> 32 - val size = addr & 0xffffffffL - resultRow.pointTo(page, offset, size.toInt) - addr = Platform.getLong(page, offset + size) - resultRow - } - } - } +private[joins] final class GeneralLongHashedRelation( + private var hashTable: JavaHashMap[Long, CompactBuffer[UnsafeRow]]) + extends LongHashedRelation with Externalizable { - /** - * Returns an iterator for all the values for the given key, or null if no value found. - */ - def get(key: Long, resultRow: UnsafeRow): Iterator[UnsafeRow] = { - if (isDense) { - val idx = (key - minKey).toInt - if (idx >=0 && key <= maxKey && array(idx) > 0) { - return valueIter(array(idx), resultRow) - } - } else { - var pos = getSlot(key) - var step = 1 - while (array(pos + 1) != 0) { - if (array(pos) == key) { - return valueIter(array(pos + 1), resultRow) - } - pos += 2 * step - step += 1 - if (pos >= array.length) { - pos -= array.length - } - } - } - null - } - - /** - * Appends the key and row into this map. - */ - def append(key: Long, row: UnsafeRow): Unit = { - if (key < minKey) { - minKey = key - } - if (key > maxKey) { - maxKey = key - } + // Needed for serialization (it is public to make Java serialization work) + def this() = this(null) - // There is 8 bytes for the pointer to next value - if (cursor + 8 + row.getSizeInBytes > page.length + Platform.BYTE_ARRAY_OFFSET) { - val used = page.length - if (used * 2L > (1L << 31)) { - sys.error("Can't allocate a page that is larger than 2G") - } - acquireMemory(used * 2) - val newPage = new Array[Byte](used * 2) - System.arraycopy(page, 0, newPage, 0, cursor - Platform.BYTE_ARRAY_OFFSET) - page = newPage - freeMemory(used) - } + override def keyIsUnique: Boolean = false - // copy the bytes of UnsafeRow - val offset = cursor - Platform.copyMemory(row.getBaseObject, row.getBaseOffset, page, cursor, row.getSizeInBytes) - cursor += row.getSizeInBytes - Platform.putLong(page, cursor, 0) - cursor += 8 - numValues += 1 - updateIndex(key, (offset.toLong << 32) | row.getSizeInBytes) - } + override def asReadOnlyCopy(): GeneralLongHashedRelation = + new GeneralLongHashedRelation(hashTable) - /** - * Update the address in array for given key. - */ - private def updateIndex(key: Long, address: Long): Unit = { - var pos = getSlot(key) - var step = 1 - while (array(pos + 1) != 0 && array(pos) != key) { - pos += 2 * step - step += 1 - if (pos >= array.length) { - pos -= array.length - } - } - if (array(pos + 1) == 0) { - // this is the first value for this key, put the address in array. - array(pos) = key - array(pos + 1) = address - numKeys += 1 - if (numKeys * 2 > cap) { - // reach half of the capacity - growArray() - } + override def get(key: Long): Iterator[InternalRow] = { + val rows = hashTable.get(key) + if (rows != null) { + rows.toIterator } else { - // there is another value for this key, put the address at the end of final value. - var addr = array(pos + 1) - var pointer = (addr >>> 32) + (addr & 0xffffffffL) - while (Platform.getLong(page, pointer) != 0) { - addr = Platform.getLong(page, pointer) - pointer = (addr >>> 32) + (addr & 0xffffffffL) - } - Platform.putLong(page, pointer, address) - } - } - - private def growArray(): Unit = { - val old_cap = cap - var old_array = array - cap = LARGEST_PRIMES.find(_ > cap).getOrElse{ - sys.error(s"Can't grow map any more than $cap") - } - numKeys = 0 - acquireMemory(cap * 2 * 8) - array = new Array[Long](cap * 2) - var i = 0 - while (i < old_array.length) { - if (old_array(i + 1) > 0) { - updateIndex(old_array(i), old_array(i + 1)) - } - i += 2 - } - old_array = null // release the reference to old array - freeMemory(old_cap * 2 * 8) - } - - /** - * Try to turn the map into dense mode, which is faster to probe. - */ - def optimize(): Unit = { - val range = maxKey - minKey - // Convert to dense mode if it does not require more memory or could fit within L1 cache - if (range < array.length || range < 1024) { - try { - acquireMemory((range + 1) * 8) - } catch { - case e: SparkException => - // there is no enough memory to convert - return - } - val denseArray = new Array[Long]((range + 1).toInt) - var i = 0 - while (i < array.length) { - if (array(i + 1) > 0) { - val idx = (array(i) - minKey).toInt - denseArray(idx) = array(i + 1) - } - i += 2 - } - val old_length = array.length - array = denseArray - isDense = true - freeMemory(old_length * 8) - } - } - - /** - * Free all the memory acquired by this map. - */ - def free(): Unit = { - if (page != null) { - freeMemory(page.length) - page = null - } - if (array != null) { - freeMemory(array.length * 8) - array = null + null } } override def writeExternal(out: ObjectOutput): Unit = { - out.writeBoolean(isDense) - out.writeLong(minKey) - out.writeLong(maxKey) - out.writeInt(numKeys) - out.writeInt(numValues) - out.writeInt(cap) - - out.writeInt(array.length) - val buffer = new Array[Byte](4 << 10) - var offset = Platform.LONG_ARRAY_OFFSET - val end = array.length * 8 + Platform.LONG_ARRAY_OFFSET - while (offset < end) { - val size = Math.min(buffer.length, end - offset) - Platform.copyMemory(array, offset, buffer, Platform.BYTE_ARRAY_OFFSET, size) - out.write(buffer, 0, size) - offset += size - } - - val used = cursor - Platform.BYTE_ARRAY_OFFSET - out.writeInt(used) - out.write(page, 0, used) + writeBytes(out, SparkSqlSerializer.serialize(hashTable)) } override def readExternal(in: ObjectInput): Unit = { - isDense = in.readBoolean() - minKey = in.readLong() - maxKey = in.readLong() - numKeys = in.readInt() - numValues = in.readInt() - cap = in.readInt() - - val length = in.readInt() - array = new Array[Long](length) - val buffer = new Array[Byte](4 << 10) - var offset = Platform.LONG_ARRAY_OFFSET - val end = length * 8 + Platform.LONG_ARRAY_OFFSET - while (offset < end) { - val size = Math.min(buffer.length, end - offset) - in.readFully(buffer, 0, size) - Platform.copyMemory(buffer, Platform.BYTE_ARRAY_OFFSET, array, offset, size) - offset += size - } - - val numBytes = in.readInt() - page = new Array[Byte](numBytes) - in.readFully(page) + hashTable = SparkSqlSerializer.deserialize(readBytes(in)) } } -private[joins] class LongHashedRelation( - private var nFields: Int, - private var map: LongToUnsafeRowMap) extends HashedRelation with Externalizable { - - private var resultRow: UnsafeRow = new UnsafeRow(nFields) +/** + * A relation that pack all the rows into a byte array, together with offsets and sizes. + * + * All the bytes of UnsafeRow are packed together as `bytes`: + * + * [ Row0 ][ Row1 ][] ... [ RowN ] + * + * With keys: + * + * start start+1 ... start+N + * + * `offsets` are offsets of UnsafeRows in the `bytes` + * `sizes` are the numbers of bytes of UnsafeRows, 0 means no row for this key. + * + * For example, two UnsafeRows (24 bytes and 32 bytes), with keys as 3 and 5 will stored as: + * + * start = 3 + * offsets = [0, 0, 24] + * sizes = [24, 0, 32] + * bytes = [0 - 24][][24 - 56] + */ +private[joins] final class LongArrayRelation( + private var numFields: Int, + private var start: Long, + private var offsets: Array[Int], + private var sizes: Array[Int], + private var bytes: Array[Byte] + ) extends LongHashedRelation with Externalizable { // Needed for serialization (it is public to make Java serialization work) - def this() = this(0, null) + def this() = this(0, 0L, null, null, null) - override def asReadOnlyCopy(): LongHashedRelation = new LongHashedRelation(nFields, map) + override def keyIsUnique: Boolean = true - override def estimatedSize: Long = { - map.getTotalMemoryConsumption + override def asReadOnlyCopy(): LongArrayRelation = { + new LongArrayRelation(numFields, start, offsets, sizes, bytes) } - override def get(key: InternalRow): Iterator[InternalRow] = { - if (key.isNullAt(0)) { - null - } else { - get(key.getLong(0)) - } + override def getMemorySize: Long = { + offsets.length * 4 + sizes.length * 4 + bytes.length } - override def getValue(key: InternalRow): InternalRow = { - if (key.isNullAt(0)) { - null + override def get(key: Long): Iterator[InternalRow] = { + val row = getValue(key) + if (row != null) { + Seq(row).toIterator } else { - getValue(key.getLong(0)) + null } } - override def get(key: Long): Iterator[InternalRow] = - map.get(key, resultRow) - + var resultRow = new UnsafeRow(numFields) override def getValue(key: Long): InternalRow = { - map.getValue(key, resultRow) - } - - override def keyIsUnique: Boolean = map.keyIsUnique - - override def close(): Unit = { - map.free() + val idx = (key - start).toInt + if (idx >= 0 && idx < sizes.length && sizes(idx) > 0) { + resultRow.pointTo(bytes, Platform.BYTE_ARRAY_OFFSET + offsets(idx), sizes(idx)) + resultRow + } else { + null + } } override def writeExternal(out: ObjectOutput): Unit = { - out.writeInt(nFields) - out.writeObject(map) + out.writeInt(numFields) + out.writeLong(start) + out.writeInt(sizes.length) + var i = 0 + while (i < sizes.length) { + out.writeInt(sizes(i)) + i += 1 + } + out.writeInt(bytes.length) + out.write(bytes) } override def readExternal(in: ObjectInput): Unit = { - nFields = in.readInt() - resultRow = new UnsafeRow(nFields) - map = in.readObject().asInstanceOf[LongToUnsafeRowMap] + numFields = in.readInt() + resultRow = new UnsafeRow(numFields) + start = in.readLong() + val length = in.readInt() + // read sizes of rows + sizes = new Array[Int](length) + offsets = new Array[Int](length) + var i = 0 + var offset = 0 + while (i < length) { + offsets(i) = offset + sizes(i) = in.readInt() + offset += sizes(i) + i += 1 + } + // read all the bytes + val total = in.readInt() + assert(total == offset) + bytes = new Array[Byte](total) + in.readFully(bytes) } } @@ -763,45 +466,96 @@ private[joins] class LongHashedRelation( * Create hashed relation with key that is long. */ private[joins] object LongHashedRelation { + + val DENSE_FACTOR = 0.2 + def apply( - input: Iterator[InternalRow], - key: Seq[Expression], - sizeEstimate: Int, - taskMemoryManager: TaskMemoryManager): LongHashedRelation = { + input: Iterator[InternalRow], + keyGenerator: Projection, + sizeEstimate: Int): HashedRelation = { - val map: LongToUnsafeRowMap = new LongToUnsafeRowMap(taskMemoryManager, sizeEstimate) - val keyGenerator = UnsafeProjection.create(key) + // TODO: use LongToBytesMap for better memory efficiency + val hashTable = new JavaHashMap[Long, CompactBuffer[UnsafeRow]](sizeEstimate) // Create a mapping of key -> rows var numFields = 0 + var keyIsUnique = true + var minKey = Long.MaxValue + var maxKey = Long.MinValue while (input.hasNext) { val unsafeRow = input.next().asInstanceOf[UnsafeRow] numFields = unsafeRow.numFields() val rowKey = keyGenerator(unsafeRow) - if (!rowKey.isNullAt(0)) { + if (!rowKey.anyNull) { val key = rowKey.getLong(0) - map.append(key, unsafeRow) + minKey = math.min(minKey, key) + maxKey = math.max(maxKey, key) + val existingMatchList = hashTable.get(key) + val matchList = if (existingMatchList == null) { + val newMatchList = new CompactBuffer[UnsafeRow]() + hashTable.put(key, newMatchList) + newMatchList + } else { + keyIsUnique = false + existingMatchList + } + matchList += unsafeRow + } + } + + if (keyIsUnique && hashTable.size() > (maxKey - minKey) * DENSE_FACTOR) { + // The keys are dense enough, so use LongArrayRelation + val length = (maxKey - minKey).toInt + 1 + val sizes = new Array[Int](length) + val offsets = new Array[Int](length) + var offset = 0 + var i = 0 + while (i < length) { + val rows = hashTable.get(i + minKey) + if (rows != null) { + offsets(i) = offset + sizes(i) = rows(0).getSizeInBytes + offset += sizes(i) + } + i += 1 + } + val bytes = new Array[Byte](offset) + i = 0 + while (i < length) { + val rows = hashTable.get(i + minKey) + if (rows != null) { + rows(0).writeToMemory(bytes, Platform.BYTE_ARRAY_OFFSET + offsets(i)) + } + i += 1 } + new LongArrayRelation(numFields, minKey, offsets, sizes, bytes) + } else { + new GeneralLongHashedRelation(hashTable) } - map.optimize() - new LongHashedRelation(numFields, map) } } /** The HashedRelationBroadcastMode requires that rows are broadcasted as a HashedRelation. */ -private[execution] case class HashedRelationBroadcastMode(key: Seq[Expression]) - extends BroadcastMode { +private[execution] case class HashedRelationBroadcastMode( + canJoinKeyFitWithinLong: Boolean, + keys: Seq[Expression], + attributes: Seq[Attribute]) extends BroadcastMode { override def transform(rows: Array[InternalRow]): HashedRelation = { - HashedRelation(rows.iterator, canonicalizedKey, rows.length) + val generator = UnsafeProjection.create(keys, attributes) + HashedRelation(canJoinKeyFitWithinLong, rows.iterator, generator, rows.length) } - private lazy val canonicalizedKey: Seq[Expression] = { - key.map { e => e.canonicalized } + private lazy val canonicalizedKeys: Seq[Expression] = { + keys.map { e => + BindReferences.bindReference(e.canonicalized, attributes) + } } override def compatibleWith(other: BroadcastMode): Boolean = other match { - case m: HashedRelationBroadcastMode => canonicalizedKey == m.canonicalizedKey + case m: HashedRelationBroadcastMode => + canJoinKeyFitWithinLong == m.canJoinKeyFitWithinLong && + canonicalizedKeys == m.canonicalizedKeys case _ => false } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala index 0c3e3c3fc1..bf86096379 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala @@ -17,10 +17,11 @@ package org.apache.spark.sql.execution.joins -import org.apache.spark.TaskContext +import org.apache.spark.{SparkException, TaskContext} +import org.apache.spark.memory.MemoryMode import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{BindReferences, Expression, UnsafeRow} +import org.apache.spark.sql.catalyst.expressions.{Expression, JoinedRow, UnsafeRow} import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} @@ -56,20 +57,54 @@ case class ShuffledHashJoin( override def requiredChildDistribution: Seq[Distribution] = ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil - private def buildHashedRelation(iter: Iterator[InternalRow]): HashedRelation = { + private def buildHashedRelation(iter: Iterator[UnsafeRow]): HashedRelation = { val context = TaskContext.get() - val relation = HashedRelation(iter, buildKeys, taskMemoryManager = context.taskMemoryManager()) - // This relation is usually used until the end of task. + if (!canJoinKeyFitWithinLong) { + // build BytesToBytesMap + val relation = HashedRelation(canJoinKeyFitWithinLong, iter, buildSideKeyGenerator) + // This relation is usually used until the end of task. + context.addTaskCompletionListener((t: TaskContext) => + relation.close() + ) + return relation + } + + // try to acquire some memory for the hash table, it could trigger other operator to free some + // memory. The memory acquired here will mostly be used until the end of task. + val memoryManager = context.taskMemoryManager() + var acquired = 0L + var used = 0L context.addTaskCompletionListener((t: TaskContext) => - relation.close() + memoryManager.releaseExecutionMemory(acquired, MemoryMode.ON_HEAP, null) ) - relation + + val copiedIter = iter.map { row => + // It's hard to guess what's exactly memory will be used, we have a rough guess here. + // TODO: use LongToBytesMap instead of HashMap for memory efficiency + // Each pair in HashMap will have UnsafeRow, CompactBuffer, maybe 10+ pointers + val needed = 150 + row.getSizeInBytes + if (needed > acquired - used) { + val got = memoryManager.acquireExecutionMemory( + Math.max(memoryManager.pageSizeBytes(), needed), MemoryMode.ON_HEAP, null) + acquired += got + if (got < needed) { + throw new SparkException("Can't acquire enough memory to build hash map in shuffled" + + "hash join, please use sort merge join by setting " + + "spark.sql.join.preferSortMergeJoin=true") + } + } + used += needed + // HashedRelation requires that the UnsafeRow should be separate objects. + row.copy() + } + + HashedRelation(canJoinKeyFitWithinLong, copiedIter, buildSideKeyGenerator) } protected override def doExecute(): RDD[InternalRow] = { val numOutputRows = longMetric("numOutputRows") streamedPlan.execute().zipPartitions(buildPlan.execute()) { (streamIter, buildIter) => - val hashed = buildHashedRelation(buildIter) + val hashed = buildHashedRelation(buildIter.asInstanceOf[Iterator[UnsafeRow]]) join(streamIter, hashed, numOutputRows) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala index 352fd07d0e..5dbf619876 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala @@ -21,7 +21,6 @@ import java.util.HashMap import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} import org.apache.spark.memory.{StaticMemoryManager, TaskMemoryManager} -import org.apache.spark.sql.execution.joins.LongToUnsafeRowMap import org.apache.spark.sql.SQLContext import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.execution.vectorized.AggregateHashMap @@ -180,8 +179,8 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite { Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz Join w long: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------- - Join w long codegen=false 3002 / 3262 7.0 143.2 1.0X - Join w long codegen=true 321 / 371 65.3 15.3 9.3X + Join w long codegen=false 5351 / 5531 3.9 255.1 1.0X + Join w long codegen=true 275 / 352 76.2 13.1 19.4X */ runBenchmark("Join w long duplicated", N) { @@ -194,8 +193,8 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite { Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz Join w long duplicated: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------- - Join w long duplicated codegen=false 3446 / 3478 6.1 164.3 1.0X - Join w long duplicated codegen=true 322 / 351 65.2 15.3 10.7X + Join w long duplicated codegen=false 4752 / 4906 4.4 226.6 1.0X + Join w long duplicated codegen=true 722 / 760 29.0 34.4 6.6X */ val dim2 = broadcast(sqlContext.range(M) @@ -212,8 +211,8 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite { Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz Join w 2 ints: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------- - Join w 2 ints codegen=false 4426 / 4501 4.7 211.1 1.0X - Join w 2 ints codegen=true 791 / 818 26.5 37.7 5.6X + Join w 2 ints codegen=false 9011 / 9121 2.3 429.7 1.0X + Join w 2 ints codegen=true 2565 / 2816 8.2 122.3 3.5X */ val dim3 = broadcast(sqlContext.range(M) @@ -260,8 +259,8 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite { Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz outer join w long: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------- - outer join w long codegen=false 3055 / 3189 6.9 145.7 1.0X - outer join w long codegen=true 261 / 276 80.5 12.4 11.7X + outer join w long codegen=false 5667 / 5780 3.7 270.2 1.0X + outer join w long codegen=true 216 / 226 97.2 10.3 26.3X */ runBenchmark("semi join w long", N) { @@ -273,8 +272,8 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite { Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz semi join w long: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------- - semi join w long codegen=false 1912 / 1990 11.0 91.2 1.0X - semi join w long codegen=true 237 / 244 88.3 11.3 8.1X + semi join w long codegen=false 4690 / 4953 4.5 223.7 1.0X + semi join w long codegen=true 211 / 229 99.2 10.1 22.2X */ } @@ -327,8 +326,8 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite { Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz shuffle hash join: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------- - shuffle hash join codegen=false 1101 / 1391 3.8 262.6 1.0X - shuffle hash join codegen=true 528 / 578 7.9 125.8 2.1X + shuffle hash join codegen=false 1538 / 1742 2.7 366.7 1.0X + shuffle hash join codegen=true 892 / 1329 4.7 212.6 1.7X */ } @@ -350,11 +349,11 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite { } ignore("hash and BytesToBytesMap") { - val N = 20 << 20 + val N = 10 << 20 val benchmark = new Benchmark("BytesToBytesMap", N) - benchmark.addCase("UnsafeRowhash") { iter => + benchmark.addCase("hash") { iter => var i = 0 val keyBytes = new Array[Byte](16) val key = new UnsafeRow(1) @@ -369,34 +368,15 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite { } } - benchmark.addCase("murmur3 hash") { iter => - var i = 0 - val keyBytes = new Array[Byte](16) - val key = new UnsafeRow(1) - key.pointTo(keyBytes, Platform.BYTE_ARRAY_OFFSET, 16) - var p = 524283 - var s = 0 - while (i < N) { - var h = Murmur3_x86_32.hashLong(i, 42) - key.setInt(0, h) - s += h - i += 1 - } - } - benchmark.addCase("fast hash") { iter => var i = 0 val keyBytes = new Array[Byte](16) val key = new UnsafeRow(1) key.pointTo(keyBytes, Platform.BYTE_ARRAY_OFFSET, 16) - var p = 524283 var s = 0 while (i < N) { - var h = i % p - if (h < 0) { - h += p - } - key.setInt(0, h) + key.setInt(0, i % 1000) + val h = Murmur3_x86_32.hashLong(i % 1000, 42) s += h i += 1 } @@ -495,42 +475,6 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite { } } - Seq(false, true).foreach { optimized => - benchmark.addCase(s"LongToUnsafeRowMap (opt=$optimized)") { iter => - var i = 0 - val valueBytes = new Array[Byte](16) - val value = new UnsafeRow(1) - value.pointTo(valueBytes, Platform.BYTE_ARRAY_OFFSET, 16) - value.setInt(0, 555) - val taskMemoryManager = new TaskMemoryManager( - new StaticMemoryManager( - new SparkConf().set("spark.memory.offHeap.enabled", "false"), - Long.MaxValue, - Long.MaxValue, - 1), - 0) - val map = new LongToUnsafeRowMap(taskMemoryManager, 64) - while (i < 65536) { - value.setInt(0, i) - val key = i % 100000 - map.append(key, value) - i += 1 - } - if (optimized) { - map.optimize() - } - var s = 0 - i = 0 - while (i < N) { - val key = i % 100000 - if (map.getValue(key, value) != null) { - s += 1 - } - i += 1 - } - } - } - Seq("off", "on").foreach { heap => benchmark.addCase(s"BytesToBytesMap ($heap Heap)") { iter => val taskMemoryManager = new TaskMemoryManager( @@ -549,27 +493,18 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite { val value = new UnsafeRow(1) value.pointTo(valueBytes, Platform.BYTE_ARRAY_OFFSET, 16) var i = 0 - val numKeys = 65536 - while (i < numKeys) { + while (i < N) { key.setInt(0, i % 65536) val loc = map.lookup(key.getBaseObject, key.getBaseOffset, key.getSizeInBytes, Murmur3_x86_32.hashLong(i % 65536, 42)) - if (!loc.isDefined) { + if (loc.isDefined) { + value.pointTo(loc.getValueBase, loc.getValueOffset, loc.getValueLength) + value.setInt(0, value.getInt(0) + 1) + i += 1 + } else { loc.append(key.getBaseObject, key.getBaseOffset, key.getSizeInBytes, value.getBaseObject, value.getBaseOffset, value.getSizeInBytes) } - i += 1 - } - i = 0 - var s = 0 - while (i < N) { - key.setInt(0, i % 100000) - val loc = map.lookup(key.getBaseObject, key.getBaseOffset, key.getSizeInBytes, - Murmur3_x86_32.hashLong(i % 100000, 42)) - if (loc.isDefined) { - s += 1 - } - i += 1 } } } @@ -600,19 +535,16 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite { Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz BytesToBytesMap: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------- - UnsafeRow hash 267 / 284 78.4 12.8 1.0X - murmur3 hash 102 / 129 205.5 4.9 2.6X - fast hash 79 / 96 263.8 3.8 3.4X - arrayEqual 164 / 172 128.2 7.8 1.6X - Java HashMap (Long) 321 / 399 65.4 15.3 0.8X - Java HashMap (two ints) 328 / 363 63.9 15.7 0.8X - Java HashMap (UnsafeRow) 1140 / 1200 18.4 54.3 0.2X - LongToUnsafeRowMap (opt=false) 378 / 400 55.5 18.0 0.7X - LongToUnsafeRowMap (opt=true) 144 / 152 145.2 6.9 1.9X - BytesToBytesMap (off Heap) 1300 / 1616 16.1 62.0 0.2X - BytesToBytesMap (on Heap) 1165 / 1202 18.0 55.5 0.2X - Aggregate HashMap 121 / 131 173.3 5.8 2.2X - */ + hash 112 / 116 93.2 10.7 1.0X + fast hash 65 / 69 160.9 6.2 1.7X + arrayEqual 66 / 69 159.1 6.3 1.7X + Java HashMap (Long) 137 / 182 76.3 13.1 0.8X + Java HashMap (two ints) 182 / 230 57.8 17.3 0.6X + Java HashMap (UnsafeRow) 511 / 565 20.5 48.8 0.2X + BytesToBytesMap (off Heap) 481 / 515 21.8 45.9 0.2X + BytesToBytesMap (on Heap) 529 / 600 19.8 50.5 0.2X + Aggregate HashMap 56 / 62 187.9 5.3 2.0X + */ benchmark.run() } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala index 17f2343cf9..9680f3a008 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala @@ -38,8 +38,8 @@ class ExchangeSuite extends SparkPlanTest with SharedSQLContext { test("compatible BroadcastMode") { val mode1 = IdentityBroadcastMode - val mode2 = HashedRelationBroadcastMode(Literal(1L) :: Nil) - val mode3 = HashedRelationBroadcastMode(Literal("s") :: Nil) + val mode2 = HashedRelationBroadcastMode(true, Literal(1) :: Nil, Seq()) + val mode3 = HashedRelationBroadcastMode(false, Literal("s") :: Nil, Seq()) assert(mode1.compatibleWith(mode1)) assert(!mode1.compatibleWith(mode2)) @@ -56,10 +56,10 @@ class ExchangeSuite extends SparkPlanTest with SharedSQLContext { assert(plan sameResult plan) val exchange1 = BroadcastExchange(IdentityBroadcastMode, plan) - val hashMode = HashedRelationBroadcastMode(output) + val hashMode = HashedRelationBroadcastMode(true, output, plan.output) val exchange2 = BroadcastExchange(hashMode, plan) val hashMode2 = - HashedRelationBroadcastMode(Alias(output.head, "id2")() :: Nil) + HashedRelationBroadcastMode(true, Alias(output.head, "id2")() :: Nil, plan.output) val exchange3 = BroadcastExchange(hashMode2, plan) val exchange4 = ReusedExchange(output, exchange3) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala index 371a9ed617..ed87a99439 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala @@ -30,23 +30,15 @@ import org.apache.spark.util.collection.CompactBuffer class HashedRelationSuite extends SparkFunSuite with SharedSQLContext { - val mm = new TaskMemoryManager( - new StaticMemoryManager( - new SparkConf().set("spark.memory.offHeap.enabled", "false"), - Long.MaxValue, - Long.MaxValue, - 1), - 0) - test("UnsafeHashedRelation") { val schema = StructType(StructField("a", IntegerType, true) :: Nil) val data = Array(InternalRow(0), InternalRow(1), InternalRow(2), InternalRow(2)) val toUnsafe = UnsafeProjection.create(schema) val unsafeData = data.map(toUnsafe(_).copy()) - val buildKey = Seq(BoundReference(0, IntegerType, false)) - val hashed = UnsafeHashedRelation(unsafeData.iterator, buildKey, 1, mm) + val keyGenerator = UnsafeProjection.create(buildKey) + val hashed = UnsafeHashedRelation(unsafeData.iterator, keyGenerator, 1) assert(hashed.isInstanceOf[UnsafeHashedRelation]) assert(hashed.get(unsafeData(0)).toArray === Array(unsafeData(0))) @@ -108,45 +100,31 @@ class HashedRelationSuite extends SparkFunSuite with SharedSQLContext { assert(java.util.Arrays.equals(os2.toByteArray, os.toByteArray)) } - test("LongToUnsafeRowMap") { + test("LongArrayRelation") { val unsafeProj = UnsafeProjection.create( Seq(BoundReference(0, IntegerType, false), BoundReference(1, IntegerType, true))) val rows = (0 until 100).map(i => unsafeProj(InternalRow(i, i + 1)).copy()) - val key = Seq(BoundReference(0, IntegerType, false)) - val longRelation = LongHashedRelation(rows.iterator, key, 10, mm) - assert(longRelation.keyIsUnique) + val keyProj = UnsafeProjection.create(Seq(BoundReference(0, IntegerType, false))) + val longRelation = LongHashedRelation(rows.iterator, keyProj, 100) + assert(longRelation.isInstanceOf[LongArrayRelation]) + val longArrayRelation = longRelation.asInstanceOf[LongArrayRelation] (0 until 100).foreach { i => - val row = longRelation.getValue(i) + val row = longArrayRelation.getValue(i) assert(row.getInt(0) === i) assert(row.getInt(1) === i + 1) } - val longRelation2 = LongHashedRelation(rows.iterator ++ rows.iterator, key, 100, mm) - assert(!longRelation2.keyIsUnique) - (0 until 100).foreach { i => - val rows = longRelation2.get(i).toArray - assert(rows.length === 2) - assert(rows(0).getInt(0) === i) - assert(rows(0).getInt(1) === i + 1) - assert(rows(1).getInt(0) === i) - assert(rows(1).getInt(1) === i + 1) - } - val os = new ByteArrayOutputStream() val out = new ObjectOutputStream(os) - longRelation2.writeExternal(out) + longArrayRelation.writeExternal(out) out.flush() val in = new ObjectInputStream(new ByteArrayInputStream(os.toByteArray)) - val relation = new LongHashedRelation() + val relation = new LongArrayRelation() relation.readExternal(in) - assert(!relation.keyIsUnique) (0 until 100).foreach { i => - val rows = relation.get(i).toArray - assert(rows.length === 2) - assert(rows(0).getInt(0) === i) - assert(rows(0).getInt(1) === i + 1) - assert(rows(1).getInt(0) === i) - assert(rows(1).getInt(1) === i + 1) + val row = longArrayRelation.getValue(i) + assert(row.getInt(0) === i) + assert(row.getInt(1) === i + 1) } } } -- cgit v1.2.3 From cd2fed70129ba601f8c849a93eeb44a5d69c2402 Mon Sep 17 00:00:00 2001 From: Yong Tang Date: Sat, 9 Apr 2016 13:54:30 -0700 Subject: [SPARK-14335][SQL] Describe function command returns wrong output MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? …because some of built-in functions are not in function registry. This fix tries to fix issues in `describe function` command where some of the outputs still shows Hive's function because some built-in functions are not in FunctionRegistry. The following built-in functions have been added to FunctionRegistry: ``` - ! * / & % ^ + < <= <=> = == > >= | ~ and in like not or rlike when ``` The following listed functions are not added, but hard coded in `commands.scala` (hvanhovell): ``` != <> between case ``` Below are the existing result of the above functions that have not been added: ``` spark-sql> describe function `!=`; Function: <> Class: org.apache.hadoop.hive.ql.udf.generic.GenericUDFOPNotEqual Usage: a <> b - Returns TRUE if a is not equal to b ``` ``` spark-sql> describe function `<>`; Function: <> Class: org.apache.hadoop.hive.ql.udf.generic.GenericUDFOPNotEqual Usage: a <> b - Returns TRUE if a is not equal to b ``` ``` spark-sql> describe function `between`; Function: between Class: org.apache.hadoop.hive.ql.udf.generic.GenericUDFBetween Usage: between a [NOT] BETWEEN b AND c - evaluate if a is [not] in between b and c ``` ``` spark-sql> describe function `case`; Function: case Class: org.apache.hadoop.hive.ql.udf.generic.GenericUDFCase Usage: CASE a WHEN b THEN c [WHEN d THEN e]* [ELSE f] END - When a = b, returns c; when a = d, return e; else return f ``` ## How was this patch tested? Existing tests passed. Additional test cases added. Author: Yong Tang Closes #12128 from yongtang/SPARK-14335. --- .../sql/catalyst/analysis/FunctionRegistry.scala | 33 +++++++++++++++- .../spark/sql/execution/command/commands.scala | 44 +++++++++++++++------- .../spark/sql/hive/execution/SQLQuerySuite.scala | 30 +++++++++++---- 3 files changed, 86 insertions(+), 21 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index f239b33e44..f2abf136da 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -171,6 +171,7 @@ object FunctionRegistry { expression[Rand]("rand"), expression[Randn]("randn"), expression[CreateStruct]("struct"), + expression[CaseWhen]("when"), // math functions expression[Acos]("acos"), @@ -217,6 +218,12 @@ object FunctionRegistry { expression[Tan]("tan"), expression[Tanh]("tanh"), + expression[Add]("+"), + expression[Subtract]("-"), + expression[Multiply]("*"), + expression[Divide]("/"), + expression[Remainder]("%"), + // aggregate functions expression[HyperLogLogPlusPlus]("approx_count_distinct"), expression[Average]("avg"), @@ -257,6 +264,7 @@ object FunctionRegistry { expression[Lower]("lcase"), expression[Length]("length"), expression[Levenshtein]("levenshtein"), + expression[Like]("like"), expression[Lower]("lower"), expression[StringLocate]("locate"), expression[StringLPad]("lpad"), @@ -267,6 +275,7 @@ object FunctionRegistry { expression[RegExpReplace]("regexp_replace"), expression[StringRepeat]("repeat"), expression[StringReverse]("reverse"), + expression[RLike]("rlike"), expression[StringRPad]("rpad"), expression[StringTrimRight]("rtrim"), expression[SoundEx]("soundex"), @@ -343,7 +352,29 @@ object FunctionRegistry { expression[NTile]("ntile"), expression[Rank]("rank"), expression[DenseRank]("dense_rank"), - expression[PercentRank]("percent_rank") + expression[PercentRank]("percent_rank"), + + // predicates + expression[And]("and"), + expression[In]("in"), + expression[Not]("not"), + expression[Or]("or"), + + expression[EqualNullSafe]("<=>"), + expression[EqualTo]("="), + expression[EqualTo]("=="), + expression[GreaterThan](">"), + expression[GreaterThanOrEqual](">="), + expression[LessThan]("<"), + expression[LessThanOrEqual]("<="), + expression[Not]("!"), + + // bitwise + expression[BitwiseAnd]("&"), + expression[BitwiseNot]("~"), + expression[BitwiseOr]("|"), + expression[BitwiseXor]("^") + ) val builtin: SimpleFunctionRegistry = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala index 3fd2a93d29..5d00c805a6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala @@ -483,20 +483,38 @@ case class DescribeFunction( } override def run(sqlContext: SQLContext): Seq[Row] = { - sqlContext.sessionState.functionRegistry.lookupFunction(functionName) match { - case Some(info) => - val result = - Row(s"Function: ${info.getName}") :: - Row(s"Class: ${info.getClassName}") :: - Row(s"Usage: ${replaceFunctionName(info.getUsage(), info.getName)}") :: Nil - - if (isExtended) { - result :+ Row(s"Extended Usage:\n${replaceFunctionName(info.getExtended, info.getName)}") - } else { - result - } + // Hard code "<>", "!=", "between", and "case" for now as there is no corresponding functions. + functionName.toLowerCase match { + case "<>" => + Row(s"Function: $functionName") :: + Row(s"Usage: a <> b - Returns TRUE if a is not equal to b") :: Nil + case "!=" => + Row(s"Function: $functionName") :: + Row(s"Usage: a != b - Returns TRUE if a is not equal to b") :: Nil + case "between" => + Row(s"Function: between") :: + Row(s"Usage: a [NOT] BETWEEN b AND c - " + + s"evaluate if a is [not] in between b and c") :: Nil + case "case" => + Row(s"Function: case") :: + Row(s"Usage: CASE a WHEN b THEN c [WHEN d THEN e]* [ELSE f] END - " + + s"When a = b, returns c; when a = d, return e; else return f") :: Nil + case _ => sqlContext.sessionState.functionRegistry.lookupFunction(functionName) match { + case Some(info) => + val result = + Row(s"Function: ${info.getName}") :: + Row(s"Class: ${info.getClassName}") :: + Row(s"Usage: ${replaceFunctionName(info.getUsage(), info.getName)}") :: Nil + + if (isExtended) { + result :+ + Row(s"Extended Usage:\n${replaceFunctionName(info.getExtended, info.getName)}") + } else { + result + } - case None => Seq(Row(s"Function: $functionName not found.")) + case None => Seq(Row(s"Function: $functionName not found.")) + } } } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 14a1d4cd30..d7ec85c15d 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -203,8 +203,7 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { checkAnswer(sql("SHOW functions abc.abs"), Row("abs")) checkAnswer(sql("SHOW functions `abc`.`abs`"), Row("abs")) checkAnswer(sql("SHOW functions `abc`.`abs`"), Row("abs")) - // TODO: Re-enable this test after we fix SPARK-14335. - // checkAnswer(sql("SHOW functions `~`"), Row("~")) + checkAnswer(sql("SHOW functions `~`"), Row("~")) checkAnswer(sql("SHOW functions `a function doens't exist`"), Nil) checkAnswer(sql("SHOW functions `weekofyea*`"), Row("weekofyear")) // this probably will failed if we add more function with `sha` prefixing. @@ -236,11 +235,28 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { checkExistence(sql("describe functioN abcadf"), true, "Function: abcadf not found.") - // TODO: Re-enable this test after we fix SPARK-14335. - // checkExistence(sql("describe functioN `~`"), true, - // "Function: ~", - // "Class: org.apache.hadoop.hive.ql.udf.UDFOPBitNot", - // "Usage: ~ n - Bitwise not") + checkExistence(sql("describe functioN `~`"), true, + "Function: ~", + "Class: org.apache.spark.sql.catalyst.expressions.BitwiseNot", + "Usage: To be added.") + + // Hard coded describe functions + checkExistence(sql("describe function `<>`"), true, + "Function: <>", + "Usage: a <> b - Returns TRUE if a is not equal to b") + + checkExistence(sql("describe function `!=`"), true, + "Function: !=", + "Usage: a != b - Returns TRUE if a is not equal to b") + + checkExistence(sql("describe function `between`"), true, + "Function: between", + "Usage: a [NOT] BETWEEN b AND c - evaluate if a is [not] in between b and c") + + checkExistence(sql("describe function `case`"), true, + "Function: case", + "Usage: CASE a WHEN b THEN c [WHEN d THEN e]* [ELSE f] END - " + + "When a = b, returns c; when a = d, return e; else return f") } test("SPARK-5371: union with null and sum") { -- cgit v1.2.3 From 415446cc9b2652f6da11ee8ead5eb4e66685c45f Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Sat, 9 Apr 2016 14:03:03 -0700 Subject: Revert "[SPARK-14462][ML][MLLIB] add the mllib-local build to maven pom" This reverts commit 1598d11bb0248384872cf88bc2b16f3b238046ad. --- dev/sparktestsupport/modules.py | 14 +--- mllib-local/pom.xml | 94 ---------------------- .../scala/org/apache/spark/ml/DummyTesting.scala | 23 ------ .../org/apache/spark/ml/DummyTestingSuite.scala | 28 ------- mllib/pom.xml | 12 --- pom.xml | 1 - project/SparkBuild.scala | 6 +- 7 files changed, 4 insertions(+), 174 deletions(-) delete mode 100644 mllib-local/pom.xml delete mode 100644 mllib-local/src/main/scala/org/apache/spark/ml/DummyTesting.scala delete mode 100644 mllib-local/src/test/scala/org/apache/spark/ml/DummyTestingSuite.scala diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index c844bcff7e..bb04ec6ee6 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -256,21 +256,9 @@ streaming_flume_assembly = Module( ) -mllib_local = Module( - name="mllib-local", - dependencies=[], - source_file_regexes=[ - "mllib-local", - ], - sbt_test_goals=[ - "mllib-local/test", - ] -) - - mllib = Module( name="mllib", - dependencies=[mllib_local, streaming, sql], + dependencies=[streaming, sql], source_file_regexes=[ "data/mllib/", "mllib/", diff --git a/mllib-local/pom.xml b/mllib-local/pom.xml deleted file mode 100644 index 69917eb0fb..0000000000 --- a/mllib-local/pom.xml +++ /dev/null @@ -1,94 +0,0 @@ - - - - - 4.0.0 - - org.apache.spark - spark-parent_2.11 - 2.0.0-SNAPSHOT - ../pom.xml - - - org.apache.spark - spark-mllib-local_2.11 - - mllib-local - - jar - Spark Project ML Local Library - http://spark.apache.org/ - - - - org.apache.spark - spark-core_${scala.binary.version} - ${project.version} - test-jar - test - - - org.scalanlp - breeze_${scala.binary.version} - 0.11.2 - - - - junit - junit - - - org.apache.commons - commons-math3 - - - - - org.apache.commons - commons-math3 - - - org.scalacheck - scalacheck_${scala.binary.version} - test - - - org.mockito - mockito-core - test - - - - - netlib-lgpl - - - com.github.fommil.netlib - all - ${netlib.java.version} - pom - - - - - - target/scala-${scala.binary.version}/classes - target/scala-${scala.binary.version}/test-classes - - diff --git a/mllib-local/src/main/scala/org/apache/spark/ml/DummyTesting.scala b/mllib-local/src/main/scala/org/apache/spark/ml/DummyTesting.scala deleted file mode 100644 index 6b3268cdfa..0000000000 --- a/mllib-local/src/main/scala/org/apache/spark/ml/DummyTesting.scala +++ /dev/null @@ -1,23 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.ml - -// This is a private class testing if the new build works. To be removed soon. -private[ml] object DummyTesting { - private[ml] def add10(input: Double): Double = input + 10 -} diff --git a/mllib-local/src/test/scala/org/apache/spark/ml/DummyTestingSuite.scala b/mllib-local/src/test/scala/org/apache/spark/ml/DummyTestingSuite.scala deleted file mode 100644 index 6c76dbfbfa..0000000000 --- a/mllib-local/src/test/scala/org/apache/spark/ml/DummyTestingSuite.scala +++ /dev/null @@ -1,28 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.ml - -import org.apache.spark.SparkFunSuite - -// This is testing if the new build works. To be removed soon. -class DummyTestingSuite extends SparkFunSuite { - - test("This is testing if the new build works.") { - assert(DummyTesting.add10(15) === 25) - } -} diff --git a/mllib/pom.xml b/mllib/pom.xml index e56eafc300..428176dcbf 100644 --- a/mllib/pom.xml +++ b/mllib/pom.xml @@ -62,18 +62,6 @@ spark-graphx_${scala.binary.version} ${project.version} - - org.apache.spark - spark-mllib-local_${scala.binary.version} - ${project.version} - - - org.apache.spark - spark-mllib-local_${scala.binary.version} - ${project.version} - test-jar - test - org.scalanlp breeze_${scala.binary.version} diff --git a/pom.xml b/pom.xml index 58c85b1d36..3f9e4abc32 100644 --- a/pom.xml +++ b/pom.xml @@ -94,7 +94,6 @@ core graphx mllib - mllib-local tools streaming sql/catalyst diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index c5688ecec6..60124ef0a1 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -47,9 +47,9 @@ object BuildCommons { ).map(ProjectRef(buildLocation, _)) val allProjects@Seq( - core, graphx, mllib, mllibLocal, repl, networkCommon, networkShuffle, launcher, unsafe, testTags, sketch, _* + core, graphx, mllib, repl, networkCommon, networkShuffle, launcher, unsafe, testTags, sketch, _* ) = Seq( - "core", "graphx", "mllib", "mllib-local", "repl", "network-common", "network-shuffle", "launcher", "unsafe", + "core", "graphx", "mllib", "repl", "network-common", "network-shuffle", "launcher", "unsafe", "test-tags", "sketch" ).map(ProjectRef(buildLocation, _)) ++ sqlProjects ++ streamingProjects @@ -254,7 +254,7 @@ object SparkBuild extends PomBuild { val mimaProjects = allProjects.filterNot { x => Seq( spark, hive, hiveThriftServer, catalyst, repl, networkCommon, networkShuffle, networkYarn, - unsafe, testTags, sketch, mllibLocal + unsafe, testTags, sketch ).contains(x) } -- cgit v1.2.3 From 9be5558e009069925d1f2d737d42e1683ed6b47f Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Sat, 9 Apr 2016 14:10:44 -0700 Subject: [SPARK-14481][SQL] Issue Exceptions for All Unsupported Options during Parsing #### What changes were proposed in this pull request? "Not good to slightly ignore all the un-supported options/clauses. We should either support it or throw an exception." A comment from yhuai in another PR https://github.com/apache/spark/pull/12146 - Can `Explain` be an exception? The `Formatted` clause is used in `HiveCompatibilitySuite`. - Two unsupported clauses in `Drop Table` are handled in a separate PR: https://github.com/apache/spark/pull/12146 #### How was this patch tested? Test cases are added to verify all the cases. Author: gatorsmile Closes #12255 from gatorsmile/warningToException. --- .../spark/sql/execution/SparkSqlParser.scala | 7 +-- .../sql/execution/command/DDLCommandSuite.scala | 10 +++- .../spark/sql/hive/execution/HiveSqlParser.scala | 16 +++--- .../spark/sql/hive/HiveDDLCommandSuite.scala | 60 +++++++++++++++++++++- .../sql/hive/execution/HiveCommandSuite.scala | 2 +- .../spark/sql/hive/execution/SQLQuerySuite.scala | 1 - 6 files changed, 81 insertions(+), 15 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala index bf21c9d524..c8d0f4e3c5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala @@ -143,10 +143,7 @@ class SparkSqlAstBuilder extends AstBuilder { override def visitExplain(ctx: ExplainContext): LogicalPlan = withOrigin(ctx) { val options = ctx.explainOption.asScala if (options.exists(_.FORMATTED != null)) { - logWarning("EXPLAIN FORMATTED option is ignored.") - } - if (options.exists(_.LOGICAL != null)) { - logWarning("EXPLAIN LOGICAL option is ignored.") + logWarning("Unsupported operation: EXPLAIN FORMATTED option") } // Create the explain comment. @@ -206,7 +203,7 @@ class SparkSqlAstBuilder extends AstBuilder { override def visitCreateTableUsing(ctx: CreateTableUsingContext): LogicalPlan = withOrigin(ctx) { val (table, temp, ifNotExists, external) = visitCreateTableHeader(ctx.createTableHeader) if (external) { - logWarning("EXTERNAL option is not supported.") + throw new ParseException("Unsupported operation: EXTERNAL option", ctx) } val options = Option(ctx.tablePropertyList).map(visitTablePropertyList).getOrElse(Map.empty) val provider = ctx.tableProvider.qualifiedName.getText diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala index 8e63b69876..b1c1fd0951 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala @@ -665,7 +665,7 @@ class DDLCommandSuite extends PlanTest { comparePlans(parsed2, expected2) } - test("commands only available in HiveContext") { + test("unsupported operations") { intercept[ParseException] { parser.parsePlan("DROP TABLE D1.T1") } @@ -682,6 +682,14 @@ class DDLCommandSuite extends PlanTest { |TBLPROPERTIES('prop1Key '= "prop1Val", ' `prop2Key` '= "prop2Val") """.stripMargin) } + intercept[ParseException] { + parser.parsePlan( + """ + |CREATE EXTERNAL TABLE oneToTenDef + |USING org.apache.spark.sql.sources + |OPTIONS (from '1', to '10') + """.stripMargin) + } intercept[ParseException] { parser.parsePlan("SELECT TRANSFORM (key, value) USING 'cat' AS (tKey, tValue) FROM testData") } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveSqlParser.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveSqlParser.scala index ab69d3502e..657edb493a 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveSqlParser.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveSqlParser.scala @@ -162,14 +162,16 @@ class HiveSqlAstBuilder extends SparkSqlAstBuilder { // Unsupported clauses. if (temp) { - logWarning("TEMPORARY clause is ignored.") + throw new ParseException(s"Unsupported operation: TEMPORARY clause.", ctx) } if (ctx.bucketSpec != null) { // TODO add this - we need cluster columns in the CatalogTable for this to work. - logWarning("CLUSTERED BY ... [ORDERED BY ...] INTO ... BUCKETS clause is ignored.") + throw new ParseException("Unsupported operation: " + + "CLUSTERED BY ... [ORDERED BY ...] INTO ... BUCKETS clause.", ctx) } if (ctx.skewSpec != null) { - logWarning("SKEWED BY ... ON ... [STORED AS DIRECTORIES] clause is ignored.") + throw new ParseException("Operation not allowed: " + + "SKEWED BY ... ON ... [STORED AS DIRECTORIES] clause.", ctx) } // Create the schema. @@ -230,7 +232,7 @@ class HiveSqlAstBuilder extends SparkSqlAstBuilder { throw new ParseException(s"Operation not allowed: partitioned views", ctx) } else { if (ctx.STRING != null) { - logWarning("COMMENT clause is ignored.") + throw new ParseException("Unsupported operation: COMMENT clause", ctx) } val identifiers = Option(ctx.identifierCommentList).toSeq.flatMap(_.identifierComment.asScala) val schema = identifiers.map { ic => @@ -296,7 +298,8 @@ class HiveSqlAstBuilder extends SparkSqlAstBuilder { recordReader: Token, schemaLess: Boolean): HiveScriptIOSchema = { if (recordWriter != null || recordReader != null) { - logWarning("Used defined record reader/writer classes are currently ignored.") + throw new ParseException( + "Unsupported operation: Used defined record reader/writer classes.", ctx) } // Decode and input/output format. @@ -370,7 +373,8 @@ class HiveSqlAstBuilder extends SparkSqlAstBuilder { ctx: TableFileFormatContext): CatalogStorageFormat = withOrigin(ctx) { import ctx._ if (inDriver != null || outDriver != null) { - logWarning("INPUTDRIVER ... OUTPUTDRIVER ... clauses are ignored.") + throw new ParseException( + s"Operation not allowed: INPUTDRIVER ... OUTPUTDRIVER ... clauses", ctx) } EmptyStorageFormat.copy( inputFormat = Option(string(inFmt)), diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDDLCommandSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDDLCommandSuite.scala index c5f01da4fa..12a582c10a 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDDLCommandSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDDLCommandSuite.scala @@ -180,6 +180,65 @@ class HiveDDLCommandSuite extends PlanTest { assert(desc.properties == Map(("tbl_p1" -> "p11"), ("tbl_p2" -> "p22"))) } + test("unsupported operations") { + intercept[ParseException] { + parser.parsePlan( + """ + |CREATE TEMPORARY TABLE ctas2 + |ROW FORMAT SERDE "org.apache.hadoop.hive.serde2.columnar.ColumnarSerDe" + |WITH SERDEPROPERTIES("serde_p1"="p1","serde_p2"="p2") + |STORED AS RCFile + |TBLPROPERTIES("tbl_p1"="p11", "tbl_p2"="p22") + |AS SELECT key, value FROM src ORDER BY key, value + """.stripMargin) + } + intercept[ParseException] { + parser.parsePlan( + """CREATE TABLE ctas2 + |STORED AS + |INPUTFORMAT "org.apache.hadoop.mapred.TextInputFormat" + |OUTPUTFORMAT "org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat" + |INPUTDRIVER "org.apache.hadoop.hive.howl.rcfile.RCFileInputDriver" + |OUTPUTDRIVER "org.apache.hadoop.hive.howl.rcfile.RCFileOutputDriver" + |AS SELECT key, value FROM src ORDER BY key, value + """.stripMargin) + } + intercept[ParseException] { + parser.parsePlan( + """ + |CREATE TABLE user_info_bucketed(user_id BIGINT, firstname STRING, lastname STRING) + |CLUSTERED BY(user_id) INTO 256 BUCKETS + |AS SELECT key, value FROM src ORDER BY key, value + """.stripMargin) + } + intercept[ParseException] { + parser.parsePlan( + """ + |CREATE TABLE user_info_bucketed(user_id BIGINT, firstname STRING, lastname STRING) + |SKEWED BY (key) ON (1,5,6) + |AS SELECT key, value FROM src ORDER BY key, value + """.stripMargin) + } + intercept[ParseException] { + parser.parsePlan( + """ + |SELECT TRANSFORM (key, value) USING 'cat' AS (tKey, tValue) + |ROW FORMAT SERDE 'org.apache.hadoop.hive.contrib.serde2.TypedBytesSerDe' + |RECORDREADER 'org.apache.hadoop.hive.contrib.util.typedbytes.TypedBytesRecordReader' + |FROM testData + """.stripMargin) + } + intercept[ParseException] { + parser.parsePlan( + """ + |CREATE OR REPLACE VIEW IF NOT EXISTS view1 (col1, col3) + |COMMENT 'blabla' + |TBLPROPERTIES('prop1Key'="prop1Val") + |AS SELECT * FROM tab1 + """.stripMargin) + } + } + test("Invalid interval term should throw AnalysisException") { def assertError(sql: String, errorMessage: String): Unit = { val e = intercept[AnalysisException] { @@ -277,7 +336,6 @@ class HiveDDLCommandSuite extends PlanTest { """ |CREATE OR REPLACE VIEW IF NOT EXISTS view1 |(col1, col3) - |COMMENT 'I cannot spell' |TBLPROPERTIES('prop1Key'="prop1Val") |AS SELECT * FROM tab1 """.stripMargin diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveCommandSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveCommandSuite.scala index 4c3f450522..8de2bdcfc0 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveCommandSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveCommandSuite.scala @@ -26,7 +26,7 @@ class HiveCommandSuite extends QueryTest with SQLTestUtils with TestHiveSingleto super.beforeAll() sql( """ - |CREATE EXTERNAL TABLE parquet_tab1 (c1 INT, c2 STRING) + |CREATE TABLE parquet_tab1 (c1 INT, c2 STRING) |USING org.apache.spark.sql.parquet.DefaultSource """.stripMargin) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index d7ec85c15d..f3796a9966 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -1491,7 +1491,6 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { sql( """CREATE VIEW IF NOT EXISTS |default.testView (c1 COMMENT 'blabla', c2 COMMENT 'blabla') - |COMMENT 'blabla' |TBLPROPERTIES ('a' = 'b') |AS SELECT * FROM jt""".stripMargin) checkAnswer(sql("SELECT c1, c2 FROM testView ORDER BY c1"), (1 to 9).map(i => Row(i, i))) -- cgit v1.2.3 From dfce9665c4b2b29a19e6302216dae2800da68ff9 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Sat, 9 Apr 2016 17:40:36 -0700 Subject: [SPARK-14362][SPARK-14406][SQL] DDL Native Support: Drop View and Drop Table #### What changes were proposed in this pull request? This PR is to provide a native support for DDL `DROP VIEW` and `DROP TABLE`. The PR includes native parsing and native analysis. Based on the HIVE DDL document for [DROP_VIEW_WEB_LINK](https://cwiki.apache.org/confluence/display/Hive/LanguageManual+DDL#LanguageManualDDL- DropView ), `DROP VIEW` is defined as, **Syntax:** ```SQL DROP VIEW [IF EXISTS] [db_name.]view_name; ``` - to remove metadata for the specified view. - illegal to use DROP TABLE on a view. - illegal to use DROP VIEW on a table. - this command only works in `HiveContext`. In `SQLContext`, we will get an exception. This PR also handles `DROP TABLE`. **Syntax:** ```SQL DROP TABLE [IF EXISTS] table_name [PURGE]; ``` - Previously, the `DROP TABLE` command only can drop Hive tables in `HiveContext`. Now, after this PR, this command also can drop temporary table, external table, external data source table in `SQLContext`. - In `HiveContext`, we will not issue an exception if the to-be-dropped table does not exist and users did not specify `IF EXISTS`. Instead, we just log an error message. If `IF EXISTS` is specified, we will not issue any error message/exception. - In `SQLContext`, we will issue an exception if the to-be-dropped table does not exist, unless `IF EXISTS` is specified. - Data will not be deleted if the tables are `external`, unless table type is `managed_table`. #### How was this patch tested? For verifying command parsing, added test cases in `spark/sql/hive/HiveDDLCommandSuite.scala` For verifying command analysis, added test cases in `spark/sql/hive/execution/HiveDDLSuite.scala` Author: gatorsmile Author: xiaoli Author: Xiao Li Closes #12146 from gatorsmile/dropView. --- .../apache/spark/sql/catalyst/parser/SqlBase.g4 | 2 +- .../sql/catalyst/catalog/InMemoryCatalog.scala | 4 + .../sql/catalyst/catalog/SessionCatalog.scala | 29 +++- .../spark/sql/catalyst/catalog/interface.scala | 2 + .../sql/catalyst/catalog/SessionCatalogSuite.scala | 7 +- .../spark/sql/execution/SparkSqlParser.scala | 16 +++ .../apache/spark/sql/execution/command/ddl.scala | 54 ++++++- .../sql/execution/command/DDLCommandSuite.scala | 56 +++++++- .../spark/sql/execution/command/DDLSuite.scala | 48 +++++++ .../spark/sql/hive/thriftserver/CliSuite.scala | 6 +- .../spark/sql/hive/HiveExternalCatalog.scala | 4 + .../apache/spark/sql/hive/HiveSessionCatalog.scala | 2 + .../spark/sql/hive/execution/HiveSqlParser.scala | 13 -- .../apache/spark/sql/hive/execution/commands.scala | 30 ---- .../spark/sql/hive/HiveDDLCommandSuite.scala | 10 +- .../spark/sql/hive/execution/HiveDDLSuite.scala | 156 +++++++++++++++++++++ 16 files changed, 376 insertions(+), 63 deletions(-) create mode 100644 sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index 85cb585919..2f2e060b38 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -104,6 +104,7 @@ statement REPLACE COLUMNS '(' colTypeList ')' (CASCADE | RESTRICT)? #replaceColumns | DROP TABLE (IF EXISTS)? tableIdentifier PURGE? (FOR METADATA? REPLICATION '(' STRING ')')? #dropTable + | DROP VIEW (IF EXISTS)? tableIdentifier #dropTable | CREATE (OR REPLACE)? VIEW (IF NOT EXISTS)? tableIdentifier identifierCommentList? (COMMENT STRING)? (PARTITIONED ON identifierList)? @@ -141,7 +142,6 @@ hiveNativeCommands | DELETE FROM tableIdentifier (WHERE booleanExpression)? | TRUNCATE TABLE tableIdentifier partitionSpec? (COLUMNS identifierList)? - | DROP VIEW (IF EXISTS)? qualifiedName | SHOW COLUMNS (FROM | IN) tableIdentifier ((FROM|IN) identifier)? | START TRANSACTION (transactionMode (',' transactionMode)*)? | COMMIT WORK? diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala index 186bbccef1..1994acd1ad 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala @@ -187,6 +187,10 @@ class InMemoryCatalog extends ExternalCatalog { catalog(db).tables(table).table } + override def getTableOption(db: String, table: String): Option[CatalogTable] = synchronized { + if (!tableExists(db, table)) None else Option(catalog(db).tables(table).table) + } + override def tableExists(db: String, table: String): Boolean = synchronized { requireDbExists(db) catalog(db).tables.contains(table) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala index 7db9fd0527..c1e5a485e7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala @@ -21,6 +21,7 @@ import java.io.File import scala.collection.mutable +import org.apache.spark.internal.Logging import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.{CatalystConf, SimpleCatalystConf} import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} @@ -41,7 +42,7 @@ class SessionCatalog( externalCatalog: ExternalCatalog, functionResourceLoader: FunctionResourceLoader, functionRegistry: FunctionRegistry, - conf: CatalystConf) { + conf: CatalystConf) extends Logging { import ExternalCatalog._ def this( @@ -175,6 +176,17 @@ class SessionCatalog( externalCatalog.getTable(db, table) } + /** + * Retrieve the metadata of an existing metastore table. + * If no database is specified, assume the table is in the current database. + * If the specified table is not found in the database then return None if it doesn't exist. + */ + def getTableMetadataOption(name: TableIdentifier): Option[CatalogTable] = { + val db = name.database.getOrElse(currentDb) + val table = formatTableName(name.table) + externalCatalog.getTableOption(db, table) + } + // ------------------------------------------------------------- // | Methods that interact with temporary and metastore tables | // ------------------------------------------------------------- @@ -229,7 +241,13 @@ class SessionCatalog( val db = name.database.getOrElse(currentDb) val table = formatTableName(name.table) if (name.database.isDefined || !tempTables.contains(table)) { - externalCatalog.dropTable(db, table, ignoreIfNotExists) + // When ignoreIfNotExists is false, no exception is issued when the table does not exist. + // Instead, log it as an error message. This is consistent with Hive. + if (externalCatalog.tableExists(db, table)) { + externalCatalog.dropTable(db, table, ignoreIfNotExists = true) + } else if (!ignoreIfNotExists) { + logError(s"Table '${name.quotedString}' does not exist") + } } else { tempTables.remove(table) } @@ -283,9 +301,14 @@ class SessionCatalog( * explicitly specified. */ def isTemporaryTable(name: TableIdentifier): Boolean = { - !name.database.isDefined && tempTables.contains(formatTableName(name.table)) + name.database.isEmpty && tempTables.contains(formatTableName(name.table)) } + /** + * Return whether View is supported + */ + def isViewSupported: Boolean = false + /** * List all tables in the specified database, including temporary tables. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala index e29d6bd8b0..4ef59316ce 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala @@ -91,6 +91,8 @@ abstract class ExternalCatalog { def getTable(db: String, table: String): CatalogTable + def getTableOption(db: String, table: String): Option[CatalogTable] + def tableExists(db: String, table: String): Boolean def listTables(db: String): Seq[String] diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala index 1850dc8156..862fc275ad 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala @@ -233,10 +233,9 @@ class SessionCatalogSuite extends SparkFunSuite { intercept[AnalysisException] { catalog.dropTable(TableIdentifier("tbl1", Some("unknown_db")), ignoreIfNotExists = true) } - // Table does not exist - intercept[AnalysisException] { - catalog.dropTable(TableIdentifier("unknown_table", Some("db2")), ignoreIfNotExists = false) - } + // If the table does not exist, we do not issue an exception. Instead, we output an error log + // message to console when ignoreIfNotExists is set to false. + catalog.dropTable(TableIdentifier("unknown_table", Some("db2")), ignoreIfNotExists = false) catalog.dropTable(TableIdentifier("unknown_table", Some("db2")), ignoreIfNotExists = true) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala index c8d0f4e3c5..3da715cdb3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala @@ -363,6 +363,22 @@ class SparkSqlAstBuilder extends AstBuilder { } } + /** + * Create a [[DropTable]] command. + */ + override def visitDropTable(ctx: DropTableContext): LogicalPlan = withOrigin(ctx) { + if (ctx.PURGE != null) { + throw new ParseException("Unsupported operation: PURGE option", ctx) + } + if (ctx.REPLICATION != null) { + throw new ParseException("Unsupported operation: REPLICATION clause", ctx) + } + DropTable( + visitTableIdentifier(ctx.tableIdentifier), + ctx.EXISTS != null, + ctx.VIEW != null) + } + /** * Create a [[AlterTableRename]] command. * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala index 20779d68e0..e941736f9a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.command import org.apache.spark.internal.Logging import org.apache.spark.sql.{AnalysisException, Row, SQLContext} import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.catalyst.catalog.{CatalogDatabase, CatalogTable} +import org.apache.spark.sql.catalyst.catalog.{CatalogDatabase, CatalogTable, CatalogTableType} import org.apache.spark.sql.catalyst.catalog.ExternalCatalog.TablePartitionSpec import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} import org.apache.spark.sql.types._ @@ -175,13 +175,61 @@ case class DescribeDatabase( } } +/** + * Drops a table/view from the metastore and removes it if it is cached. + * + * The syntax of this command is: + * {{{ + * DROP TABLE [IF EXISTS] table_name; + * DROP VIEW [IF EXISTS] [db_name.]view_name; + * }}} + */ +case class DropTable( + tableName: TableIdentifier, + ifExists: Boolean, + isView: Boolean) extends RunnableCommand { + + override def run(sqlContext: SQLContext): Seq[Row] = { + val catalog = sqlContext.sessionState.catalog + if (isView && !catalog.isViewSupported) { + throw new AnalysisException(s"Not supported object: views") + } + // If the command DROP VIEW is to drop a table or DROP TABLE is to drop a view + // issue an exception. + catalog.getTableMetadataOption(tableName).map(_.tableType match { + case CatalogTableType.VIRTUAL_VIEW if !isView => + throw new AnalysisException( + "Cannot drop a view with DROP TABLE. Please use DROP VIEW instead") + case o if o != CatalogTableType.VIRTUAL_VIEW && isView => + throw new AnalysisException( + s"Cannot drop a table with DROP VIEW. Please use DROP TABLE instead") + case _ => + }) + + try { + sqlContext.cacheManager.tryUncacheQuery(sqlContext.table(tableName.quotedString)) + } catch { + // This table's metadata is not in Hive metastore (e.g. the table does not exist). + case e if e.getClass.getName == "org.apache.hadoop.hive.ql.metadata.InvalidTableException" => + case _: org.apache.spark.sql.catalyst.analysis.NoSuchTableException => + // Other Throwables can be caused by users providing wrong parameters in OPTIONS + // (e.g. invalid paths). We catch it and log a warning message. + // Users should be able to drop such kinds of tables regardless if there is an error. + case e: Throwable => log.warn(s"${e.getMessage}", e) + } + catalog.invalidateTable(tableName) + catalog.dropTable(tableName, ifExists) + Seq.empty[Row] + } +} + /** * A command that renames a table/view. * * The syntax of this command is: * {{{ - * ALTER TABLE table1 RENAME TO table2; - * ALTER VIEW view1 RENAME TO view2; + * ALTER TABLE table1 RENAME TO table2; + * ALTER VIEW view1 RENAME TO view2; * }}} */ case class AlterTableRename( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala index b1c1fd0951..ac69518ddf 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala @@ -20,9 +20,8 @@ package org.apache.spark.sql.execution.command import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.catalyst.plans.PlanTest -import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project} +import org.apache.spark.sql.catalyst.plans.logical.Project import org.apache.spark.sql.execution.SparkSqlParser -import org.apache.spark.sql.execution.datasources.BucketSpec import org.apache.spark.sql.types._ class DDLCommandSuite extends PlanTest { @@ -667,7 +666,10 @@ class DDLCommandSuite extends PlanTest { test("unsupported operations") { intercept[ParseException] { - parser.parsePlan("DROP TABLE D1.T1") + parser.parsePlan("DROP TABLE tab PURGE") + } + intercept[ParseException] { + parser.parsePlan("DROP TABLE tab FOR REPLICATION('eventid')") } intercept[ParseException] { parser.parsePlan("CREATE VIEW testView AS SELECT id FROM tab") @@ -700,4 +702,52 @@ class DDLCommandSuite extends PlanTest { val parsed = parser.parsePlan(sql) assert(parsed.isInstanceOf[Project]) } + + test("drop table") { + val tableName1 = "db.tab" + val tableName2 = "tab" + + val parsed1 = parser.parsePlan(s"DROP TABLE $tableName1") + val parsed2 = parser.parsePlan(s"DROP TABLE IF EXISTS $tableName1") + val parsed3 = parser.parsePlan(s"DROP TABLE $tableName2") + val parsed4 = parser.parsePlan(s"DROP TABLE IF EXISTS $tableName2") + + val expected1 = + DropTable(TableIdentifier("tab", Option("db")), ifExists = false, isView = false) + val expected2 = + DropTable(TableIdentifier("tab", Option("db")), ifExists = true, isView = false) + val expected3 = + DropTable(TableIdentifier("tab", None), ifExists = false, isView = false) + val expected4 = + DropTable(TableIdentifier("tab", None), ifExists = true, isView = false) + + comparePlans(parsed1, expected1) + comparePlans(parsed2, expected2) + comparePlans(parsed3, expected3) + comparePlans(parsed4, expected4) + } + + test("drop view") { + val viewName1 = "db.view" + val viewName2 = "view" + + val parsed1 = parser.parsePlan(s"DROP VIEW $viewName1") + val parsed2 = parser.parsePlan(s"DROP VIEW IF EXISTS $viewName1") + val parsed3 = parser.parsePlan(s"DROP VIEW $viewName2") + val parsed4 = parser.parsePlan(s"DROP VIEW IF EXISTS $viewName2") + + val expected1 = + DropTable(TableIdentifier("view", Option("db")), ifExists = false, isView = true) + val expected2 = + DropTable(TableIdentifier("view", Option("db")), ifExists = true, isView = true) + val expected3 = + DropTable(TableIdentifier("view", None), ifExists = false, isView = true) + val expected4 = + DropTable(TableIdentifier("view", None), ifExists = true, isView = true) + + comparePlans(parsed1, expected1) + comparePlans(parsed2, expected2) + comparePlans(parsed3, expected3) + comparePlans(parsed4, expected4) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala index 7084665b3b..e75e5f5cb2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala @@ -391,6 +391,54 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { Nil) } + test("drop table - temporary table") { + val catalog = sqlContext.sessionState.catalog + sql( + """ + |CREATE TEMPORARY TABLE tab1 + |USING org.apache.spark.sql.sources.DDLScanSource + |OPTIONS ( + | From '1', + | To '10', + | Table 'test1' + |) + """.stripMargin) + assert(catalog.listTables("default") == Seq(TableIdentifier("tab1"))) + sql("DROP TABLE tab1") + assert(catalog.listTables("default") == Nil) + } + + test("drop table") { + testDropTable(isDatasourceTable = false) + } + + test("drop table - data source table") { + testDropTable(isDatasourceTable = true) + } + + private def testDropTable(isDatasourceTable: Boolean): Unit = { + val catalog = sqlContext.sessionState.catalog + val tableIdent = TableIdentifier("tab1", Some("dbx")) + createDatabase(catalog, "dbx") + createTable(catalog, tableIdent) + if (isDatasourceTable) { + convertToDatasourceTable(catalog, tableIdent) + } + assert(catalog.listTables("dbx") == Seq(tableIdent)) + sql("DROP TABLE dbx.tab1") + assert(catalog.listTables("dbx") == Nil) + sql("DROP TABLE IF EXISTS dbx.tab1") + // no exception will be thrown + sql("DROP TABLE dbx.tab1") + } + + test("drop view") { + val e = intercept[AnalysisException] { + sql("DROP VIEW dbx.tab1") + } + assert(e.getMessage.contains("Not supported object: views")) + } + private def convertToDatasourceTable( catalog: SessionCatalog, tableIdent: TableIdentifier): Unit = { diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala index e93b0c145f..9ec8b9a9a6 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala @@ -172,7 +172,7 @@ class CliSuite extends SparkFunSuite with BeforeAndAfterAll with Logging { "SELECT COUNT(*) FROM hive_test;" -> "5", "DROP TABLE hive_test;" - -> "OK" + -> "" ) } @@ -220,9 +220,9 @@ class CliSuite extends SparkFunSuite with BeforeAndAfterAll with Logging { "SELECT count(key) FROM t1;" -> "5", "DROP TABLE t1;" - -> "OK", + -> "", "DROP TABLE sourceTable;" - -> "OK" + -> "" ) } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala index b1156fb3e2..a49ce33ba1 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala @@ -182,6 +182,10 @@ private[spark] class HiveExternalCatalog(client: HiveClient) extends ExternalCat client.getTable(db, table) } + override def getTableOption(db: String, table: String): Option[CatalogTable] = withClient { + client.getTableOption(db, table) + } + override def tableExists(db: String, table: String): Boolean = withClient { client.getTableOption(db, table).isDefined } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala index 0cccc22e5a..875652c226 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala @@ -70,6 +70,8 @@ private[sql] class HiveSessionCatalog( } } + override def isViewSupported: Boolean = true + // ---------------------------------------------------------------- // | Methods and fields for interacting with HiveMetastoreCatalog | // ---------------------------------------------------------------- diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveSqlParser.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveSqlParser.scala index 657edb493a..7a435117e7 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveSqlParser.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveSqlParser.scala @@ -103,19 +103,6 @@ class HiveSqlAstBuilder extends SparkSqlAstBuilder { } } - /** - * Create a [[DropTable]] command. - */ - override def visitDropTable(ctx: DropTableContext): LogicalPlan = withOrigin(ctx) { - if (ctx.PURGE != null) { - logWarning("PURGE option is ignored.") - } - if (ctx.REPLICATION != null) { - logWarning("REPLICATION clause is ignored.") - } - DropTable(visitTableIdentifier(ctx.tableIdentifier).toString, ctx.EXISTS != null) - } - /** * Create an [[AnalyzeTable]] command. This currently only implements the NOSCAN option (other * options are passed on to Hive) e.g.: diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala index 64d1341a47..06badff474 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala @@ -46,36 +46,6 @@ case class AnalyzeTable(tableName: String) extends RunnableCommand { } } -/** - * Drops a table from the metastore and removes it if it is cached. - */ -private[hive] -case class DropTable( - tableName: String, - ifExists: Boolean) extends RunnableCommand { - - override def run(sqlContext: SQLContext): Seq[Row] = { - val hiveContext = sqlContext.asInstanceOf[HiveContext] - val ifExistsClause = if (ifExists) "IF EXISTS " else "" - try { - hiveContext.cacheManager.tryUncacheQuery(hiveContext.table(tableName)) - } catch { - // This table's metadata is not in Hive metastore (e.g. the table does not exist). - case _: org.apache.hadoop.hive.ql.metadata.InvalidTableException => - case _: org.apache.spark.sql.catalyst.analysis.NoSuchTableException => - // Other Throwables can be caused by users providing wrong parameters in OPTIONS - // (e.g. invalid paths). We catch it and log a warning message. - // Users should be able to drop such kinds of tables regardless if there is an error. - case e: Throwable => log.warn(s"${e.getMessage}", e) - } - hiveContext.invalidateTable(tableName) - hiveContext.runSqlHive(s"DROP TABLE $ifExistsClause$tableName") - hiveContext.sessionState.catalog.dropTable( - TableIdentifier(tableName), ignoreIfNotExists = true) - Seq.empty[Row] - } -} - private[hive] case class AddJar(path: String) extends RunnableCommand { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDDLCommandSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDDLCommandSuite.scala index 12a582c10a..a144da4997 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDDLCommandSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDDLCommandSuite.scala @@ -72,6 +72,7 @@ class HiveDDLCommandSuite extends PlanTest { CatalogColumn("country", "string", comment = Some("country of origination")) :: Nil) // TODO will be SQLText assert(desc.viewText == Option("This is the staging page view table")) + assert(desc.viewOriginalText.isEmpty) assert(desc.partitionColumns == CatalogColumn("dt", "string", comment = Some("date type")) :: CatalogColumn("hour", "string", comment = Some("hour of the day")) :: Nil) @@ -118,6 +119,7 @@ class HiveDDLCommandSuite extends PlanTest { CatalogColumn("country", "string", comment = Some("country of origination")) :: Nil) // TODO will be SQLText assert(desc.viewText == Option("This is the staging page view table")) + assert(desc.viewOriginalText.isEmpty) assert(desc.partitionColumns == CatalogColumn("dt", "string", comment = Some("date type")) :: CatalogColumn("hour", "string", comment = Some("hour of the day")) :: Nil) @@ -138,6 +140,7 @@ class HiveDDLCommandSuite extends PlanTest { assert(desc.storage.locationUri == None) assert(desc.schema == Seq.empty[CatalogColumn]) assert(desc.viewText == None) // TODO will be SQLText + assert(desc.viewOriginalText.isEmpty) assert(desc.storage.serdeProperties == Map()) assert(desc.storage.inputFormat == Some("org.apache.hadoop.mapred.TextInputFormat")) assert(desc.storage.outputFormat == @@ -173,6 +176,7 @@ class HiveDDLCommandSuite extends PlanTest { assert(desc.storage.locationUri == None) assert(desc.schema == Seq.empty[CatalogColumn]) assert(desc.viewText == None) // TODO will be SQLText + assert(desc.viewOriginalText.isEmpty) assert(desc.storage.serdeProperties == Map(("serde_p1" -> "p1"), ("serde_p2" -> "p2"))) assert(desc.storage.inputFormat == Some("org.apache.hadoop.hive.ql.io.RCFileInputFormat")) assert(desc.storage.outputFormat == Some("org.apache.hadoop.hive.ql.io.RCFileOutputFormat")) @@ -286,7 +290,7 @@ class HiveDDLCommandSuite extends PlanTest { } test("use backticks in output of Script Transform") { - val plan = parser.parsePlan( + parser.parsePlan( """SELECT `t`.`thing1` |FROM (SELECT TRANSFORM (`parquet_t1`.`key`, `parquet_t1`.`value`) |USING 'cat' AS (`thing1` int, `thing2` string) FROM `default`.`parquet_t1`) AS t @@ -294,7 +298,7 @@ class HiveDDLCommandSuite extends PlanTest { } test("use backticks in output of Generator") { - val plan = parser.parsePlan( + parser.parsePlan( """ |SELECT `gentab2`.`gencol2` |FROM `default`.`src` @@ -304,7 +308,7 @@ class HiveDDLCommandSuite extends PlanTest { } test("use escaped backticks in output of Generator") { - val plan = parser.parsePlan( + parser.parsePlan( """ |SELECT `gen``tab2`.`gen``col2` |FROM `default`.`src` diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala new file mode 100644 index 0000000000..78ccdc7adb --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala @@ -0,0 +1,156 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.execution + +import org.apache.hadoop.fs.Path + +import org.apache.spark.sql.{AnalysisException, QueryTest, SaveMode} +import org.apache.spark.sql.catalyst.catalog.CatalogTableType +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.hive.test.TestHiveSingleton +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.SQLTestUtils + +class HiveDDLSuite extends QueryTest with SQLTestUtils with TestHiveSingleton { + import hiveContext.implicits._ + + // check if the directory for recording the data of the table exists. + private def tableDirectoryExists(tableIdentifier: TableIdentifier): Boolean = { + val expectedTablePath = + hiveContext.sessionState.catalog.hiveDefaultTableFilePath(tableIdentifier) + val filesystemPath = new Path(expectedTablePath) + val fs = filesystemPath.getFileSystem(sparkContext.hadoopConfiguration) + fs.exists(filesystemPath) + } + + test("drop tables") { + withTable("tab1") { + val tabName = "tab1" + + assert(!tableDirectoryExists(TableIdentifier(tabName))) + sql(s"CREATE TABLE $tabName(c1 int)") + + assert(tableDirectoryExists(TableIdentifier(tabName))) + sql(s"DROP TABLE $tabName") + + assert(!tableDirectoryExists(TableIdentifier(tabName))) + sql(s"DROP TABLE IF EXISTS $tabName") + sql(s"DROP VIEW IF EXISTS $tabName") + } + } + + test("drop managed tables") { + withTempDir { tmpDir => + val tabName = "tab1" + withTable(tabName) { + assert(tmpDir.listFiles.isEmpty) + sql( + s""" + |create table $tabName + |stored as parquet + |location '$tmpDir' + |as select 1, '3' + """.stripMargin) + + val hiveTable = + hiveContext.sessionState.catalog + .getTableMetadata(TableIdentifier(tabName, Some("default"))) + // It is a managed table, although it uses external in SQL + assert(hiveTable.tableType == CatalogTableType.MANAGED_TABLE) + + assert(tmpDir.listFiles.nonEmpty) + sql(s"DROP TABLE $tabName") + // The data are deleted since the table type is not EXTERNAL + assert(tmpDir.listFiles == null) + } + } + } + + test("drop external data source table") { + withTempDir { tmpDir => + val tabName = "tab1" + withTable(tabName) { + assert(tmpDir.listFiles.isEmpty) + + withSQLConf(SQLConf.PARQUET_WRITE_LEGACY_FORMAT.key -> "true") { + Seq(1 -> "a").toDF("i", "j") + .write + .mode(SaveMode.Overwrite) + .format("parquet") + .option("path", tmpDir.toString) + .saveAsTable(tabName) + } + + val hiveTable = + hiveContext.sessionState.catalog + .getTableMetadata(TableIdentifier(tabName, Some("default"))) + // This data source table is external table + assert(hiveTable.tableType == CatalogTableType.EXTERNAL_TABLE) + + assert(tmpDir.listFiles.nonEmpty) + sql(s"DROP TABLE $tabName") + // The data are not deleted since the table type is EXTERNAL + assert(tmpDir.listFiles.nonEmpty) + } + } + } + + test("drop views") { + withTable("tab1") { + val tabName = "tab1" + sqlContext.range(10).write.saveAsTable("tab1") + withView("view1") { + val viewName = "view1" + + assert(tableDirectoryExists(TableIdentifier(tabName))) + assert(!tableDirectoryExists(TableIdentifier(viewName))) + sql(s"CREATE VIEW $viewName AS SELECT * FROM tab1") + + assert(tableDirectoryExists(TableIdentifier(tabName))) + assert(!tableDirectoryExists(TableIdentifier(viewName))) + sql(s"DROP VIEW $viewName") + + assert(tableDirectoryExists(TableIdentifier(tabName))) + sql(s"DROP VIEW IF EXISTS $viewName") + } + } + } + + test("drop table using drop view") { + withTable("tab1") { + sql("CREATE TABLE tab1(c1 int)") + val message = intercept[AnalysisException] { + sql("DROP VIEW tab1") + }.getMessage + assert(message.contains("Cannot drop a table with DROP VIEW. Please use DROP TABLE instead")) + } + } + + test("drop view using drop table") { + withTable("tab1") { + sqlContext.range(10).write.saveAsTable("tab1") + withView("view1") { + sql("CREATE VIEW view1 AS SELECT * FROM tab1") + val message = intercept[AnalysisException] { + sql("DROP TABLE view1") + }.getMessage + assert(message.contains("Cannot drop a view with DROP TABLE. Please use DROP VIEW instead")) + } + } + } +} -- cgit v1.2.3 From 5cb5edaf9c5054e42d41f20b2dd92dafcccbf0d6 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Sat, 9 Apr 2016 17:44:38 -0700 Subject: [SPARK-14419] [SQL] Improve HashedRelation for key fit within Long ## What changes were proposed in this pull request? Currently, we use java HashMap for HashedRelation if the key could fit within a Long. The java HashMap and CompactBuffer are not memory efficient, the memory used by them is also accounted accurately. This PR introduce a LongToUnsafeRowMap (similar to BytesToBytesMap) for better memory efficiency and performance. This PR reopen #12190 to fix bugs. ## How was this patch tested? Existing tests. Author: Davies Liu Closes #12278 from davies/long_map3. --- .../apache/spark/unsafe/map/BytesToBytesMap.java | 14 +- .../execution/aggregate/TungstenAggregate.scala | 3 +- .../sql/execution/joins/BroadcastHashJoin.scala | 18 +- .../spark/sql/execution/joins/HashJoin.scala | 41 +- .../spark/sql/execution/joins/HashedRelation.scala | 648 ++++++++++++++------- .../sql/execution/joins/ShuffledHashJoin.scala | 51 +- .../sql/execution/BenchmarkWholeStageCodegen.scala | 132 ++++- .../apache/spark/sql/execution/ExchangeSuite.scala | 8 +- .../sql/execution/joins/HashedRelationSuite.scala | 48 +- 9 files changed, 602 insertions(+), 361 deletions(-) diff --git a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java index 32958be7a7..6807710f9f 100644 --- a/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java +++ b/core/src/main/java/org/apache/spark/unsafe/map/BytesToBytesMap.java @@ -716,7 +716,8 @@ public final class BytesToBytesMap extends MemoryConsumer { offset += klen; Platform.copyMemory(vbase, voff, base, offset, vlen); offset += vlen; - Platform.putLong(base, offset, 0); + // put this value at the beginning of the list + Platform.putLong(base, offset, isDefined ? longArray.get(pos * 2) : 0); // --- Update bookkeeping data structures ---------------------------------------------------- offset = currentPage.getBaseOffset(); @@ -724,17 +725,12 @@ public final class BytesToBytesMap extends MemoryConsumer { pageCursor += recordLength; final long storedKeyAddress = taskMemoryManager.encodePageNumberAndOffset( currentPage, recordOffset); + longArray.set(pos * 2, storedKeyAddress); + updateAddressesAndSizes(storedKeyAddress); numValues++; - if (isDefined) { - // put this pair at the end of chain - while (nextValue()) { /* do nothing */ } - Platform.putLong(baseObject, valueOffset + valueLength, storedKeyAddress); - nextValue(); // point to new added value - } else { + if (!isDefined) { numKeys++; - longArray.set(pos * 2, storedKeyAddress); longArray.set(pos * 2 + 1, keyHashcode); - updateAddressesAndSizes(storedKeyAddress); isDefined = true; if (numKeys > growthThreshold && longArray.size() < MAX_CAPACITY) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala index 0a5a72c52a..692fef703f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala @@ -454,7 +454,7 @@ case class TungstenAggregate( val thisPlan = ctx.addReferenceObj("plan", this) hashMapTerm = ctx.freshName("hashMap") val hashMapClassName = classOf[UnsafeFixedWidthAggregationMap].getName - ctx.addMutableState(hashMapClassName, hashMapTerm, s"$hashMapTerm = $thisPlan.createHashMap();") + ctx.addMutableState(hashMapClassName, hashMapTerm, s"") sorterTerm = ctx.freshName("sorter") ctx.addMutableState(classOf[UnsafeKVExternalSorter].getName, sorterTerm, "") @@ -467,6 +467,7 @@ case class TungstenAggregate( s""" ${if (isAggregateHashMapSupported) aggregateHashMapGenerator.generate() else ""} private void $doAgg() throws java.io.IOException { + $hashMapTerm = $thisPlan.createHashMap(); ${child.asInstanceOf[CodegenSupport].produce(ctx, this)} $iterTerm = $thisPlan.finishAggregate($hashMapTerm, $sorterTerm); diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala index e3d554c2de..a8f854136c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala @@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.physical.{BroadcastDistribution, Distribution, Partitioning, UnspecifiedDistribution} import org.apache.spark.sql.execution.{BinaryNode, CodegenSupport, SparkPlan} import org.apache.spark.sql.execution.metric.SQLMetrics +import org.apache.spark.sql.types.LongType /** * Performs an inner hash join of two child relations. When the output RDD of this operator is @@ -50,10 +51,7 @@ case class BroadcastHashJoin( override def outputPartitioning: Partitioning = streamedPlan.outputPartitioning override def requiredChildDistribution: Seq[Distribution] = { - val mode = HashedRelationBroadcastMode( - canJoinKeyFitWithinLong, - rewriteKeyExpr(buildKeys), - buildPlan.output) + val mode = HashedRelationBroadcastMode(buildKeys) buildSide match { case BuildLeft => BroadcastDistribution(mode) :: UnspecifiedDistribution :: Nil @@ -68,7 +66,7 @@ case class BroadcastHashJoin( val broadcastRelation = buildPlan.executeBroadcast[HashedRelation]() streamedPlan.execute().mapPartitions { streamedIter => val hashed = broadcastRelation.value.asReadOnlyCopy() - TaskContext.get().taskMetrics().incPeakExecutionMemory(hashed.getMemorySize) + TaskContext.get().taskMetrics().incPeakExecutionMemory(hashed.estimatedSize) join(streamedIter, hashed, numOutputRows) } } @@ -105,7 +103,7 @@ case class BroadcastHashJoin( ctx.addMutableState(clsName, relationTerm, s""" | $relationTerm = (($clsName) $broadcast.value()).asReadOnlyCopy(); - | incPeakExecutionMemory($relationTerm.getMemorySize()); + | incPeakExecutionMemory($relationTerm.estimatedSize()); """.stripMargin) (broadcastRelation, relationTerm) } @@ -118,15 +116,13 @@ case class BroadcastHashJoin( ctx: CodegenContext, input: Seq[ExprCode]): (ExprCode, String) = { ctx.currentVars = input - if (canJoinKeyFitWithinLong) { + if (streamedKeys.length == 1 && streamedKeys.head.dataType == LongType) { // generate the join key as Long - val expr = rewriteKeyExpr(streamedKeys).head - val ev = BindReferences.bindReference(expr, streamedPlan.output).gen(ctx) + val ev = streamedKeys.head.gen(ctx) (ev, ev.isNull) } else { // generate the join key as UnsafeRow - val keyExpr = streamedKeys.map(BindReferences.bindReference(_, streamedPlan.output)) - val ev = GenerateUnsafeProjection.createCode(ctx, keyExpr) + val ev = GenerateUnsafeProjection.createCode(ctx, streamedKeys) (ev, s"${ev.value}.anyNull()") } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala index 8f45d57126..d6feedc272 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala @@ -17,16 +17,12 @@ package org.apache.spark.sql.execution.joins -import java.util.NoSuchElementException - -import org.apache.spark.TaskContext import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.execution.{RowIterator, SparkPlan} import org.apache.spark.sql.execution.metric.LongSQLMetric -import org.apache.spark.sql.types.{IntegerType, IntegralType, LongType} -import org.apache.spark.util.collection.CompactBuffer +import org.apache.spark.sql.types.{IntegralType, LongType} trait HashJoin { self: SparkPlan => @@ -59,9 +55,15 @@ trait HashJoin { case BuildRight => (right, left) } - protected lazy val (buildKeys, streamedKeys) = buildSide match { - case BuildLeft => (leftKeys, rightKeys) - case BuildRight => (rightKeys, leftKeys) + protected lazy val (buildKeys, streamedKeys) = { + require(leftKeys.map(_.dataType) == rightKeys.map(_.dataType), + "Join keys from two sides should have same types") + val lkeys = rewriteKeyExpr(leftKeys).map(BindReferences.bindReference(_, left.output)) + val rkeys = rewriteKeyExpr(rightKeys).map(BindReferences.bindReference(_, right.output)) + buildSide match { + case BuildLeft => (lkeys, rkeys) + case BuildRight => (rkeys, lkeys) + } } /** @@ -69,7 +71,7 @@ trait HashJoin { * * If not, returns the original expressions. */ - def rewriteKeyExpr(keys: Seq[Expression]): Seq[Expression] = { + private def rewriteKeyExpr(keys: Seq[Expression]): Seq[Expression] = { var keyExpr: Expression = null var width = 0 keys.foreach { e => @@ -84,17 +86,8 @@ trait HashJoin { width = dt.defaultSize } else { val bits = dt.defaultSize * 8 - // hashCode of Long is (l >> 32) ^ l.toInt, it means the hash code of an long with same - // value in high 32 bit and low 32 bit will be 0. To avoid the worst case that keys - // with two same ints have hash code 0, we rotate the bits of second one. - val rotated = if (e.dataType == IntegerType) { - // (e >>> 15) | (e << 17) - BitwiseOr(ShiftRightUnsigned(e, Literal(15)), ShiftLeft(e, Literal(17))) - } else { - e - } keyExpr = BitwiseOr(ShiftLeft(keyExpr, Literal(bits)), - BitwiseAnd(Cast(rotated, LongType), Literal((1L << bits) - 1))) + BitwiseAnd(Cast(e, LongType), Literal((1L << bits) - 1))) width -= bits } // TODO: support BooleanType, DateType and TimestampType @@ -105,17 +98,11 @@ trait HashJoin { keyExpr :: Nil } - protected lazy val canJoinKeyFitWithinLong: Boolean = { - val sameTypes = buildKeys.map(_.dataType) == streamedKeys.map(_.dataType) - val key = rewriteKeyExpr(buildKeys) - sameTypes && key.length == 1 && key.head.dataType.isInstanceOf[LongType] - } - protected def buildSideKeyGenerator(): Projection = - UnsafeProjection.create(rewriteKeyExpr(buildKeys), buildPlan.output) + UnsafeProjection.create(buildKeys) protected def streamSideKeyGenerator(): UnsafeProjection = - UnsafeProjection.create(rewriteKeyExpr(streamedKeys), streamedPlan.output) + UnsafeProjection.create(streamedKeys) @transient private[this] lazy val boundCondition = if (condition.isDefined) { newPredicate(condition.get, streamedPlan.output ++ buildPlan.output) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala index 5ccb435686..68b5486faa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala @@ -18,24 +18,22 @@ package org.apache.spark.sql.execution.joins import java.io.{Externalizable, IOException, ObjectInput, ObjectOutput} -import java.util.{HashMap => JavaHashMap} -import org.apache.spark.{SparkConf, SparkEnv, SparkException, TaskContext} -import org.apache.spark.memory.{StaticMemoryManager, TaskMemoryManager} +import org.apache.spark.{SparkConf, SparkEnv, SparkException} +import org.apache.spark.memory.{MemoryConsumer, MemoryMode, StaticMemoryManager, TaskMemoryManager} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical.BroadcastMode -import org.apache.spark.sql.execution.SparkSqlSerializer +import org.apache.spark.sql.types.LongType import org.apache.spark.unsafe.Platform import org.apache.spark.unsafe.map.BytesToBytesMap import org.apache.spark.util.{KnownSizeEstimation, Utils} -import org.apache.spark.util.collection.CompactBuffer /** * Interface for a hashed relation by some key. Use [[HashedRelation.apply]] to create a concrete * object. */ -private[execution] sealed trait HashedRelation { +private[execution] sealed trait HashedRelation extends KnownSizeEstimation { /** * Returns matched rows. * @@ -74,51 +72,36 @@ private[execution] sealed trait HashedRelation { */ def asReadOnlyCopy(): HashedRelation - /** - * Returns the size of used memory. - */ - def getMemorySize: Long = 1L // to make the test happy - /** * Release any used resources. */ - def close(): Unit = {} - - // This is a helper method to implement Externalizable, and is used by - // GeneralHashedRelation and UniqueKeyHashedRelation - protected def writeBytes(out: ObjectOutput, serialized: Array[Byte]): Unit = { - out.writeInt(serialized.length) // Write the length of serialized bytes first - out.write(serialized) - } - - // This is a helper method to implement Externalizable, and is used by - // GeneralHashedRelation and UniqueKeyHashedRelation - protected def readBytes(in: ObjectInput): Array[Byte] = { - val serializedSize = in.readInt() // Read the length of serialized bytes first - val bytes = new Array[Byte](serializedSize) - in.readFully(bytes) - bytes - } + def close(): Unit } private[execution] object HashedRelation { /** * Create a HashedRelation from an Iterator of InternalRow. - * - * Note: The caller should make sure that these InternalRow are different objects. */ def apply( - canJoinKeyFitWithinLong: Boolean, input: Iterator[InternalRow], - keyGenerator: Projection, - sizeEstimate: Int = 64): HashedRelation = { + key: Seq[Expression], + sizeEstimate: Int = 64, + taskMemoryManager: TaskMemoryManager = null): HashedRelation = { + val mm = Option(taskMemoryManager).getOrElse { + new TaskMemoryManager( + new StaticMemoryManager( + new SparkConf().set("spark.memory.offHeap.enabled", "false"), + Long.MaxValue, + Long.MaxValue, + 1), + 0) + } - if (canJoinKeyFitWithinLong) { - LongHashedRelation(input, keyGenerator, sizeEstimate) + if (key.length == 1 && key.head.dataType == LongType) { + LongHashedRelation(input, key, sizeEstimate, mm) } else { - UnsafeHashedRelation( - input, keyGenerator.asInstanceOf[UnsafeProjection], sizeEstimate) + UnsafeHashedRelation(input, key, sizeEstimate, mm) } } } @@ -133,7 +116,7 @@ private[execution] object HashedRelation { private[joins] class UnsafeHashedRelation( private var numFields: Int, private var binaryMap: BytesToBytesMap) - extends HashedRelation with KnownSizeEstimation with Externalizable { + extends HashedRelation with Externalizable { private[joins] def this() = this(0, null) // Needed for serialization @@ -142,10 +125,6 @@ private[joins] class UnsafeHashedRelation( override def asReadOnlyCopy(): UnsafeHashedRelation = new UnsafeHashedRelation(numFields, binaryMap) - override def getMemorySize: Long = { - binaryMap.getTotalMemoryConsumption - } - override def estimatedSize: Long = { binaryMap.getTotalMemoryConsumption } @@ -276,20 +255,10 @@ private[joins] object UnsafeHashedRelation { def apply( input: Iterator[InternalRow], - keyGenerator: UnsafeProjection, - sizeEstimate: Int): HashedRelation = { + key: Seq[Expression], + sizeEstimate: Int, + taskMemoryManager: TaskMemoryManager): HashedRelation = { - val taskMemoryManager = if (TaskContext.get() != null) { - TaskContext.get().taskMemoryManager() - } else { - new TaskMemoryManager( - new StaticMemoryManager( - new SparkConf().set("spark.memory.offHeap.enabled", "false"), - Long.MaxValue, - Long.MaxValue, - 1), - 0) - } val pageSizeBytes = Option(SparkEnv.get).map(_.memoryManager.pageSizeBytes) .getOrElse(new SparkConf().getSizeAsBytes("spark.buffer.pageSize", "16m")) @@ -300,6 +269,7 @@ private[joins] object UnsafeHashedRelation { pageSizeBytes) // Create a mapping of buildKeys -> rows + val keyGenerator = UnsafeProjection.create(key) var numFields = 0 while (input.hasNext) { val row = input.next().asInstanceOf[UnsafeRow] @@ -322,143 +292,430 @@ private[joins] object UnsafeHashedRelation { } /** - * An interface for a hashed relation that the key is a Long. + * An append-only hash map mapping from key of Long to UnsafeRow. + * + * The underlying bytes of all values (UnsafeRows) are packed together as a single byte array + * (`page`) in this format: + * + * [bytes of row1][address1][bytes of row2][address1] ... + * + * address1 (8 bytes) is the offset and size of next value for the same key as row1, any key + * could have multiple values. the address at the end of last value for every key is 0. + * + * The keys and addresses of their values could be stored in two modes: + * + * 1) sparse mode: the keys and addresses are stored in `array` as: + * + * [key1][address1][key2][address2]...[] + * + * address1 (Long) is the offset (in `page`) and size of the value for key1. The position of key1 + * is determined by `key1 % cap`. Quadratic probing with triangular numbers is used to address + * hash collision. + * + * 2) dense mode: all the addresses are packed into a single array of long, as: + * + * [address1] [address2] ... + * + * address1 (Long) is the offset (in `page`) and size of the value for key1, the position is + * determined by `key1 - minKey`. + * + * The map is created as sparse mode, then key-value could be appended into it. Once finish + * appending, caller could all optimize() to try to turn the map into dense mode, which is faster + * to probe. + * + * see http://java-performance.info/implementing-world-fastest-java-int-to-int-hash-map/ */ -private[joins] trait LongHashedRelation extends HashedRelation { - override def get(key: InternalRow): Iterator[InternalRow] = { - get(key.getLong(0)) +private[execution] final class LongToUnsafeRowMap(var mm: TaskMemoryManager, capacity: Int) + extends MemoryConsumer(mm) with Externalizable { + + // Whether the keys are stored in dense mode or not. + private var isDense = false + + // The minimum key + private var minKey = Long.MaxValue + + // The maxinum key + private var maxKey = Long.MinValue + + // The array to store the key and offset of UnsafeRow in the page. + // + // Sparse mode: [key1] [offset1 | size1] [key2] [offset | size2] ... + // Dense mode: [offset1 | size1] [offset2 | size2] + private var array: Array[Long] = null + private var mask: Int = 0 + + // The page to store all bytes of UnsafeRow and the pointer to next rows. + // [row1][pointer1] [row2][pointer2] + private var page: Array[Byte] = null + + // Current write cursor in the page. + private var cursor = Platform.BYTE_ARRAY_OFFSET + + // The total number of values of all keys. + private var numValues = 0 + + // The number of unique keys. + private var numKeys = 0 + + // needed by serializer + def this() = { + this( + new TaskMemoryManager( + new StaticMemoryManager( + new SparkConf().set("spark.memory.offHeap.enabled", "false"), + Long.MaxValue, + Long.MaxValue, + 1), + 0), + 0) } - override def getValue(key: InternalRow): InternalRow = { - getValue(key.getLong(0)) + + private def acquireMemory(size: Long): Unit = { + // do not support spilling + val got = mm.acquireExecutionMemory(size, MemoryMode.ON_HEAP, this) + if (got < size) { + mm.releaseExecutionMemory(got, MemoryMode.ON_HEAP, this) + throw new SparkException(s"Can't acquire $size bytes memory to build hash relation") + } } -} -private[joins] final class GeneralLongHashedRelation( - private var hashTable: JavaHashMap[Long, CompactBuffer[UnsafeRow]]) - extends LongHashedRelation with Externalizable { + private def freeMemory(size: Long): Unit = { + mm.releaseExecutionMemory(size, MemoryMode.ON_HEAP, this) + } - // Needed for serialization (it is public to make Java serialization work) - def this() = this(null) + private def init(): Unit = { + if (mm != null) { + var n = 1 + while (n < capacity) n *= 2 + acquireMemory(n * 2 * 8 + (1 << 20)) + array = new Array[Long](n * 2) + mask = n * 2 - 2 + page = new Array[Byte](1 << 20) // 1M bytes + } + } + + init() + + def spill(size: Long, trigger: MemoryConsumer): Long = { + 0L + } + + /** + * Returns whether all the keys are unique. + */ + def keyIsUnique: Boolean = numKeys == numValues + + /** + * Returns total memory consumption. + */ + def getTotalMemoryConsumption: Long = { + array.length * 8 + page.length + } - override def keyIsUnique: Boolean = false + /** + * Returns the first slot of array that store the keys (sparse mode). + */ + private def firstSlot(key: Long): Int = { + val h = key * 0x9E3779B9L + (h ^ (h >> 32)).toInt & mask + } - override def asReadOnlyCopy(): GeneralLongHashedRelation = - new GeneralLongHashedRelation(hashTable) + /** + * Returns the next probe in the array. + */ + private def nextSlot(pos: Int): Int = { + (pos + 2) & mask + } - override def get(key: Long): Iterator[InternalRow] = { - val rows = hashTable.get(key) - if (rows != null) { - rows.toIterator + private def getRow(address: Long, resultRow: UnsafeRow): UnsafeRow = { + val offset = address >>> 32 + val size = address & 0xffffffffL + resultRow.pointTo(page, offset, size.toInt) + resultRow + } + + /** + * Returns the single UnsafeRow for given key, or null if not found. + */ + def getValue(key: Long, resultRow: UnsafeRow): UnsafeRow = { + if (isDense) { + val idx = (key - minKey).toInt + if (idx >= 0 && key <= maxKey && array(idx) > 0) { + return getRow(array(idx), resultRow) + } } else { - null + var pos = firstSlot(key) + while (array(pos + 1) != 0) { + if (array(pos) == key) { + return getRow(array(pos + 1), resultRow) + } + pos = nextSlot(pos) + } + } + null + } + + /** + * Returns an interator of UnsafeRow for multiple linked values. + */ + private def valueIter(address: Long, resultRow: UnsafeRow): Iterator[UnsafeRow] = { + new Iterator[UnsafeRow] { + var addr = address + override def hasNext: Boolean = addr != 0 + override def next(): UnsafeRow = { + val offset = addr >>> 32 + val size = addr & 0xffffffffL + resultRow.pointTo(page, offset, size.toInt) + addr = Platform.getLong(page, offset + size) + resultRow + } + } + } + + /** + * Returns an iterator for all the values for the given key, or null if no value found. + */ + def get(key: Long, resultRow: UnsafeRow): Iterator[UnsafeRow] = { + if (isDense) { + val idx = (key - minKey).toInt + if (idx >=0 && key <= maxKey && array(idx) > 0) { + return valueIter(array(idx), resultRow) + } + } else { + var pos = firstSlot(key) + while (array(pos + 1) != 0) { + if (array(pos) == key) { + return valueIter(array(pos + 1), resultRow) + } + pos = nextSlot(pos) + } + } + null + } + + /** + * Appends the key and row into this map. + */ + def append(key: Long, row: UnsafeRow): Unit = { + if (key < minKey) { + minKey = key + } + if (key > maxKey) { + maxKey = key + } + + // There is 8 bytes for the pointer to next value + if (cursor + 8 + row.getSizeInBytes > page.length + Platform.BYTE_ARRAY_OFFSET) { + val used = page.length + if (used * 2L > (1L << 31)) { + sys.error("Can't allocate a page that is larger than 2G") + } + acquireMemory(used * 2) + val newPage = new Array[Byte](used * 2) + System.arraycopy(page, 0, newPage, 0, cursor - Platform.BYTE_ARRAY_OFFSET) + page = newPage + freeMemory(used) + } + + // copy the bytes of UnsafeRow + val offset = cursor + Platform.copyMemory(row.getBaseObject, row.getBaseOffset, page, cursor, row.getSizeInBytes) + cursor += row.getSizeInBytes + Platform.putLong(page, cursor, 0) + cursor += 8 + numValues += 1 + updateIndex(key, (offset.toLong << 32) | row.getSizeInBytes) + } + + /** + * Update the address in array for given key. + */ + private def updateIndex(key: Long, address: Long): Unit = { + var pos = firstSlot(key) + while (array(pos) != key && array(pos + 1) != 0) { + pos = nextSlot(pos) + } + if (array(pos + 1) == 0) { + // this is the first value for this key, put the address in array. + array(pos) = key + array(pos + 1) = address + numKeys += 1 + if (numKeys * 4 > array.length) { + // reach half of the capacity + growArray() + } + } else { + // there are some values for this key, put the address in the front of them. + val pointer = (address >>> 32) + (address & 0xffffffffL) + Platform.putLong(page, pointer, array(pos + 1)) + array(pos + 1) = address + } + } + + private def growArray(): Unit = { + var old_array = array + val n = array.length + numKeys = 0 + acquireMemory(n * 2 * 8) + array = new Array[Long](n * 2) + mask = n * 2 - 2 + var i = 0 + while (i < old_array.length) { + if (old_array(i + 1) > 0) { + updateIndex(old_array(i), old_array(i + 1)) + } + i += 2 + } + old_array = null // release the reference to old array + freeMemory(n * 8) + } + + /** + * Try to turn the map into dense mode, which is faster to probe. + */ + def optimize(): Unit = { + val range = maxKey - minKey + // Convert to dense mode if it does not require more memory or could fit within L1 cache + if (range < array.length || range < 1024) { + try { + acquireMemory((range + 1) * 8) + } catch { + case e: SparkException => + // there is no enough memory to convert + return + } + val denseArray = new Array[Long]((range + 1).toInt) + var i = 0 + while (i < array.length) { + if (array(i + 1) > 0) { + val idx = (array(i) - minKey).toInt + denseArray(idx) = array(i + 1) + } + i += 2 + } + val old_length = array.length + array = denseArray + isDense = true + freeMemory(old_length * 8) + } + } + + /** + * Free all the memory acquired by this map. + */ + def free(): Unit = { + if (page != null) { + freeMemory(page.length) + page = null + } + if (array != null) { + freeMemory(array.length * 8) + array = null } } override def writeExternal(out: ObjectOutput): Unit = { - writeBytes(out, SparkSqlSerializer.serialize(hashTable)) + out.writeBoolean(isDense) + out.writeLong(minKey) + out.writeLong(maxKey) + out.writeInt(numKeys) + out.writeInt(numValues) + + out.writeInt(array.length) + val buffer = new Array[Byte](4 << 10) + var offset = Platform.LONG_ARRAY_OFFSET + val end = array.length * 8 + Platform.LONG_ARRAY_OFFSET + while (offset < end) { + val size = Math.min(buffer.length, end - offset) + Platform.copyMemory(array, offset, buffer, Platform.BYTE_ARRAY_OFFSET, size) + out.write(buffer, 0, size) + offset += size + } + + val used = cursor - Platform.BYTE_ARRAY_OFFSET + out.writeInt(used) + out.write(page, 0, used) } override def readExternal(in: ObjectInput): Unit = { - hashTable = SparkSqlSerializer.deserialize(readBytes(in)) + isDense = in.readBoolean() + minKey = in.readLong() + maxKey = in.readLong() + numKeys = in.readInt() + numValues = in.readInt() + + val length = in.readInt() + array = new Array[Long](length) + mask = length - 2 + val buffer = new Array[Byte](4 << 10) + var offset = Platform.LONG_ARRAY_OFFSET + val end = length * 8 + Platform.LONG_ARRAY_OFFSET + while (offset < end) { + val size = Math.min(buffer.length, end - offset) + in.readFully(buffer, 0, size) + Platform.copyMemory(buffer, Platform.BYTE_ARRAY_OFFSET, array, offset, size) + offset += size + } + + val numBytes = in.readInt() + page = new Array[Byte](numBytes) + in.readFully(page) } } -/** - * A relation that pack all the rows into a byte array, together with offsets and sizes. - * - * All the bytes of UnsafeRow are packed together as `bytes`: - * - * [ Row0 ][ Row1 ][] ... [ RowN ] - * - * With keys: - * - * start start+1 ... start+N - * - * `offsets` are offsets of UnsafeRows in the `bytes` - * `sizes` are the numbers of bytes of UnsafeRows, 0 means no row for this key. - * - * For example, two UnsafeRows (24 bytes and 32 bytes), with keys as 3 and 5 will stored as: - * - * start = 3 - * offsets = [0, 0, 24] - * sizes = [24, 0, 32] - * bytes = [0 - 24][][24 - 56] - */ -private[joins] final class LongArrayRelation( - private var numFields: Int, - private var start: Long, - private var offsets: Array[Int], - private var sizes: Array[Int], - private var bytes: Array[Byte] - ) extends LongHashedRelation with Externalizable { +private[joins] class LongHashedRelation( + private var nFields: Int, + private var map: LongToUnsafeRowMap) extends HashedRelation with Externalizable { + + private var resultRow: UnsafeRow = new UnsafeRow(nFields) // Needed for serialization (it is public to make Java serialization work) - def this() = this(0, 0L, null, null, null) + def this() = this(0, null) - override def keyIsUnique: Boolean = true + override def asReadOnlyCopy(): LongHashedRelation = new LongHashedRelation(nFields, map) - override def asReadOnlyCopy(): LongArrayRelation = { - new LongArrayRelation(numFields, start, offsets, sizes, bytes) + override def estimatedSize: Long = { + map.getTotalMemoryConsumption } - override def getMemorySize: Long = { - offsets.length * 4 + sizes.length * 4 + bytes.length + override def get(key: InternalRow): Iterator[InternalRow] = { + if (key.isNullAt(0)) { + null + } else { + get(key.getLong(0)) + } } - override def get(key: Long): Iterator[InternalRow] = { - val row = getValue(key) - if (row != null) { - Seq(row).toIterator - } else { + override def getValue(key: InternalRow): InternalRow = { + if (key.isNullAt(0)) { null + } else { + getValue(key.getLong(0)) } } - var resultRow = new UnsafeRow(numFields) + override def get(key: Long): Iterator[InternalRow] = + map.get(key, resultRow) + override def getValue(key: Long): InternalRow = { - val idx = (key - start).toInt - if (idx >= 0 && idx < sizes.length && sizes(idx) > 0) { - resultRow.pointTo(bytes, Platform.BYTE_ARRAY_OFFSET + offsets(idx), sizes(idx)) - resultRow - } else { - null - } + map.getValue(key, resultRow) + } + + override def keyIsUnique: Boolean = map.keyIsUnique + + override def close(): Unit = { + map.free() } override def writeExternal(out: ObjectOutput): Unit = { - out.writeInt(numFields) - out.writeLong(start) - out.writeInt(sizes.length) - var i = 0 - while (i < sizes.length) { - out.writeInt(sizes(i)) - i += 1 - } - out.writeInt(bytes.length) - out.write(bytes) + out.writeInt(nFields) + out.writeObject(map) } override def readExternal(in: ObjectInput): Unit = { - numFields = in.readInt() - resultRow = new UnsafeRow(numFields) - start = in.readLong() - val length = in.readInt() - // read sizes of rows - sizes = new Array[Int](length) - offsets = new Array[Int](length) - var i = 0 - var offset = 0 - while (i < length) { - offsets(i) = offset - sizes(i) = in.readInt() - offset += sizes(i) - i += 1 - } - // read all the bytes - val total = in.readInt() - assert(total == offset) - bytes = new Array[Byte](total) - in.readFully(bytes) + nFields = in.readInt() + resultRow = new UnsafeRow(nFields) + map = in.readObject().asInstanceOf[LongToUnsafeRowMap] } } @@ -466,96 +723,45 @@ private[joins] final class LongArrayRelation( * Create hashed relation with key that is long. */ private[joins] object LongHashedRelation { - - val DENSE_FACTOR = 0.2 - def apply( - input: Iterator[InternalRow], - keyGenerator: Projection, - sizeEstimate: Int): HashedRelation = { + input: Iterator[InternalRow], + key: Seq[Expression], + sizeEstimate: Int, + taskMemoryManager: TaskMemoryManager): LongHashedRelation = { - // TODO: use LongToBytesMap for better memory efficiency - val hashTable = new JavaHashMap[Long, CompactBuffer[UnsafeRow]](sizeEstimate) + val map: LongToUnsafeRowMap = new LongToUnsafeRowMap(taskMemoryManager, sizeEstimate) + val keyGenerator = UnsafeProjection.create(key) // Create a mapping of key -> rows var numFields = 0 - var keyIsUnique = true - var minKey = Long.MaxValue - var maxKey = Long.MinValue while (input.hasNext) { val unsafeRow = input.next().asInstanceOf[UnsafeRow] numFields = unsafeRow.numFields() val rowKey = keyGenerator(unsafeRow) - if (!rowKey.anyNull) { + if (!rowKey.isNullAt(0)) { val key = rowKey.getLong(0) - minKey = math.min(minKey, key) - maxKey = math.max(maxKey, key) - val existingMatchList = hashTable.get(key) - val matchList = if (existingMatchList == null) { - val newMatchList = new CompactBuffer[UnsafeRow]() - hashTable.put(key, newMatchList) - newMatchList - } else { - keyIsUnique = false - existingMatchList - } - matchList += unsafeRow - } - } - - if (keyIsUnique && hashTable.size() > (maxKey - minKey) * DENSE_FACTOR) { - // The keys are dense enough, so use LongArrayRelation - val length = (maxKey - minKey).toInt + 1 - val sizes = new Array[Int](length) - val offsets = new Array[Int](length) - var offset = 0 - var i = 0 - while (i < length) { - val rows = hashTable.get(i + minKey) - if (rows != null) { - offsets(i) = offset - sizes(i) = rows(0).getSizeInBytes - offset += sizes(i) - } - i += 1 - } - val bytes = new Array[Byte](offset) - i = 0 - while (i < length) { - val rows = hashTable.get(i + minKey) - if (rows != null) { - rows(0).writeToMemory(bytes, Platform.BYTE_ARRAY_OFFSET + offsets(i)) - } - i += 1 + map.append(key, unsafeRow) } - new LongArrayRelation(numFields, minKey, offsets, sizes, bytes) - } else { - new GeneralLongHashedRelation(hashTable) } + map.optimize() + new LongHashedRelation(numFields, map) } } /** The HashedRelationBroadcastMode requires that rows are broadcasted as a HashedRelation. */ -private[execution] case class HashedRelationBroadcastMode( - canJoinKeyFitWithinLong: Boolean, - keys: Seq[Expression], - attributes: Seq[Attribute]) extends BroadcastMode { +private[execution] case class HashedRelationBroadcastMode(key: Seq[Expression]) + extends BroadcastMode { override def transform(rows: Array[InternalRow]): HashedRelation = { - val generator = UnsafeProjection.create(keys, attributes) - HashedRelation(canJoinKeyFitWithinLong, rows.iterator, generator, rows.length) + HashedRelation(rows.iterator, canonicalizedKey, rows.length) } - private lazy val canonicalizedKeys: Seq[Expression] = { - keys.map { e => - BindReferences.bindReference(e.canonicalized, attributes) - } + private lazy val canonicalizedKey: Seq[Expression] = { + key.map { e => e.canonicalized } } override def compatibleWith(other: BroadcastMode): Boolean = other match { - case m: HashedRelationBroadcastMode => - canJoinKeyFitWithinLong == m.canJoinKeyFitWithinLong && - canonicalizedKeys == m.canonicalizedKeys + case m: HashedRelationBroadcastMode => canonicalizedKey == m.canonicalizedKey case _ => false } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala index bf86096379..0c3e3c3fc1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala @@ -17,11 +17,10 @@ package org.apache.spark.sql.execution.joins -import org.apache.spark.{SparkException, TaskContext} -import org.apache.spark.memory.MemoryMode +import org.apache.spark.TaskContext import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{Expression, JoinedRow, UnsafeRow} +import org.apache.spark.sql.catalyst.expressions.{BindReferences, Expression, UnsafeRow} import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} @@ -57,54 +56,20 @@ case class ShuffledHashJoin( override def requiredChildDistribution: Seq[Distribution] = ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil - private def buildHashedRelation(iter: Iterator[UnsafeRow]): HashedRelation = { + private def buildHashedRelation(iter: Iterator[InternalRow]): HashedRelation = { val context = TaskContext.get() - if (!canJoinKeyFitWithinLong) { - // build BytesToBytesMap - val relation = HashedRelation(canJoinKeyFitWithinLong, iter, buildSideKeyGenerator) - // This relation is usually used until the end of task. - context.addTaskCompletionListener((t: TaskContext) => - relation.close() - ) - return relation - } - - // try to acquire some memory for the hash table, it could trigger other operator to free some - // memory. The memory acquired here will mostly be used until the end of task. - val memoryManager = context.taskMemoryManager() - var acquired = 0L - var used = 0L + val relation = HashedRelation(iter, buildKeys, taskMemoryManager = context.taskMemoryManager()) + // This relation is usually used until the end of task. context.addTaskCompletionListener((t: TaskContext) => - memoryManager.releaseExecutionMemory(acquired, MemoryMode.ON_HEAP, null) + relation.close() ) - - val copiedIter = iter.map { row => - // It's hard to guess what's exactly memory will be used, we have a rough guess here. - // TODO: use LongToBytesMap instead of HashMap for memory efficiency - // Each pair in HashMap will have UnsafeRow, CompactBuffer, maybe 10+ pointers - val needed = 150 + row.getSizeInBytes - if (needed > acquired - used) { - val got = memoryManager.acquireExecutionMemory( - Math.max(memoryManager.pageSizeBytes(), needed), MemoryMode.ON_HEAP, null) - acquired += got - if (got < needed) { - throw new SparkException("Can't acquire enough memory to build hash map in shuffled" + - "hash join, please use sort merge join by setting " + - "spark.sql.join.preferSortMergeJoin=true") - } - } - used += needed - // HashedRelation requires that the UnsafeRow should be separate objects. - row.copy() - } - - HashedRelation(canJoinKeyFitWithinLong, copiedIter, buildSideKeyGenerator) + relation } protected override def doExecute(): RDD[InternalRow] = { val numOutputRows = longMetric("numOutputRows") streamedPlan.execute().zipPartitions(buildPlan.execute()) { (streamIter, buildIter) => - val hashed = buildHashedRelation(buildIter.asInstanceOf[Iterator[UnsafeRow]]) + val hashed = buildHashedRelation(buildIter) join(streamIter, hashed, numOutputRows) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala index 5dbf619876..352fd07d0e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala @@ -21,6 +21,7 @@ import java.util.HashMap import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} import org.apache.spark.memory.{StaticMemoryManager, TaskMemoryManager} +import org.apache.spark.sql.execution.joins.LongToUnsafeRowMap import org.apache.spark.sql.SQLContext import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.execution.vectorized.AggregateHashMap @@ -179,8 +180,8 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite { Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz Join w long: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------- - Join w long codegen=false 5351 / 5531 3.9 255.1 1.0X - Join w long codegen=true 275 / 352 76.2 13.1 19.4X + Join w long codegen=false 3002 / 3262 7.0 143.2 1.0X + Join w long codegen=true 321 / 371 65.3 15.3 9.3X */ runBenchmark("Join w long duplicated", N) { @@ -193,8 +194,8 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite { Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz Join w long duplicated: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------- - Join w long duplicated codegen=false 4752 / 4906 4.4 226.6 1.0X - Join w long duplicated codegen=true 722 / 760 29.0 34.4 6.6X + Join w long duplicated codegen=false 3446 / 3478 6.1 164.3 1.0X + Join w long duplicated codegen=true 322 / 351 65.2 15.3 10.7X */ val dim2 = broadcast(sqlContext.range(M) @@ -211,8 +212,8 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite { Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz Join w 2 ints: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------- - Join w 2 ints codegen=false 9011 / 9121 2.3 429.7 1.0X - Join w 2 ints codegen=true 2565 / 2816 8.2 122.3 3.5X + Join w 2 ints codegen=false 4426 / 4501 4.7 211.1 1.0X + Join w 2 ints codegen=true 791 / 818 26.5 37.7 5.6X */ val dim3 = broadcast(sqlContext.range(M) @@ -259,8 +260,8 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite { Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz outer join w long: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------- - outer join w long codegen=false 5667 / 5780 3.7 270.2 1.0X - outer join w long codegen=true 216 / 226 97.2 10.3 26.3X + outer join w long codegen=false 3055 / 3189 6.9 145.7 1.0X + outer join w long codegen=true 261 / 276 80.5 12.4 11.7X */ runBenchmark("semi join w long", N) { @@ -272,8 +273,8 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite { Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz semi join w long: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------- - semi join w long codegen=false 4690 / 4953 4.5 223.7 1.0X - semi join w long codegen=true 211 / 229 99.2 10.1 22.2X + semi join w long codegen=false 1912 / 1990 11.0 91.2 1.0X + semi join w long codegen=true 237 / 244 88.3 11.3 8.1X */ } @@ -326,8 +327,8 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite { Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz shuffle hash join: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------- - shuffle hash join codegen=false 1538 / 1742 2.7 366.7 1.0X - shuffle hash join codegen=true 892 / 1329 4.7 212.6 1.7X + shuffle hash join codegen=false 1101 / 1391 3.8 262.6 1.0X + shuffle hash join codegen=true 528 / 578 7.9 125.8 2.1X */ } @@ -349,11 +350,11 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite { } ignore("hash and BytesToBytesMap") { - val N = 10 << 20 + val N = 20 << 20 val benchmark = new Benchmark("BytesToBytesMap", N) - benchmark.addCase("hash") { iter => + benchmark.addCase("UnsafeRowhash") { iter => var i = 0 val keyBytes = new Array[Byte](16) val key = new UnsafeRow(1) @@ -368,15 +369,34 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite { } } + benchmark.addCase("murmur3 hash") { iter => + var i = 0 + val keyBytes = new Array[Byte](16) + val key = new UnsafeRow(1) + key.pointTo(keyBytes, Platform.BYTE_ARRAY_OFFSET, 16) + var p = 524283 + var s = 0 + while (i < N) { + var h = Murmur3_x86_32.hashLong(i, 42) + key.setInt(0, h) + s += h + i += 1 + } + } + benchmark.addCase("fast hash") { iter => var i = 0 val keyBytes = new Array[Byte](16) val key = new UnsafeRow(1) key.pointTo(keyBytes, Platform.BYTE_ARRAY_OFFSET, 16) + var p = 524283 var s = 0 while (i < N) { - key.setInt(0, i % 1000) - val h = Murmur3_x86_32.hashLong(i % 1000, 42) + var h = i % p + if (h < 0) { + h += p + } + key.setInt(0, h) s += h i += 1 } @@ -475,6 +495,42 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite { } } + Seq(false, true).foreach { optimized => + benchmark.addCase(s"LongToUnsafeRowMap (opt=$optimized)") { iter => + var i = 0 + val valueBytes = new Array[Byte](16) + val value = new UnsafeRow(1) + value.pointTo(valueBytes, Platform.BYTE_ARRAY_OFFSET, 16) + value.setInt(0, 555) + val taskMemoryManager = new TaskMemoryManager( + new StaticMemoryManager( + new SparkConf().set("spark.memory.offHeap.enabled", "false"), + Long.MaxValue, + Long.MaxValue, + 1), + 0) + val map = new LongToUnsafeRowMap(taskMemoryManager, 64) + while (i < 65536) { + value.setInt(0, i) + val key = i % 100000 + map.append(key, value) + i += 1 + } + if (optimized) { + map.optimize() + } + var s = 0 + i = 0 + while (i < N) { + val key = i % 100000 + if (map.getValue(key, value) != null) { + s += 1 + } + i += 1 + } + } + } + Seq("off", "on").foreach { heap => benchmark.addCase(s"BytesToBytesMap ($heap Heap)") { iter => val taskMemoryManager = new TaskMemoryManager( @@ -493,18 +549,27 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite { val value = new UnsafeRow(1) value.pointTo(valueBytes, Platform.BYTE_ARRAY_OFFSET, 16) var i = 0 - while (i < N) { + val numKeys = 65536 + while (i < numKeys) { key.setInt(0, i % 65536) val loc = map.lookup(key.getBaseObject, key.getBaseOffset, key.getSizeInBytes, Murmur3_x86_32.hashLong(i % 65536, 42)) - if (loc.isDefined) { - value.pointTo(loc.getValueBase, loc.getValueOffset, loc.getValueLength) - value.setInt(0, value.getInt(0) + 1) - i += 1 - } else { + if (!loc.isDefined) { loc.append(key.getBaseObject, key.getBaseOffset, key.getSizeInBytes, value.getBaseObject, value.getBaseOffset, value.getSizeInBytes) } + i += 1 + } + i = 0 + var s = 0 + while (i < N) { + key.setInt(0, i % 100000) + val loc = map.lookup(key.getBaseObject, key.getBaseOffset, key.getSizeInBytes, + Murmur3_x86_32.hashLong(i % 100000, 42)) + if (loc.isDefined) { + s += 1 + } + i += 1 } } } @@ -535,16 +600,19 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite { Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz BytesToBytesMap: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------- - hash 112 / 116 93.2 10.7 1.0X - fast hash 65 / 69 160.9 6.2 1.7X - arrayEqual 66 / 69 159.1 6.3 1.7X - Java HashMap (Long) 137 / 182 76.3 13.1 0.8X - Java HashMap (two ints) 182 / 230 57.8 17.3 0.6X - Java HashMap (UnsafeRow) 511 / 565 20.5 48.8 0.2X - BytesToBytesMap (off Heap) 481 / 515 21.8 45.9 0.2X - BytesToBytesMap (on Heap) 529 / 600 19.8 50.5 0.2X - Aggregate HashMap 56 / 62 187.9 5.3 2.0X - */ + UnsafeRow hash 267 / 284 78.4 12.8 1.0X + murmur3 hash 102 / 129 205.5 4.9 2.6X + fast hash 79 / 96 263.8 3.8 3.4X + arrayEqual 164 / 172 128.2 7.8 1.6X + Java HashMap (Long) 321 / 399 65.4 15.3 0.8X + Java HashMap (two ints) 328 / 363 63.9 15.7 0.8X + Java HashMap (UnsafeRow) 1140 / 1200 18.4 54.3 0.2X + LongToUnsafeRowMap (opt=false) 378 / 400 55.5 18.0 0.7X + LongToUnsafeRowMap (opt=true) 144 / 152 145.2 6.9 1.9X + BytesToBytesMap (off Heap) 1300 / 1616 16.1 62.0 0.2X + BytesToBytesMap (on Heap) 1165 / 1202 18.0 55.5 0.2X + Aggregate HashMap 121 / 131 173.3 5.8 2.2X + */ benchmark.run() } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala index 9680f3a008..17f2343cf9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala @@ -38,8 +38,8 @@ class ExchangeSuite extends SparkPlanTest with SharedSQLContext { test("compatible BroadcastMode") { val mode1 = IdentityBroadcastMode - val mode2 = HashedRelationBroadcastMode(true, Literal(1) :: Nil, Seq()) - val mode3 = HashedRelationBroadcastMode(false, Literal("s") :: Nil, Seq()) + val mode2 = HashedRelationBroadcastMode(Literal(1L) :: Nil) + val mode3 = HashedRelationBroadcastMode(Literal("s") :: Nil) assert(mode1.compatibleWith(mode1)) assert(!mode1.compatibleWith(mode2)) @@ -56,10 +56,10 @@ class ExchangeSuite extends SparkPlanTest with SharedSQLContext { assert(plan sameResult plan) val exchange1 = BroadcastExchange(IdentityBroadcastMode, plan) - val hashMode = HashedRelationBroadcastMode(true, output, plan.output) + val hashMode = HashedRelationBroadcastMode(output) val exchange2 = BroadcastExchange(hashMode, plan) val hashMode2 = - HashedRelationBroadcastMode(true, Alias(output.head, "id2")() :: Nil, plan.output) + HashedRelationBroadcastMode(Alias(output.head, "id2")() :: Nil) val exchange3 = BroadcastExchange(hashMode2, plan) val exchange4 = ReusedExchange(output, exchange3) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala index ed87a99439..371a9ed617 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala @@ -30,15 +30,23 @@ import org.apache.spark.util.collection.CompactBuffer class HashedRelationSuite extends SparkFunSuite with SharedSQLContext { + val mm = new TaskMemoryManager( + new StaticMemoryManager( + new SparkConf().set("spark.memory.offHeap.enabled", "false"), + Long.MaxValue, + Long.MaxValue, + 1), + 0) + test("UnsafeHashedRelation") { val schema = StructType(StructField("a", IntegerType, true) :: Nil) val data = Array(InternalRow(0), InternalRow(1), InternalRow(2), InternalRow(2)) val toUnsafe = UnsafeProjection.create(schema) val unsafeData = data.map(toUnsafe(_).copy()) + val buildKey = Seq(BoundReference(0, IntegerType, false)) - val keyGenerator = UnsafeProjection.create(buildKey) - val hashed = UnsafeHashedRelation(unsafeData.iterator, keyGenerator, 1) + val hashed = UnsafeHashedRelation(unsafeData.iterator, buildKey, 1, mm) assert(hashed.isInstanceOf[UnsafeHashedRelation]) assert(hashed.get(unsafeData(0)).toArray === Array(unsafeData(0))) @@ -100,31 +108,45 @@ class HashedRelationSuite extends SparkFunSuite with SharedSQLContext { assert(java.util.Arrays.equals(os2.toByteArray, os.toByteArray)) } - test("LongArrayRelation") { + test("LongToUnsafeRowMap") { val unsafeProj = UnsafeProjection.create( Seq(BoundReference(0, IntegerType, false), BoundReference(1, IntegerType, true))) val rows = (0 until 100).map(i => unsafeProj(InternalRow(i, i + 1)).copy()) - val keyProj = UnsafeProjection.create(Seq(BoundReference(0, IntegerType, false))) - val longRelation = LongHashedRelation(rows.iterator, keyProj, 100) - assert(longRelation.isInstanceOf[LongArrayRelation]) - val longArrayRelation = longRelation.asInstanceOf[LongArrayRelation] + val key = Seq(BoundReference(0, IntegerType, false)) + val longRelation = LongHashedRelation(rows.iterator, key, 10, mm) + assert(longRelation.keyIsUnique) (0 until 100).foreach { i => - val row = longArrayRelation.getValue(i) + val row = longRelation.getValue(i) assert(row.getInt(0) === i) assert(row.getInt(1) === i + 1) } + val longRelation2 = LongHashedRelation(rows.iterator ++ rows.iterator, key, 100, mm) + assert(!longRelation2.keyIsUnique) + (0 until 100).foreach { i => + val rows = longRelation2.get(i).toArray + assert(rows.length === 2) + assert(rows(0).getInt(0) === i) + assert(rows(0).getInt(1) === i + 1) + assert(rows(1).getInt(0) === i) + assert(rows(1).getInt(1) === i + 1) + } + val os = new ByteArrayOutputStream() val out = new ObjectOutputStream(os) - longArrayRelation.writeExternal(out) + longRelation2.writeExternal(out) out.flush() val in = new ObjectInputStream(new ByteArrayInputStream(os.toByteArray)) - val relation = new LongArrayRelation() + val relation = new LongHashedRelation() relation.readExternal(in) + assert(!relation.keyIsUnique) (0 until 100).foreach { i => - val row = longArrayRelation.getValue(i) - assert(row.getInt(0) === i) - assert(row.getInt(1) === i + 1) + val rows = relation.get(i).toArray + assert(rows.length === 2) + assert(rows(0).getInt(0) === i) + assert(rows(0).getInt(1) === i + 1) + assert(rows(1).getInt(0) === i) + assert(rows(1).getInt(1) === i + 1) } } } -- cgit v1.2.3 From 5989c85b535f7f623392d6456d8b37052487f24b Mon Sep 17 00:00:00 2001 From: Nong Li Date: Sat, 9 Apr 2016 17:45:10 -0700 Subject: [SPARK-14217] [SQL] Fix bug if parquet data has columns that use dictionary encoding for some of the data ## What changes were proposed in this pull request? This PR is based on #12017 Currently, this causes batches where some values are dictionary encoded and some which are not. The non-dictionary encoded values cause us to remove the dictionary from the batch causing the first values to return garbage. This patch fixes the issue by first decoding the dictionary for the values that are already dictionary encoded before switching. A similar thing is done for the reverse case where the initial values are not dictionary encoded. ## How was this patch tested? This is difficult to test but replicated on a test cluster using a large tpcds data set. Author: Nong Li Author: Davies Liu Closes #12279 from davies/fix_dict. --- .../parquet/VectorizedColumnReader.java | 120 +++++++++++---------- .../sql/execution/vectorized/ColumnVector.java | 12 +++ 2 files changed, 78 insertions(+), 54 deletions(-) diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java index 6cc2fda587..ea37a08ab5 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java @@ -27,6 +27,7 @@ import org.apache.parquet.column.Encoding; import org.apache.parquet.column.page.*; import org.apache.parquet.column.values.ValuesReader; import org.apache.parquet.io.api.Binary; +import org.apache.parquet.schema.PrimitiveType; import org.apache.spark.sql.execution.vectorized.ColumnVector; import org.apache.spark.sql.types.DataTypes; @@ -114,57 +115,6 @@ public class VectorizedColumnReader { } } - /** - * TODO: Hoist the useDictionary branch to decode*Batch and make the batch page aligned. - */ - public boolean nextBoolean() { - if (!useDictionary) { - return dataColumn.readBoolean(); - } else { - return dictionary.decodeToBoolean(dataColumn.readValueDictionaryId()); - } - } - - public int nextInt() { - if (!useDictionary) { - return dataColumn.readInteger(); - } else { - return dictionary.decodeToInt(dataColumn.readValueDictionaryId()); - } - } - - public long nextLong() { - if (!useDictionary) { - return dataColumn.readLong(); - } else { - return dictionary.decodeToLong(dataColumn.readValueDictionaryId()); - } - } - - public float nextFloat() { - if (!useDictionary) { - return dataColumn.readFloat(); - } else { - return dictionary.decodeToFloat(dataColumn.readValueDictionaryId()); - } - } - - public double nextDouble() { - if (!useDictionary) { - return dataColumn.readDouble(); - } else { - return dictionary.decodeToDouble(dataColumn.readValueDictionaryId()); - } - } - - public Binary nextBinary() { - if (!useDictionary) { - return dataColumn.readBytes(); - } else { - return dictionary.decodeToBinary(dataColumn.readValueDictionaryId()); - } - } - /** * Advances to the next value. Returns true if the value is non-null. */ @@ -200,8 +150,26 @@ public class VectorizedColumnReader { ColumnVector dictionaryIds = column.reserveDictionaryIds(total); defColumn.readIntegers( num, dictionaryIds, column, rowId, maxDefLevel, (VectorizedValuesReader) dataColumn); - decodeDictionaryIds(rowId, num, column, dictionaryIds); + + if (column.hasDictionary() || (rowId == 0 && + (descriptor.getType() == PrimitiveType.PrimitiveTypeName.INT32 || + descriptor.getType() == PrimitiveType.PrimitiveTypeName.INT64 || + descriptor.getType() == PrimitiveType.PrimitiveTypeName.FLOAT || + descriptor.getType() == PrimitiveType.PrimitiveTypeName.DOUBLE || + descriptor.getType() == PrimitiveType.PrimitiveTypeName.BINARY))) { + // Column vector supports lazy decoding of dictionary values so just set the dictionary. + // We can't do this if rowId != 0 AND the column doesn't have a dictionary (i.e. some + // non-dictionary encoded values have already been added). + column.setDictionary(dictionary); + } else { + decodeDictionaryIds(rowId, num, column, dictionaryIds); + } } else { + if (column.hasDictionary() && rowId != 0) { + // This batch already has dictionary encoded values but this new page is not. The batch + // does not support a mix of dictionary and not so we will decode the dictionary. + decodeDictionaryIds(0, rowId, column, column.getDictionaryIds()); + } column.setDictionary(null); switch (descriptor.getType()) { case BOOLEAN: @@ -246,11 +214,45 @@ public class VectorizedColumnReader { ColumnVector dictionaryIds) { switch (descriptor.getType()) { case INT32: + if (column.dataType() == DataTypes.IntegerType || + DecimalType.is32BitDecimalType(column.dataType())) { + for (int i = rowId; i < rowId + num; ++i) { + column.putInt(i, dictionary.decodeToInt(dictionaryIds.getInt(i))); + } + } else if (column.dataType() == DataTypes.ByteType) { + for (int i = rowId; i < rowId + num; ++i) { + column.putByte(i, (byte) dictionary.decodeToInt(dictionaryIds.getInt(i))); + } + } else if (column.dataType() == DataTypes.ShortType) { + for (int i = rowId; i < rowId + num; ++i) { + column.putShort(i, (short) dictionary.decodeToInt(dictionaryIds.getInt(i))); + } + } else { + throw new NotImplementedException("Unimplemented type: " + column.dataType()); + } + break; + case INT64: + if (column.dataType() == DataTypes.LongType || + DecimalType.is64BitDecimalType(column.dataType())) { + for (int i = rowId; i < rowId + num; ++i) { + column.putLong(i, dictionary.decodeToLong(dictionaryIds.getInt(i))); + } + } else { + throw new NotImplementedException("Unimplemented type: " + column.dataType()); + } + break; + case FLOAT: + for (int i = rowId; i < rowId + num; ++i) { + column.putFloat(i, dictionary.decodeToFloat(dictionaryIds.getInt(i))); + } + break; + case DOUBLE: - case BINARY: - column.setDictionary(dictionary); + for (int i = rowId; i < rowId + num; ++i) { + column.putDouble(i, dictionary.decodeToDouble(dictionaryIds.getInt(i))); + } break; case INT96: if (column.dataType() == DataTypes.TimestampType) { @@ -263,6 +265,16 @@ public class VectorizedColumnReader { throw new NotImplementedException(); } break; + case BINARY: + // TODO: this is incredibly inefficient as it blows up the dictionary right here. We + // need to do this better. We should probably add the dictionary data to the ColumnVector + // and reuse it across batches. This should mean adding a ByteArray would just update + // the length and offset. + for (int i = rowId; i < rowId + num; ++i) { + Binary v = dictionary.decodeToBinary(dictionaryIds.getInt(i)); + column.putByteArray(i, v.getBytes()); + } + break; case FIXED_LEN_BYTE_ARRAY: // DecimalType written in the legacy mode if (DecimalType.is32BitDecimalType(column.dataType())) { diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java index 0b276e6c77..ff1f6680a7 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java @@ -912,6 +912,11 @@ public abstract class ColumnVector implements AutoCloseable { this.dictionary = dictionary; } + /** + * Returns true if this column has a dictionary. + */ + public boolean hasDictionary() { return this.dictionary != null; } + /** * Reserve a integer column for ids of dictionary. */ @@ -926,6 +931,13 @@ public abstract class ColumnVector implements AutoCloseable { return dictionaryIds; } + /** + * Returns the underlying integer column for ids of dictionary. + */ + public ColumnVector getDictionaryIds() { + return dictionaryIds; + } + /** * Sets up the common state and also handles creating the child columns if this is a nested * type. -- cgit v1.2.3 From 00288ea2a463180e91fd16c8e2b627e69566e1f0 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Sun, 10 Apr 2016 02:34:54 +0100 Subject: [SPARK-13687][PYTHON] Cleanup PySpark parallelize temporary files ## What changes were proposed in this pull request? Eagerly cleanup PySpark's temporary parallelize cleanup files rather than waiting for shut down. ## How was this patch tested? Unit tests Author: Holden Karau Closes #12233 from holdenk/SPARK-13687-cleanup-pyspark-temporary-files. --- python/pyspark/context.py | 22 +++++++++++++--------- python/pyspark/tests.py | 7 +++++++ 2 files changed, 20 insertions(+), 9 deletions(-) diff --git a/python/pyspark/context.py b/python/pyspark/context.py index 529d16b480..cb15b4b91f 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -428,15 +428,19 @@ class SparkContext(object): # because it sends O(n) Py4J commands. As an alternative, serialized # objects are written to a file and loaded through textFile(). tempFile = NamedTemporaryFile(delete=False, dir=self._temp_dir) - # Make sure we distribute data evenly if it's smaller than self.batchSize - if "__len__" not in dir(c): - c = list(c) # Make it a list so we can compute its length - batchSize = max(1, min(len(c) // numSlices, self._batchSize or 1024)) - serializer = BatchedSerializer(self._unbatched_serializer, batchSize) - serializer.dump_stream(c, tempFile) - tempFile.close() - readRDDFromFile = self._jvm.PythonRDD.readRDDFromFile - jrdd = readRDDFromFile(self._jsc, tempFile.name, numSlices) + try: + # Make sure we distribute data evenly if it's smaller than self.batchSize + if "__len__" not in dir(c): + c = list(c) # Make it a list so we can compute its length + batchSize = max(1, min(len(c) // numSlices, self._batchSize or 1024)) + serializer = BatchedSerializer(self._unbatched_serializer, batchSize) + serializer.dump_stream(c, tempFile) + tempFile.close() + readRDDFromFile = self._jvm.PythonRDD.readRDDFromFile + jrdd = readRDDFromFile(self._jsc, tempFile.name, numSlices) + finally: + # readRDDFromFile eagerily reads the file so we can delete right after. + os.unlink(tempFile.name) return RDD(jrdd, self, serializer) def pickleFile(self, name, minPartitions=None): diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 15c87e22f9..97ea39dde0 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -1914,6 +1914,13 @@ class ContextTests(unittest.TestCase): with SparkContext.getOrCreate() as sc: self.assertTrue(SparkContext.getOrCreate() is sc) + def test_parallelize_eager_cleanup(self): + with SparkContext() as sc: + temp_files = os.listdir(sc._temp_dir) + rdd = sc.parallelize([0, 1, 2]) + post_parallalize_temp_files = os.listdir(sc._temp_dir) + self.assertEqual(temp_files, post_parallalize_temp_files) + def test_stop(self): sc = SparkContext() self.assertNotEqual(SparkContext._active_spark_context, None) -- cgit v1.2.3 From 72e66bb270efa3dc55560a4b2657e065cfdf2ea5 Mon Sep 17 00:00:00 2001 From: Yong Tang Date: Sun, 10 Apr 2016 02:37:11 +0100 Subject: [SPARK-14301][EXAMPLES] Java examples code merge and clean up. ## What changes were proposed in this pull request? This fix tries to remove duplicate Java code in examples/mllib and examples/ml. The following changes have been made: ``` deleted: ml/JavaCrossValidatorExample.java (duplicate of JavaModelSelectionViaCrossValidationExample.java) deleted: ml/JavaTrainValidationSplitExample.java (duplicated of JavaModelSelectionViaTrainValidationSplitExample.java) deleted: mllib/JavaFPGrowthExample.java (duplicate of JavaSimpleFPGrowth.java) deleted: mllib/JavaLDAExample.java (duplicate of JavaLatentDirichletAllocationExample.java) deleted: mllib/JavaKMeans.java (merged with JavaKMeansExample.java) deleted: mllib/JavaLR.java (duplicate of JavaLinearRegressionWithSGDExample.java) updated: mllib/JavaKMeansExample.java (merged with mllib/JavaKMeans.java) ``` ## How was this patch tested? Existing tests passed. Author: Yong Tang Closes #12143 from yongtang/SPARK-14301. --- .../examples/ml/JavaCrossValidatorExample.java | 127 --------------------- ...delSelectionViaTrainValidationSplitExample.java | 10 +- .../ml/JavaTrainValidationSplitExample.java | 87 -------------- .../spark/examples/mllib/JavaFPGrowthExample.java | 78 ------------- .../apache/spark/examples/mllib/JavaKMeans.java | 82 ------------- .../spark/examples/mllib/JavaKMeansExample.java | 7 ++ .../spark/examples/mllib/JavaLDAExample.java | 77 ------------- .../org/apache/spark/examples/mllib/JavaLR.java | 82 ------------- 8 files changed, 16 insertions(+), 534 deletions(-) delete mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaCrossValidatorExample.java delete mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaTrainValidationSplitExample.java delete mode 100644 examples/src/main/java/org/apache/spark/examples/mllib/JavaFPGrowthExample.java delete mode 100644 examples/src/main/java/org/apache/spark/examples/mllib/JavaKMeans.java delete mode 100644 examples/src/main/java/org/apache/spark/examples/mllib/JavaLDAExample.java delete mode 100644 examples/src/main/java/org/apache/spark/examples/mllib/JavaLR.java diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaCrossValidatorExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaCrossValidatorExample.java deleted file mode 100644 index 07edeb3e52..0000000000 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaCrossValidatorExample.java +++ /dev/null @@ -1,127 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.examples.ml; - -import java.util.List; - -import com.google.common.collect.Lists; - -import org.apache.spark.SparkConf; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.ml.Pipeline; -import org.apache.spark.ml.PipelineStage; -import org.apache.spark.ml.classification.LogisticRegression; -import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator; -import org.apache.spark.ml.feature.HashingTF; -import org.apache.spark.ml.feature.Tokenizer; -import org.apache.spark.ml.param.ParamMap; -import org.apache.spark.ml.tuning.CrossValidator; -import org.apache.spark.ml.tuning.CrossValidatorModel; -import org.apache.spark.ml.tuning.ParamGridBuilder; -import org.apache.spark.sql.Dataset; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.SQLContext; - -/** - * A simple example demonstrating model selection using CrossValidator. - * This example also demonstrates how Pipelines are Estimators. - * - * This example uses the Java bean classes {@link org.apache.spark.examples.ml.LabeledDocument} and - * {@link org.apache.spark.examples.ml.Document} defined in the Scala example - * {@link org.apache.spark.examples.ml.SimpleTextClassificationPipeline}. - * - * Run with - *
    - * bin/run-example ml.JavaCrossValidatorExample
    - * 
    - */ -public class JavaCrossValidatorExample { - - public static void main(String[] args) { - SparkConf conf = new SparkConf().setAppName("JavaCrossValidatorExample"); - JavaSparkContext jsc = new JavaSparkContext(conf); - SQLContext jsql = new SQLContext(jsc); - - // Prepare training documents, which are labeled. - List localTraining = Lists.newArrayList( - new LabeledDocument(0L, "a b c d e spark", 1.0), - new LabeledDocument(1L, "b d", 0.0), - new LabeledDocument(2L, "spark f g h", 1.0), - new LabeledDocument(3L, "hadoop mapreduce", 0.0), - new LabeledDocument(4L, "b spark who", 1.0), - new LabeledDocument(5L, "g d a y", 0.0), - new LabeledDocument(6L, "spark fly", 1.0), - new LabeledDocument(7L, "was mapreduce", 0.0), - new LabeledDocument(8L, "e spark program", 1.0), - new LabeledDocument(9L, "a e c l", 0.0), - new LabeledDocument(10L, "spark compile", 1.0), - new LabeledDocument(11L, "hadoop software", 0.0)); - Dataset training = jsql.createDataFrame( - jsc.parallelize(localTraining), LabeledDocument.class); - - // Configure an ML pipeline, which consists of three stages: tokenizer, hashingTF, and lr. - Tokenizer tokenizer = new Tokenizer() - .setInputCol("text") - .setOutputCol("words"); - HashingTF hashingTF = new HashingTF() - .setNumFeatures(1000) - .setInputCol(tokenizer.getOutputCol()) - .setOutputCol("features"); - LogisticRegression lr = new LogisticRegression() - .setMaxIter(10) - .setRegParam(0.01); - Pipeline pipeline = new Pipeline() - .setStages(new PipelineStage[] {tokenizer, hashingTF, lr}); - - // We now treat the Pipeline as an Estimator, wrapping it in a CrossValidator instance. - // This will allow us to jointly choose parameters for all Pipeline stages. - // A CrossValidator requires an Estimator, a set of Estimator ParamMaps, and an Evaluator. - CrossValidator crossval = new CrossValidator() - .setEstimator(pipeline) - .setEvaluator(new BinaryClassificationEvaluator()); - // We use a ParamGridBuilder to construct a grid of parameters to search over. - // With 3 values for hashingTF.numFeatures and 2 values for lr.regParam, - // this grid will have 3 x 2 = 6 parameter settings for CrossValidator to choose from. - ParamMap[] paramGrid = new ParamGridBuilder() - .addGrid(hashingTF.numFeatures(), new int[]{10, 100, 1000}) - .addGrid(lr.regParam(), new double[]{0.1, 0.01}) - .build(); - crossval.setEstimatorParamMaps(paramGrid); - crossval.setNumFolds(2); // Use 3+ in practice - - // Run cross-validation, and choose the best set of parameters. - CrossValidatorModel cvModel = crossval.fit(training); - - // Prepare test documents, which are unlabeled. - List localTest = Lists.newArrayList( - new Document(4L, "spark i j k"), - new Document(5L, "l m n"), - new Document(6L, "mapreduce spark"), - new Document(7L, "apache hadoop")); - Dataset test = jsql.createDataFrame(jsc.parallelize(localTest), Document.class); - - // Make predictions on test documents. cvModel uses the best model found (lrModel). - Dataset predictions = cvModel.transform(test); - for (Row r: predictions.select("id", "text", "probability", "prediction").collectAsList()) { - System.out.println("(" + r.get(0) + ", " + r.get(1) + ") --> prob=" + r.get(2) - + ", prediction=" + r.get(3)); - } - - jsc.stop(); - } -} diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaModelSelectionViaTrainValidationSplitExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaModelSelectionViaTrainValidationSplitExample.java index 6ac4aea3c4..4994f8f9fa 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaModelSelectionViaTrainValidationSplitExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaModelSelectionViaTrainValidationSplitExample.java @@ -32,7 +32,15 @@ import org.apache.spark.sql.Row; import org.apache.spark.sql.SQLContext; /** - * Java example for Model Selection via Train Validation Split. + * Java example demonstrating model selection using TrainValidationSplit. + * + * The example is based on {@link org.apache.spark.examples.ml.JavaSimpleParamsExample} + * using linear regression. + * + * Run with + * {{{ + * bin/run-example ml.JavaModelSelectionViaTrainValidationSplitExample + * }}} */ public class JavaModelSelectionViaTrainValidationSplitExample { public static void main(String[] args) { diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaTrainValidationSplitExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaTrainValidationSplitExample.java deleted file mode 100644 index 09bbc39c01..0000000000 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaTrainValidationSplitExample.java +++ /dev/null @@ -1,87 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.examples.ml; - -import org.apache.spark.SparkConf; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.ml.evaluation.RegressionEvaluator; -import org.apache.spark.ml.param.ParamMap; -import org.apache.spark.ml.regression.LinearRegression; -import org.apache.spark.ml.tuning.*; -import org.apache.spark.sql.Dataset; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.SQLContext; - -/** - * A simple example demonstrating model selection using TrainValidationSplit. - * - * The example is based on {@link org.apache.spark.examples.ml.JavaSimpleParamsExample} - * using linear regression. - * - * Run with - * {{{ - * bin/run-example ml.JavaTrainValidationSplitExample - * }}} - */ -public class JavaTrainValidationSplitExample { - - public static void main(String[] args) { - SparkConf conf = new SparkConf().setAppName("JavaTrainValidationSplitExample"); - JavaSparkContext jsc = new JavaSparkContext(conf); - SQLContext jsql = new SQLContext(jsc); - - Dataset data = jsql.read().format("libsvm").load("data/mllib/sample_libsvm_data.txt"); - - // Prepare training and test data. - Dataset[] splits = data.randomSplit(new double [] {0.9, 0.1}, 12345); - Dataset training = splits[0]; - Dataset test = splits[1]; - - LinearRegression lr = new LinearRegression(); - - // We use a ParamGridBuilder to construct a grid of parameters to search over. - // TrainValidationSplit will try all combinations of values and determine best model using - // the evaluator. - ParamMap[] paramGrid = new ParamGridBuilder() - .addGrid(lr.regParam(), new double[] {0.1, 0.01}) - .addGrid(lr.fitIntercept()) - .addGrid(lr.elasticNetParam(), new double[] {0.0, 0.5, 1.0}) - .build(); - - // In this case the estimator is simply the linear regression. - // A TrainValidationSplit requires an Estimator, a set of Estimator ParamMaps, and an Evaluator. - TrainValidationSplit trainValidationSplit = new TrainValidationSplit() - .setEstimator(lr) - .setEvaluator(new RegressionEvaluator()) - .setEstimatorParamMaps(paramGrid); - - // 80% of the data will be used for training and the remaining 20% for validation. - trainValidationSplit.setTrainRatio(0.8); - - // Run train validation split, and choose the best set of parameters. - TrainValidationSplitModel model = trainValidationSplit.fit(training); - - // Make predictions on test data. model is the model with combination of parameters - // that performed best. - model.transform(test) - .select("features", "label", "prediction") - .show(); - - jsc.stop(); - } -} diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaFPGrowthExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaFPGrowthExample.java deleted file mode 100644 index 36baf58687..0000000000 --- a/examples/src/main/java/org/apache/spark/examples/mllib/JavaFPGrowthExample.java +++ /dev/null @@ -1,78 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.examples.mllib; - -import java.util.ArrayList; - -import com.google.common.base.Joiner; -import com.google.common.collect.Lists; - -import org.apache.spark.SparkConf; -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.api.java.function.Function; -import org.apache.spark.mllib.fpm.FPGrowth; -import org.apache.spark.mllib.fpm.FPGrowthModel; - -/** - * Java example for mining frequent itemsets using FP-growth. - * Example usage: ./bin/run-example mllib.JavaFPGrowthExample ./data/mllib/sample_fpgrowth.txt - */ -public class JavaFPGrowthExample { - - public static void main(String[] args) { - String inputFile; - double minSupport = 0.3; - int numPartition = -1; - if (args.length < 1) { - System.err.println( - "Usage: JavaFPGrowth [minSupport] [numPartition]"); - System.exit(1); - } - inputFile = args[0]; - if (args.length >= 2) { - minSupport = Double.parseDouble(args[1]); - } - if (args.length >= 3) { - numPartition = Integer.parseInt(args[2]); - } - - SparkConf sparkConf = new SparkConf().setAppName("JavaFPGrowthExample"); - JavaSparkContext sc = new JavaSparkContext(sparkConf); - - JavaRDD> transactions = sc.textFile(inputFile).map( - new Function>() { - @Override - public ArrayList call(String s) { - return Lists.newArrayList(s.split(" ")); - } - } - ); - - FPGrowthModel model = new FPGrowth() - .setMinSupport(minSupport) - .setNumPartitions(numPartition) - .run(transactions); - - for (FPGrowth.FreqItemset s: model.freqItemsets().toJavaRDD().collect()) { - System.out.println("[" + Joiner.on(",").join(s.javaItems()) + "], " + s.freq()); - } - - sc.stop(); - } -} diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaKMeans.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaKMeans.java deleted file mode 100644 index e575eedeb4..0000000000 --- a/examples/src/main/java/org/apache/spark/examples/mllib/JavaKMeans.java +++ /dev/null @@ -1,82 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.examples.mllib; - -import java.util.regex.Pattern; - -import org.apache.spark.SparkConf; -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.api.java.function.Function; - -import org.apache.spark.mllib.clustering.KMeans; -import org.apache.spark.mllib.clustering.KMeansModel; -import org.apache.spark.mllib.linalg.Vector; -import org.apache.spark.mllib.linalg.Vectors; - -/** - * Example using MLlib KMeans from Java. - */ -public final class JavaKMeans { - - private static class ParsePoint implements Function { - private static final Pattern SPACE = Pattern.compile(" "); - - @Override - public Vector call(String line) { - String[] tok = SPACE.split(line); - double[] point = new double[tok.length]; - for (int i = 0; i < tok.length; ++i) { - point[i] = Double.parseDouble(tok[i]); - } - return Vectors.dense(point); - } - } - - public static void main(String[] args) { - if (args.length < 3) { - System.err.println( - "Usage: JavaKMeans []"); - System.exit(1); - } - String inputFile = args[0]; - int k = Integer.parseInt(args[1]); - int iterations = Integer.parseInt(args[2]); - int runs = 1; - - if (args.length >= 4) { - runs = Integer.parseInt(args[3]); - } - SparkConf sparkConf = new SparkConf().setAppName("JavaKMeans"); - JavaSparkContext sc = new JavaSparkContext(sparkConf); - JavaRDD lines = sc.textFile(inputFile); - - JavaRDD points = lines.map(new ParsePoint()); - - KMeansModel model = KMeans.train(points.rdd(), k, iterations, runs, KMeans.K_MEANS_PARALLEL()); - - System.out.println("Cluster centers:"); - for (Vector center : model.clusterCenters()) { - System.out.println(" " + center); - } - double cost = model.computeCost(points.rdd()); - System.out.println("Cost: " + cost); - - sc.stop(); - } -} diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaKMeansExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaKMeansExample.java index 006d96d111..2d89c768fc 100644 --- a/examples/src/main/java/org/apache/spark/examples/mllib/JavaKMeansExample.java +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaKMeansExample.java @@ -58,6 +58,13 @@ public class JavaKMeansExample { int numIterations = 20; KMeansModel clusters = KMeans.train(parsedData.rdd(), numClusters, numIterations); + System.out.println("Cluster centers:"); + for (Vector center: clusters.clusterCenters()) { + System.out.println(" " + center); + } + double cost = clusters.computeCost(parsedData.rdd()); + System.out.println("Cost: " + cost); + // Evaluate clustering by computing Within Set Sum of Squared Errors double WSSSE = clusters.computeCost(parsedData.rdd()); System.out.println("Within Set Sum of Squared Errors = " + WSSSE); diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaLDAExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaLDAExample.java deleted file mode 100644 index de8e739ac9..0000000000 --- a/examples/src/main/java/org/apache/spark/examples/mllib/JavaLDAExample.java +++ /dev/null @@ -1,77 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.examples.mllib; - -import scala.Tuple2; - -import org.apache.spark.api.java.*; -import org.apache.spark.api.java.function.Function; -import org.apache.spark.mllib.clustering.DistributedLDAModel; -import org.apache.spark.mllib.clustering.LDA; -import org.apache.spark.mllib.linalg.Matrix; -import org.apache.spark.mllib.linalg.Vector; -import org.apache.spark.mllib.linalg.Vectors; -import org.apache.spark.SparkConf; - -public class JavaLDAExample { - public static void main(String[] args) { - SparkConf conf = new SparkConf().setAppName("LDA Example"); - JavaSparkContext sc = new JavaSparkContext(conf); - - // Load and parse the data - String path = "data/mllib/sample_lda_data.txt"; - JavaRDD data = sc.textFile(path); - JavaRDD parsedData = data.map( - new Function() { - public Vector call(String s) { - String[] sarray = s.trim().split(" "); - double[] values = new double[sarray.length]; - for (int i = 0; i < sarray.length; i++) { - values[i] = Double.parseDouble(sarray[i]); - } - return Vectors.dense(values); - } - } - ); - // Index documents with unique IDs - JavaPairRDD corpus = JavaPairRDD.fromJavaRDD(parsedData.zipWithIndex().map( - new Function, Tuple2>() { - public Tuple2 call(Tuple2 doc_id) { - return doc_id.swap(); - } - } - )); - corpus.cache(); - - // Cluster the documents into three topics using LDA - DistributedLDAModel ldaModel = (DistributedLDAModel)new LDA().setK(3).run(corpus); - - // Output topics. Each is a distribution over words (matching word count vectors) - System.out.println("Learned topics (as distributions over vocab of " + ldaModel.vocabSize() - + " words):"); - Matrix topics = ldaModel.topicsMatrix(); - for (int topic = 0; topic < 3; topic++) { - System.out.print("Topic " + topic + ":"); - for (int word = 0; word < ldaModel.vocabSize(); word++) { - System.out.print(" " + topics.apply(word, topic)); - } - System.out.println(); - } - sc.stop(); - } -} diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaLR.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaLR.java deleted file mode 100644 index eceb6927d5..0000000000 --- a/examples/src/main/java/org/apache/spark/examples/mllib/JavaLR.java +++ /dev/null @@ -1,82 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.examples.mllib; - -import java.util.regex.Pattern; - -import org.apache.spark.SparkConf; -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.api.java.function.Function; - -import org.apache.spark.mllib.classification.LogisticRegressionWithSGD; -import org.apache.spark.mllib.classification.LogisticRegressionModel; -import org.apache.spark.mllib.linalg.Vectors; -import org.apache.spark.mllib.regression.LabeledPoint; - -/** - * Logistic regression based classification using ML Lib. - */ -public final class JavaLR { - - static class ParsePoint implements Function { - private static final Pattern COMMA = Pattern.compile(","); - private static final Pattern SPACE = Pattern.compile(" "); - - @Override - public LabeledPoint call(String line) { - String[] parts = COMMA.split(line); - double y = Double.parseDouble(parts[0]); - String[] tok = SPACE.split(parts[1]); - double[] x = new double[tok.length]; - for (int i = 0; i < tok.length; ++i) { - x[i] = Double.parseDouble(tok[i]); - } - return new LabeledPoint(y, Vectors.dense(x)); - } - } - - public static void main(String[] args) { - if (args.length != 3) { - System.err.println("Usage: JavaLR "); - System.exit(1); - } - SparkConf sparkConf = new SparkConf().setAppName("JavaLR"); - JavaSparkContext sc = new JavaSparkContext(sparkConf); - JavaRDD lines = sc.textFile(args[0]); - JavaRDD points = lines.map(new ParsePoint()).cache(); - double stepSize = Double.parseDouble(args[1]); - int iterations = Integer.parseInt(args[2]); - - // Another way to configure LogisticRegression - // - // LogisticRegressionWithSGD lr = new LogisticRegressionWithSGD(); - // lr.optimizer().setNumIterations(iterations) - // .setStepSize(stepSize) - // .setMiniBatchFraction(1.0); - // lr.setIntercept(true); - // LogisticRegressionModel model = lr.train(points.rdd()); - - LogisticRegressionModel model = LogisticRegressionWithSGD.train(points.rdd(), - iterations, stepSize); - - System.out.print("Final w: " + model.weights()); - - sc.stop(); - } -} -- cgit v1.2.3 From aea30a1a9b79eb13d362ef32e4e9c8233e29f3dc Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Sat, 9 Apr 2016 21:31:20 -0700 Subject: [SPARK-14465][BUILD] Checkstyle should check all Java files ## What changes were proposed in this pull request? Currently, `checkstyle` is configured to check the files under `src/main/java`. However, Spark has Java files in `src/main/scala`, too. This PR fixes the following configuration in `pom.xml` and the unchecked-so-far violations on those files. ```xml -${basedir}/src/main/java +${basedir}/src/main/java,${basedir}/src/main/scala ``` ## How was this patch tested? After passing the Jenkins build and manually `dev/lint-java`. (Note that Jenkins does not run `lint-java`) Author: Dongjoon Hyun Closes #12242 from dongjoon-hyun/SPARK-14465. --- .../org/apache/spark/io/LZ4BlockInputStream.java | 261 ++++++++++++++++++++ .../org/apache/spark/io/LZ4BlockInputStream.java | 263 --------------------- .../org/apache/spark/streaming/Java8APISuite.java | 6 +- pom.xml | 2 +- .../spark/sql/execution/BufferedRowIterator.java | 94 ++++++++ .../apache/spark/sql/expressions/java/typed.java | 75 ++++++ .../spark/sql/execution/BufferedRowIterator.java | 94 -------- .../apache/spark/sql/expressions/java/typed.java | 76 ------ 8 files changed, 435 insertions(+), 436 deletions(-) create mode 100644 core/src/main/java/org/apache/spark/io/LZ4BlockInputStream.java delete mode 100644 core/src/main/scala/org/apache/spark/io/LZ4BlockInputStream.java create mode 100644 sql/core/src/main/java/org/apache/spark/sql/execution/BufferedRowIterator.java create mode 100644 sql/core/src/main/java/org/apache/spark/sql/expressions/java/typed.java delete mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/BufferedRowIterator.java delete mode 100644 sql/core/src/main/scala/org/apache/spark/sql/expressions/java/typed.java diff --git a/core/src/main/java/org/apache/spark/io/LZ4BlockInputStream.java b/core/src/main/java/org/apache/spark/io/LZ4BlockInputStream.java new file mode 100644 index 0000000000..8783b5f56e --- /dev/null +++ b/core/src/main/java/org/apache/spark/io/LZ4BlockInputStream.java @@ -0,0 +1,261 @@ +package org.apache.spark.io; + +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import java.io.EOFException; +import java.io.FilterInputStream; +import java.io.IOException; +import java.io.InputStream; +import java.util.zip.Checksum; + +import net.jpountz.lz4.LZ4Exception; +import net.jpountz.lz4.LZ4Factory; +import net.jpountz.lz4.LZ4FastDecompressor; +import net.jpountz.util.SafeUtils; +import net.jpountz.xxhash.XXHashFactory; + +/** + * {@link InputStream} implementation to decode data written with + * {@link net.jpountz.lz4.LZ4BlockOutputStream}. This class is not thread-safe and does not + * support {@link #mark(int)}/{@link #reset()}. + * @see net.jpountz.lz4.LZ4BlockOutputStream + * + * This is based on net.jpountz.lz4.LZ4BlockInputStream + * + * changes: https://github.com/davies/lz4-java/commit/cc1fa940ac57cc66a0b937300f805d37e2bf8411 + * + * TODO: merge this into upstream + */ +public final class LZ4BlockInputStream extends FilterInputStream { + + // Copied from net.jpountz.lz4.LZ4BlockOutputStream + static final byte[] MAGIC = new byte[] { 'L', 'Z', '4', 'B', 'l', 'o', 'c', 'k' }; + static final int MAGIC_LENGTH = MAGIC.length; + + static final int HEADER_LENGTH = + MAGIC_LENGTH // magic bytes + + 1 // token + + 4 // compressed length + + 4 // decompressed length + + 4; // checksum + + static final int COMPRESSION_LEVEL_BASE = 10; + + static final int COMPRESSION_METHOD_RAW = 0x10; + static final int COMPRESSION_METHOD_LZ4 = 0x20; + + static final int DEFAULT_SEED = 0x9747b28c; + + private final LZ4FastDecompressor decompressor; + private final Checksum checksum; + private byte[] buffer; + private byte[] compressedBuffer; + private int originalLen; + private int o; + private boolean finished; + + /** + * Create a new {@link InputStream}. + * + * @param in the {@link InputStream} to poll + * @param decompressor the {@link LZ4FastDecompressor decompressor} instance to + * use + * @param checksum the {@link Checksum} instance to use, must be + * equivalent to the instance which has been used to + * write the stream + */ + public LZ4BlockInputStream(InputStream in, LZ4FastDecompressor decompressor, Checksum checksum) { + super(in); + this.decompressor = decompressor; + this.checksum = checksum; + this.buffer = new byte[0]; + this.compressedBuffer = new byte[HEADER_LENGTH]; + o = originalLen = 0; + finished = false; + } + + /** + * Create a new instance using {@link net.jpountz.xxhash.XXHash32} for checksuming. + * @see #LZ4BlockInputStream(InputStream, LZ4FastDecompressor, Checksum) + * @see net.jpountz.xxhash.StreamingXXHash32#asChecksum() + */ + public LZ4BlockInputStream(InputStream in, LZ4FastDecompressor decompressor) { + this(in, decompressor, + XXHashFactory.fastestInstance().newStreamingHash32(DEFAULT_SEED).asChecksum()); + } + + /** + * Create a new instance which uses the fastest {@link LZ4FastDecompressor} available. + * @see LZ4Factory#fastestInstance() + * @see #LZ4BlockInputStream(InputStream, LZ4FastDecompressor) + */ + public LZ4BlockInputStream(InputStream in) { + this(in, LZ4Factory.fastestInstance().fastDecompressor()); + } + + @Override + public int available() throws IOException { + refill(); + return originalLen - o; + } + + @Override + public int read() throws IOException { + refill(); + if (finished) { + return -1; + } + return buffer[o++] & 0xFF; + } + + @Override + public int read(byte[] b, int off, int len) throws IOException { + SafeUtils.checkRange(b, off, len); + refill(); + if (finished) { + return -1; + } + len = Math.min(len, originalLen - o); + System.arraycopy(buffer, o, b, off, len); + o += len; + return len; + } + + @Override + public int read(byte[] b) throws IOException { + return read(b, 0, b.length); + } + + @Override + public long skip(long n) throws IOException { + refill(); + if (finished) { + return -1; + } + final int skipped = (int) Math.min(n, originalLen - o); + o += skipped; + return skipped; + } + + private void refill() throws IOException { + if (finished || o < originalLen) { + return; + } + try { + readFully(compressedBuffer, HEADER_LENGTH); + } catch (EOFException e) { + finished = true; + return; + } + for (int i = 0; i < MAGIC_LENGTH; ++i) { + if (compressedBuffer[i] != MAGIC[i]) { + throw new IOException("Stream is corrupted"); + } + } + final int token = compressedBuffer[MAGIC_LENGTH] & 0xFF; + final int compressionMethod = token & 0xF0; + final int compressionLevel = COMPRESSION_LEVEL_BASE + (token & 0x0F); + if (compressionMethod != COMPRESSION_METHOD_RAW && compressionMethod != COMPRESSION_METHOD_LZ4) + { + throw new IOException("Stream is corrupted"); + } + final int compressedLen = SafeUtils.readIntLE(compressedBuffer, MAGIC_LENGTH + 1); + originalLen = SafeUtils.readIntLE(compressedBuffer, MAGIC_LENGTH + 5); + final int check = SafeUtils.readIntLE(compressedBuffer, MAGIC_LENGTH + 9); + assert HEADER_LENGTH == MAGIC_LENGTH + 13; + if (originalLen > 1 << compressionLevel + || originalLen < 0 + || compressedLen < 0 + || (originalLen == 0 && compressedLen != 0) + || (originalLen != 0 && compressedLen == 0) + || (compressionMethod == COMPRESSION_METHOD_RAW && originalLen != compressedLen)) { + throw new IOException("Stream is corrupted"); + } + if (originalLen == 0 && compressedLen == 0) { + if (check != 0) { + throw new IOException("Stream is corrupted"); + } + refill(); + return; + } + if (buffer.length < originalLen) { + buffer = new byte[Math.max(originalLen, buffer.length * 3 / 2)]; + } + switch (compressionMethod) { + case COMPRESSION_METHOD_RAW: + readFully(buffer, originalLen); + break; + case COMPRESSION_METHOD_LZ4: + if (compressedBuffer.length < originalLen) { + compressedBuffer = new byte[Math.max(compressedLen, compressedBuffer.length * 3 / 2)]; + } + readFully(compressedBuffer, compressedLen); + try { + final int compressedLen2 = + decompressor.decompress(compressedBuffer, 0, buffer, 0, originalLen); + if (compressedLen != compressedLen2) { + throw new IOException("Stream is corrupted"); + } + } catch (LZ4Exception e) { + throw new IOException("Stream is corrupted", e); + } + break; + default: + throw new AssertionError(); + } + checksum.reset(); + checksum.update(buffer, 0, originalLen); + if ((int) checksum.getValue() != check) { + throw new IOException("Stream is corrupted"); + } + o = 0; + } + + private void readFully(byte[] b, int len) throws IOException { + int read = 0; + while (read < len) { + final int r = in.read(b, read, len - read); + if (r < 0) { + throw new EOFException("Stream ended prematurely"); + } + read += r; + } + assert len == read; + } + + @Override + public boolean markSupported() { + return false; + } + + @SuppressWarnings("sync-override") + @Override + public void mark(int readlimit) { + // unsupported + } + + @SuppressWarnings("sync-override") + @Override + public void reset() throws IOException { + throw new IOException("mark/reset not supported"); + } + + @Override + public String toString() { + return getClass().getSimpleName() + "(in=" + in + + ", decompressor=" + decompressor + ", checksum=" + checksum + ")"; + } + +} diff --git a/core/src/main/scala/org/apache/spark/io/LZ4BlockInputStream.java b/core/src/main/scala/org/apache/spark/io/LZ4BlockInputStream.java deleted file mode 100644 index 27b6f0d4a3..0000000000 --- a/core/src/main/scala/org/apache/spark/io/LZ4BlockInputStream.java +++ /dev/null @@ -1,263 +0,0 @@ -package org.apache.spark.io; - -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -import java.io.EOFException; -import java.io.FilterInputStream; -import java.io.IOException; -import java.io.InputStream; -import java.util.zip.Checksum; - -import net.jpountz.lz4.LZ4BlockOutputStream; -import net.jpountz.lz4.LZ4Exception; -import net.jpountz.lz4.LZ4Factory; -import net.jpountz.lz4.LZ4FastDecompressor; -import net.jpountz.util.SafeUtils; -import net.jpountz.xxhash.StreamingXXHash32; -import net.jpountz.xxhash.XXHash32; -import net.jpountz.xxhash.XXHashFactory; - -/** - * {@link InputStream} implementation to decode data written with - * {@link LZ4BlockOutputStream}. This class is not thread-safe and does not - * support {@link #mark(int)}/{@link #reset()}. - * @see LZ4BlockOutputStream - * - * This is based on net.jpountz.lz4.LZ4BlockInputStream - * - * changes: https://github.com/davies/lz4-java/commit/cc1fa940ac57cc66a0b937300f805d37e2bf8411 - * - * TODO: merge this into upstream - */ -public final class LZ4BlockInputStream extends FilterInputStream { - - // Copied from net.jpountz.lz4.LZ4BlockOutputStream - static final byte[] MAGIC = new byte[] { 'L', 'Z', '4', 'B', 'l', 'o', 'c', 'k' }; - static final int MAGIC_LENGTH = MAGIC.length; - - static final int HEADER_LENGTH = - MAGIC_LENGTH // magic bytes - + 1 // token - + 4 // compressed length - + 4 // decompressed length - + 4; // checksum - - static final int COMPRESSION_LEVEL_BASE = 10; - - static final int COMPRESSION_METHOD_RAW = 0x10; - static final int COMPRESSION_METHOD_LZ4 = 0x20; - - static final int DEFAULT_SEED = 0x9747b28c; - - private final LZ4FastDecompressor decompressor; - private final Checksum checksum; - private byte[] buffer; - private byte[] compressedBuffer; - private int originalLen; - private int o; - private boolean finished; - - /** - * Create a new {@link InputStream}. - * - * @param in the {@link InputStream} to poll - * @param decompressor the {@link LZ4FastDecompressor decompressor} instance to - * use - * @param checksum the {@link Checksum} instance to use, must be - * equivalent to the instance which has been used to - * write the stream - */ - public LZ4BlockInputStream(InputStream in, LZ4FastDecompressor decompressor, Checksum checksum) { - super(in); - this.decompressor = decompressor; - this.checksum = checksum; - this.buffer = new byte[0]; - this.compressedBuffer = new byte[HEADER_LENGTH]; - o = originalLen = 0; - finished = false; - } - - /** - * Create a new instance using {@link XXHash32} for checksuming. - * @see #LZ4BlockInputStream(InputStream, LZ4FastDecompressor, Checksum) - * @see StreamingXXHash32#asChecksum() - */ - public LZ4BlockInputStream(InputStream in, LZ4FastDecompressor decompressor) { - this(in, decompressor, XXHashFactory.fastestInstance().newStreamingHash32(DEFAULT_SEED).asChecksum()); - } - - /** - * Create a new instance which uses the fastest {@link LZ4FastDecompressor} available. - * @see LZ4Factory#fastestInstance() - * @see #LZ4BlockInputStream(InputStream, LZ4FastDecompressor) - */ - public LZ4BlockInputStream(InputStream in) { - this(in, LZ4Factory.fastestInstance().fastDecompressor()); - } - - @Override - public int available() throws IOException { - refill(); - return originalLen - o; - } - - @Override - public int read() throws IOException { - refill(); - if (finished) { - return -1; - } - return buffer[o++] & 0xFF; - } - - @Override - public int read(byte[] b, int off, int len) throws IOException { - SafeUtils.checkRange(b, off, len); - refill(); - if (finished) { - return -1; - } - len = Math.min(len, originalLen - o); - System.arraycopy(buffer, o, b, off, len); - o += len; - return len; - } - - @Override - public int read(byte[] b) throws IOException { - return read(b, 0, b.length); - } - - @Override - public long skip(long n) throws IOException { - refill(); - if (finished) { - return -1; - } - final int skipped = (int) Math.min(n, originalLen - o); - o += skipped; - return skipped; - } - - private void refill() throws IOException { - if (finished || o < originalLen) { - return; - } - try { - readFully(compressedBuffer, HEADER_LENGTH); - } catch (EOFException e) { - finished = true; - return; - } - for (int i = 0; i < MAGIC_LENGTH; ++i) { - if (compressedBuffer[i] != MAGIC[i]) { - throw new IOException("Stream is corrupted"); - } - } - final int token = compressedBuffer[MAGIC_LENGTH] & 0xFF; - final int compressionMethod = token & 0xF0; - final int compressionLevel = COMPRESSION_LEVEL_BASE + (token & 0x0F); - if (compressionMethod != COMPRESSION_METHOD_RAW && compressionMethod != COMPRESSION_METHOD_LZ4) - { - throw new IOException("Stream is corrupted"); - } - final int compressedLen = SafeUtils.readIntLE(compressedBuffer, MAGIC_LENGTH + 1); - originalLen = SafeUtils.readIntLE(compressedBuffer, MAGIC_LENGTH + 5); - final int check = SafeUtils.readIntLE(compressedBuffer, MAGIC_LENGTH + 9); - assert HEADER_LENGTH == MAGIC_LENGTH + 13; - if (originalLen > 1 << compressionLevel - || originalLen < 0 - || compressedLen < 0 - || (originalLen == 0 && compressedLen != 0) - || (originalLen != 0 && compressedLen == 0) - || (compressionMethod == COMPRESSION_METHOD_RAW && originalLen != compressedLen)) { - throw new IOException("Stream is corrupted"); - } - if (originalLen == 0 && compressedLen == 0) { - if (check != 0) { - throw new IOException("Stream is corrupted"); - } - refill(); - return; - } - if (buffer.length < originalLen) { - buffer = new byte[Math.max(originalLen, buffer.length * 3 / 2)]; - } - switch (compressionMethod) { - case COMPRESSION_METHOD_RAW: - readFully(buffer, originalLen); - break; - case COMPRESSION_METHOD_LZ4: - if (compressedBuffer.length < originalLen) { - compressedBuffer = new byte[Math.max(compressedLen, compressedBuffer.length * 3 / 2)]; - } - readFully(compressedBuffer, compressedLen); - try { - final int compressedLen2 = - decompressor.decompress(compressedBuffer, 0, buffer, 0, originalLen); - if (compressedLen != compressedLen2) { - throw new IOException("Stream is corrupted"); - } - } catch (LZ4Exception e) { - throw new IOException("Stream is corrupted", e); - } - break; - default: - throw new AssertionError(); - } - checksum.reset(); - checksum.update(buffer, 0, originalLen); - if ((int) checksum.getValue() != check) { - throw new IOException("Stream is corrupted"); - } - o = 0; - } - - private void readFully(byte[] b, int len) throws IOException { - int read = 0; - while (read < len) { - final int r = in.read(b, read, len - read); - if (r < 0) { - throw new EOFException("Stream ended prematurely"); - } - read += r; - } - assert len == read; - } - - @Override - public boolean markSupported() { - return false; - } - - @SuppressWarnings("sync-override") - @Override - public void mark(int readlimit) { - // unsupported - } - - @SuppressWarnings("sync-override") - @Override - public void reset() throws IOException { - throw new IOException("mark/reset not supported"); - } - - @Override - public String toString() { - return getClass().getSimpleName() + "(in=" + in - + ", decompressor=" + decompressor + ", checksum=" + checksum + ")"; - } - -} diff --git a/external/java8-tests/src/test/java/org/apache/spark/streaming/Java8APISuite.java b/external/java8-tests/src/test/java/org/apache/spark/streaming/Java8APISuite.java index 67bc64a444..d0fed303e6 100644 --- a/external/java8-tests/src/test/java/org/apache/spark/streaming/Java8APISuite.java +++ b/external/java8-tests/src/test/java/org/apache/spark/streaming/Java8APISuite.java @@ -377,7 +377,9 @@ public class Java8APISuite extends LocalJavaStreamingContext implements Serializ }); // This is a test to make sure foreachRDD(VoidFunction2) can be called from Java - stream.foreachRDD((rdd, time) -> { return; }); + stream.foreachRDD((rdd, time) -> { + return; + }); JavaTestUtils.runStreams(ssc, 2, 2); @@ -873,7 +875,7 @@ public class Java8APISuite extends LocalJavaStreamingContext implements Serializ JavaMapWithStateDStream stateDstream = wordsDstream.mapWithState( - StateSpec. function((time, key, value, state) -> { + StateSpec.function((time, key, value, state) -> { // Use all State's methods here state.exists(); state.get(); diff --git a/pom.xml b/pom.xml index 3f9e4abc32..4cbc6a2f11 100644 --- a/pom.xml +++ b/pom.xml @@ -2253,7 +2253,7 @@ false true false - ${basedir}/src/main/java + ${basedir}/src/main/java,${basedir}/src/main/scala ${basedir}/src/test/java dev/checkstyle.xml ${basedir}/target/checkstyle-output.xml diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/BufferedRowIterator.java b/sql/core/src/main/java/org/apache/spark/sql/execution/BufferedRowIterator.java new file mode 100644 index 0000000000..086547c793 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/BufferedRowIterator.java @@ -0,0 +1,94 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution; + +import java.io.IOException; +import java.util.LinkedList; + +import scala.collection.Iterator; + +import org.apache.spark.TaskContext; +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.catalyst.expressions.UnsafeRow; + +/** + * An iterator interface used to pull the output from generated function for multiple operators + * (whole stage codegen). + */ +public abstract class BufferedRowIterator { + protected LinkedList currentRows = new LinkedList<>(); + // used when there is no column in output + protected UnsafeRow unsafeRow = new UnsafeRow(0); + private long startTimeNs = System.nanoTime(); + + protected int partitionIndex = -1; + + public boolean hasNext() throws IOException { + if (currentRows.isEmpty()) { + processNext(); + } + return !currentRows.isEmpty(); + } + + public InternalRow next() { + return currentRows.remove(); + } + + /** + * Returns the elapsed time since this object is created. This object represents a pipeline so + * this is a measure of how long the pipeline has been running. + */ + public long durationMs() { + return (System.nanoTime() - startTimeNs) / (1000 * 1000); + } + + /** + * Initializes from array of iterators of InternalRow. + */ + public abstract void init(int index, Iterator[] iters); + + /** + * Append a row to currentRows. + */ + protected void append(InternalRow row) { + currentRows.add(row); + } + + /** + * Returns whether `processNext()` should stop processing next row from `input` or not. + * + * If it returns true, the caller should exit the loop (return from processNext()). + */ + protected boolean shouldStop() { + return !currentRows.isEmpty(); + } + + /** + * Increase the peak execution memory for current task. + */ + protected void incPeakExecutionMemory(long size) { + TaskContext.get().taskMetrics().incPeakExecutionMemory(size); + } + + /** + * Processes the input until have a row as output (currentRow). + * + * After it's called, if currentRow is still null, it means no more rows left. + */ + protected abstract void processNext() throws IOException; +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/expressions/java/typed.java b/sql/core/src/main/java/org/apache/spark/sql/expressions/java/typed.java new file mode 100644 index 0000000000..c7c6e3868f --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/expressions/java/typed.java @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.expressions.java; + +import org.apache.spark.annotation.Experimental; +import org.apache.spark.api.java.function.MapFunction; +import org.apache.spark.sql.TypedColumn; +import org.apache.spark.sql.execution.aggregate.TypedAverage; +import org.apache.spark.sql.execution.aggregate.TypedCount; +import org.apache.spark.sql.execution.aggregate.TypedSumDouble; +import org.apache.spark.sql.execution.aggregate.TypedSumLong; + +/** + * :: Experimental :: + * Type-safe functions available for {@link org.apache.spark.sql.Dataset} operations in Java. + * + * Scala users should use {@link org.apache.spark.sql.expressions.scala.typed}. + * + * @since 2.0.0 + */ +@Experimental +public class typed { + // Note: make sure to keep in sync with typed.scala + + /** + * Average aggregate function. + * + * @since 2.0.0 + */ + public static TypedColumn avg(MapFunction f) { + return new TypedAverage(f).toColumnJava(); + } + + /** + * Count aggregate function. + * + * @since 2.0.0 + */ + public static TypedColumn count(MapFunction f) { + return new TypedCount(f).toColumnJava(); + } + + /** + * Sum aggregate function for floating point (double) type. + * + * @since 2.0.0 + */ + public static TypedColumn sum(MapFunction f) { + return new TypedSumDouble(f).toColumnJava(); + } + + /** + * Sum aggregate function for integral (long, i.e. 64 bit integer) type. + * + * @since 2.0.0 + */ + public static TypedColumn sumLong(MapFunction f) { + return new TypedSumLong(f).toColumnJava(); + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/BufferedRowIterator.java b/sql/core/src/main/scala/org/apache/spark/sql/execution/BufferedRowIterator.java deleted file mode 100644 index c2633a9f8c..0000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/BufferedRowIterator.java +++ /dev/null @@ -1,94 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution; - -import java.io.IOException; -import java.util.LinkedList; - -import scala.collection.Iterator; - -import org.apache.spark.TaskContext; -import org.apache.spark.sql.catalyst.InternalRow; -import org.apache.spark.sql.catalyst.expressions.UnsafeRow; - -/** - * An iterator interface used to pull the output from generated function for multiple operators - * (whole stage codegen). - */ -public abstract class BufferedRowIterator { - protected LinkedList currentRows = new LinkedList<>(); - // used when there is no column in output - protected UnsafeRow unsafeRow = new UnsafeRow(0); - private long startTimeNs = System.nanoTime(); - - protected int partitionIndex = -1; - - public boolean hasNext() throws IOException { - if (currentRows.isEmpty()) { - processNext(); - } - return !currentRows.isEmpty(); - } - - public InternalRow next() { - return currentRows.remove(); - } - - /** - * Returns the elapsed time since this object is created. This object represents a pipeline so - * this is a measure of how long the pipeline has been running. - */ - public long durationMs() { - return (System.nanoTime() - startTimeNs) / (1000 * 1000); - } - - /** - * Initializes from array of iterators of InternalRow. - */ - public abstract void init(int index, Iterator iters[]); - - /** - * Append a row to currentRows. - */ - protected void append(InternalRow row) { - currentRows.add(row); - } - - /** - * Returns whether `processNext()` should stop processing next row from `input` or not. - * - * If it returns true, the caller should exit the loop (return from processNext()). - */ - protected boolean shouldStop() { - return !currentRows.isEmpty(); - } - - /** - * Increase the peak execution memory for current task. - */ - protected void incPeakExecutionMemory(long size) { - TaskContext.get().taskMetrics().incPeakExecutionMemory(size); - } - - /** - * Processes the input until have a row as output (currentRow). - * - * After it's called, if currentRow is still null, it means no more rows left. - */ - protected abstract void processNext() throws IOException; -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/java/typed.java b/sql/core/src/main/scala/org/apache/spark/sql/expressions/java/typed.java deleted file mode 100644 index 8ff7b6549b..0000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/java/typed.java +++ /dev/null @@ -1,76 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.expressions.java; - -import org.apache.spark.annotation.Experimental; -import org.apache.spark.api.java.function.MapFunction; -import org.apache.spark.sql.Dataset; -import org.apache.spark.sql.TypedColumn; -import org.apache.spark.sql.execution.aggregate.TypedAverage; -import org.apache.spark.sql.execution.aggregate.TypedCount; -import org.apache.spark.sql.execution.aggregate.TypedSumDouble; -import org.apache.spark.sql.execution.aggregate.TypedSumLong; - -/** - * :: Experimental :: - * Type-safe functions available for {@link Dataset} operations in Java. - * - * Scala users should use {@link org.apache.spark.sql.expressions.scala.typed}. - * - * @since 2.0.0 - */ -@Experimental -public class typed { - // Note: make sure to keep in sync with typed.scala - - /** - * Average aggregate function. - * - * @since 2.0.0 - */ - public static TypedColumn avg(MapFunction f) { - return new TypedAverage(f).toColumnJava(); - } - - /** - * Count aggregate function. - * - * @since 2.0.0 - */ - public static TypedColumn count(MapFunction f) { - return new TypedCount(f).toColumnJava(); - } - - /** - * Sum aggregate function for floating point (double) type. - * - * @since 2.0.0 - */ - public static TypedColumn sum(MapFunction f) { - return new TypedSumDouble(f).toColumnJava(); - } - - /** - * Sum aggregate function for integral (long, i.e. 64 bit integer) type. - * - * @since 2.0.0 - */ - public static TypedColumn sumLong(MapFunction f) { - return new TypedSumLong(f).toColumnJava(); - } -} -- cgit v1.2.3 From 3fb09afd5e55b9a7a0a332273f09f984a78c3645 Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Sat, 9 Apr 2016 23:32:17 -0700 Subject: [SPARK-14506][SQL] HiveClientImpl's toHiveTable misses a table property for external tables ## What changes were proposed in this pull request? For an external table's metadata (in Hive's representation), its table type needs to be EXTERNAL_TABLE. Also, there needs to be a field called EXTERNAL set in the table property with a value of TRUE (for a MANAGED_TABLE it will be FALSE) based on https://github.com/apache/hive/blob/release-1.2.1/metastore/src/java/org/apache/hadoop/hive/metastore/ObjectStore.java#L1095-L1105. HiveClientImpl's toHiveTable misses to set this table property. ## How was this patch tested? Added a new test. Author: Yin Huai Closes #12275 from yhuai/SPARK-14506. --- .../spark/sql/catalyst/catalog/CatalogTestCases.scala | 9 +++++++++ .../org/apache/spark/sql/hive/client/HiveClientImpl.scala | 13 +++++++++++-- 2 files changed, 20 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/CatalogTestCases.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/CatalogTestCases.scala index fbcac09ce2..0009438b31 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/CatalogTestCases.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/CatalogTestCases.scala @@ -149,6 +149,15 @@ abstract class CatalogTestCases extends SparkFunSuite with BeforeAndAfterEach { // Tables // -------------------------------------------------------------------------- + test("the table type of an external table should be EXTERNAL_TABLE") { + val catalog = newBasicCatalog() + val table = + newTable("external_table1", "db2").copy(tableType = CatalogTableType.EXTERNAL_TABLE) + catalog.createTable("db2", table, ignoreIfExists = false) + val actual = catalog.getTable("db2", "external_table1") + assert(actual.tableType === CatalogTableType.EXTERNAL_TABLE) + } + test("drop table") { val catalog = newBasicCatalog() assert(catalog.listTables("db2").toSet == Set("tbl1", "tbl2")) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala index d0eb9ddf50..a037671ef0 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala @@ -659,9 +659,18 @@ private[hive] class HiveClientImpl( private def toHiveTable(table: CatalogTable): HiveTable = { val hiveTable = new HiveTable(table.database, table.identifier.table) + // For EXTERNAL_TABLE/MANAGED_TABLE, we also need to set EXTERNAL field in + // the table properties accodringly. Otherwise, if EXTERNAL_TABLE is the table type + // but EXTERNAL field is not set, Hive metastore will change the type to + // MANAGED_TABLE (see + // metastore/src/java/org/apache/hadoop/hive/metastore/ObjectStore.java#L1095-L1105) hiveTable.setTableType(table.tableType match { - case CatalogTableType.EXTERNAL_TABLE => HiveTableType.EXTERNAL_TABLE - case CatalogTableType.MANAGED_TABLE => HiveTableType.MANAGED_TABLE + case CatalogTableType.EXTERNAL_TABLE => + hiveTable.setProperty("EXTERNAL", "TRUE") + HiveTableType.EXTERNAL_TABLE + case CatalogTableType.MANAGED_TABLE => + hiveTable.setProperty("EXTERNAL", "FALSE") + HiveTableType.MANAGED_TABLE case CatalogTableType.INDEX_TABLE => HiveTableType.INDEX_TABLE case CatalogTableType.VIRTUAL_VIEW => HiveTableType.VIRTUAL_VIEW }) -- cgit v1.2.3 From 2c95e4e966b90d2a315350608d4b21b0381dfd11 Mon Sep 17 00:00:00 2001 From: jerryshao Date: Sat, 9 Apr 2016 23:34:14 -0700 Subject: [SPARK-14455][STREAMING] Fix NPE in allocatedExecutors when calling in receiver-less scenario ## What changes were proposed in this pull request? When calling `ReceiverTracker#allocatedExecutors` in receiver-less scenario, NPE will be thrown, since this `ReceiverTracker` actually is not started and `endpoint` is not created. This will be happened when playing streaming dynamic allocation with direct Kafka. ## How was this patch tested? Local integrated test is done. Author: jerryshao Closes #12236 from jerryshao/SPARK-14455. --- .../streaming/scheduler/ReceiverTracker.scala | 12 ++++++++--- .../streaming/scheduler/ReceiverTrackerSuite.scala | 23 +++++++++++++++++++++- 2 files changed, 31 insertions(+), 4 deletions(-) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala index d67f70732d..3b33a979df 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala @@ -240,9 +240,15 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false * Get the executors allocated to each receiver. * @return a map containing receiver ids to optional executor ids. */ - def allocatedExecutors(): Map[Int, Option[String]] = { - endpoint.askWithRetry[Map[Int, ReceiverTrackingInfo]](GetAllReceiverInfo).mapValues { - _.runningExecutor.map { _.executorId } + def allocatedExecutors(): Map[Int, Option[String]] = synchronized { + if (isTrackerStarted) { + endpoint.askWithRetry[Map[Int, ReceiverTrackingInfo]](GetAllReceiverInfo).mapValues { + _.runningExecutor.map { + _.executorId + } + } + } else { + Map.empty } } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverTrackerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverTrackerSuite.scala index 7654bb2d03..df122ac090 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverTrackerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverTrackerSuite.scala @@ -26,7 +26,7 @@ import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskStart, TaskLo import org.apache.spark.scheduler.TaskLocality.TaskLocality import org.apache.spark.storage.{StorageLevel, StreamBlockId} import org.apache.spark.streaming._ -import org.apache.spark.streaming.dstream.ReceiverInputDStream +import org.apache.spark.streaming.dstream.{ConstantInputDStream, ReceiverInputDStream} import org.apache.spark.streaming.receiver._ /** Testsuite for receiver scheduling */ @@ -102,6 +102,27 @@ class ReceiverTrackerSuite extends TestSuiteBase { } } } + + test("get allocated executors") { + // Test get allocated executors when 1 receiver is registered + withStreamingContext(new StreamingContext(conf, Milliseconds(100))) { ssc => + val input = ssc.receiverStream(new TestReceiver) + val output = new TestOutputStream(input) + output.register() + ssc.start() + assert(ssc.scheduler.receiverTracker.allocatedExecutors().size === 1) + } + + // Test get allocated executors when there's no receiver registered + withStreamingContext(new StreamingContext(conf, Milliseconds(100))) { ssc => + val rdd = ssc.sc.parallelize(1 to 10) + val input = new ConstantInputDStream(ssc, rdd) + val output = new TestOutputStream(input) + output.register() + ssc.start() + assert(ssc.scheduler.receiverTracker.allocatedExecutors() === Map.empty) + } + } } /** An input DStream with for testing rate controlling */ -- cgit v1.2.3 From 22014e6fb919a35c31d852b7c2f5b7eb05751208 Mon Sep 17 00:00:00 2001 From: Jason Moore Date: Sat, 9 Apr 2016 23:34:57 -0700 Subject: [SPARK-14357][CORE] Properly handle the root cause being a commit denied exception ## What changes were proposed in this pull request? When deciding whether a CommitDeniedException caused a task to fail, consider the root cause of the Exception. ## How was this patch tested? Added a test suite for the component that extracts the root cause of the error. Made a distribution after cherry-picking this commit to branch-1.6 and used to run our Spark application that would quite often fail due to the CommitDeniedException. Author: Jason Moore Closes #12228 from jasonmoore2k/SPARK-14357. --- .../scala/org/apache/spark/executor/Executor.scala | 2 +- .../scala/org/apache/spark/util/CausedBy.scala | 36 ++++++++++++++ .../org/apache/spark/util/CausedBySuite.scala | 56 ++++++++++++++++++++++ 3 files changed, 93 insertions(+), 1 deletion(-) create mode 100644 core/src/main/scala/org/apache/spark/util/CausedBy.scala create mode 100644 core/src/test/scala/org/apache/spark/util/CausedBySuite.scala diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index 09c5733565..afa4d6093a 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -321,7 +321,7 @@ private[spark] class Executor( logInfo(s"Executor killed $taskName (TID $taskId)") execBackend.statusUpdate(taskId, TaskState.KILLED, ser.serialize(TaskKilled)) - case cDE: CommitDeniedException => + case CausedBy(cDE: CommitDeniedException) => val reason = cDE.toTaskEndReason execBackend.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason)) diff --git a/core/src/main/scala/org/apache/spark/util/CausedBy.scala b/core/src/main/scala/org/apache/spark/util/CausedBy.scala new file mode 100644 index 0000000000..73df446d98 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/CausedBy.scala @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +/** + * Extractor Object for pulling out the root cause of an error. + * If the error contains no cause, it will return the error itself. + * + * Usage: + * try { + * ... + * } catch { + * case CausedBy(ex: CommitDeniedException) => ... + * } + */ +private[spark] object CausedBy { + + def unapply(e: Throwable): Option[Throwable] = { + Option(e.getCause).flatMap(cause => unapply(cause)).orElse(Some(e)) + } +} diff --git a/core/src/test/scala/org/apache/spark/util/CausedBySuite.scala b/core/src/test/scala/org/apache/spark/util/CausedBySuite.scala new file mode 100644 index 0000000000..4a80e3f1f4 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/util/CausedBySuite.scala @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +import org.apache.spark.SparkFunSuite + +class CausedBySuite extends SparkFunSuite { + + test("For an error without a cause, should return the error") { + val error = new Exception + + val causedBy = error match { + case CausedBy(e) => e + } + + assert(causedBy === error) + } + + test("For an error with a cause, should return the cause of the error") { + val cause = new Exception + val error = new Exception(cause) + + val causedBy = error match { + case CausedBy(e) => e + } + + assert(causedBy === cause) + } + + test("For an error with a cause that itself has a cause, return the root cause") { + val causeOfCause = new Exception + val cause = new Exception(causeOfCause) + val error = new Exception(cause) + + val causedBy = error match { + case CausedBy(e) => e + } + + assert(causedBy === causeOfCause) + } +} -- cgit v1.2.3 From f4344582ba28983bf3892d08e11236f090f5bf92 Mon Sep 17 00:00:00 2001 From: fwang1 Date: Sun, 10 Apr 2016 01:13:25 -0700 Subject: [SPARK-14497][ML] Use top instead of sortBy() to get top N frequent words as dict in ConutVectorizer ## What changes were proposed in this pull request? Replace sortBy() with top() to calculate the top N frequent words as dictionary. ## How was this patch tested? existing unit tests. The terms with same TF would be sorted in descending order. The test would fail if hardcode the terms with same TF the dictionary like "c", "d"... Author: fwang1 Closes #12265 from lionelfeng/master. --- .../org/apache/spark/ml/feature/CountVectorizer.scala | 14 ++++---------- .../org/apache/spark/ml/feature/CountVectorizerSuite.scala | 7 ++++--- 2 files changed, 8 insertions(+), 13 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala index f1be971a6a..00abbbe29c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala @@ -170,16 +170,10 @@ class CountVectorizer(override val uid: String) (word, count) }.cache() val fullVocabSize = wordCounts.count() - val vocab: Array[String] = { - val tmpSortedWC: Array[(String, Long)] = if (fullVocabSize <= vocSize) { - // Use all terms - wordCounts.collect().sortBy(-_._2) - } else { - // Sort terms to select vocab - wordCounts.sortBy(_._2, ascending = false).take(vocSize) - } - tmpSortedWC.map(_._1) - } + + val vocab = wordCounts + .top(math.min(fullVocabSize, vocSize).toInt)(Ordering.by(_._2)) + .map(_._1) require(vocab.length > 0, "The vocabulary size should be > 0. Lower minDF as necessary.") copyValues(new CountVectorizerModel(uid, vocab).setParent(this)) diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala index ff0de06e27..7641e3b8cf 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala @@ -59,14 +59,15 @@ class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext (0, split("a b c d e"), Vectors.sparse(5, Seq((0, 1.0), (1, 1.0), (2, 1.0), (3, 1.0), (4, 1.0)))), (1, split("a a a a a a"), Vectors.sparse(5, Seq((0, 6.0)))), - (2, split("c"), Vectors.sparse(5, Seq((2, 1.0)))), - (3, split("b b b b b"), Vectors.sparse(5, Seq((1, 5.0))))) + (2, split("c c"), Vectors.sparse(5, Seq((2, 2.0)))), + (3, split("d"), Vectors.sparse(5, Seq((3, 1.0)))), + (4, split("b b b b b"), Vectors.sparse(5, Seq((1, 5.0))))) ).toDF("id", "words", "expected") val cv = new CountVectorizer() .setInputCol("words") .setOutputCol("features") .fit(df) - assert(cv.vocabulary === Array("a", "b", "c", "d", "e")) + assert(cv.vocabulary.toSet === Set("a", "b", "c", "d", "e")) cv.transform(df).select("features", "expected").collect().foreach { case Row(features: Vector, expected: Vector) => -- cgit v1.2.3 From b5c785629acb9afa5a62de3da472ec2184a31e3d Mon Sep 17 00:00:00 2001 From: Örjan Lundberg Date: Sun, 10 Apr 2016 16:30:30 +0100 Subject: Update KMeansExample.scala MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? example does not work wo DataFrame import ## How was this patch tested? example doc only example does not work wo DataFrame import Author: Örjan Lundberg Closes #12277 from oluies/patch-1. --- .../src/main/scala/org/apache/spark/examples/ml/KMeansExample.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/KMeansExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/KMeansExample.scala index af90652b55..7af011571f 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/KMeansExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/KMeansExample.scala @@ -23,8 +23,8 @@ import org.apache.spark.{SparkConf, SparkContext} // $example on$ import org.apache.spark.ml.clustering.KMeans import org.apache.spark.mllib.linalg.Vectors -// $example off$ import org.apache.spark.sql.{DataFrame, SQLContext} +// $example off$ /** * An example demonstrating a k-means clustering. -- cgit v1.2.3 From a7ce473bd0520c71154ed028f295dab64a7485fe Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Sun, 10 Apr 2016 11:46:45 -0700 Subject: [SPARK-14415][SQL] All functions should show usages by command `DESC FUNCTION` ## What changes were proposed in this pull request? Currently, many functions do now show usages like the followings. ``` scala> sql("desc function extended `sin`").collect().foreach(println) [Function: sin] [Class: org.apache.spark.sql.catalyst.expressions.Sin] [Usage: To be added.] [Extended Usage: To be added.] ``` This PR adds descriptions for functions and adds a testcase prevent adding function without usage. ``` scala> sql("desc function extended `sin`").collect().foreach(println); [Function: sin] [Class: org.apache.spark.sql.catalyst.expressions.Sin] [Usage: sin(x) - Returns the sine of x.] [Extended Usage: > SELECT sin(0); 0.0] ``` The only exceptions are `cube`, `grouping`, `grouping_id`, `rollup`, `window`. ## How was this patch tested? Pass the Jenkins tests (including new testcases.) Author: Dongjoon Hyun Closes #12185 from dongjoon-hyun/SPARK-14415. --- .../catalyst/expressions/aggregate/Average.scala | 2 + .../expressions/aggregate/CentralMomentAgg.scala | 14 +++ .../sql/catalyst/expressions/aggregate/Corr.scala | 2 + .../sql/catalyst/expressions/aggregate/Count.scala | 6 + .../expressions/aggregate/Covariance.scala | 4 + .../sql/catalyst/expressions/aggregate/First.scala | 5 + .../aggregate/HyperLogLogPlusPlus.scala | 7 +- .../sql/catalyst/expressions/aggregate/Last.scala | 2 + .../sql/catalyst/expressions/aggregate/Max.scala | 2 + .../sql/catalyst/expressions/aggregate/Min.scala | 3 +- .../sql/catalyst/expressions/aggregate/Sum.scala | 2 + .../sql/catalyst/expressions/arithmetic.scala | 23 +++- .../catalyst/expressions/bitwiseExpressions.scala | 12 ++ .../expressions/collectionOperations.scala | 10 ++ .../catalyst/expressions/complexTypeCreator.scala | 10 ++ .../expressions/conditionalExpressions.scala | 13 ++- .../catalyst/expressions/datetimeExpressions.scala | 82 ++++++++++++- .../sql/catalyst/expressions/generators.scala | 4 + .../sql/catalyst/expressions/jsonExpressions.scala | 6 + .../sql/catalyst/expressions/mathExpressions.scala | 129 ++++++++++++++++++++- .../spark/sql/catalyst/expressions/misc.scala | 2 + .../sql/catalyst/expressions/nullExpressions.scala | 11 ++ .../sql/catalyst/expressions/predicates.scala | 29 +++-- .../catalyst/expressions/randomExpressions.scala | 4 + .../catalyst/expressions/regexpExpressions.scala | 14 ++- .../catalyst/expressions/stringExpressions.scala | 106 ++++++++++++++++- .../scala/org/apache/spark/sql/SQLQuerySuite.scala | 8 ++ .../spark/sql/hive/execution/SQLQuerySuite.scala | 2 +- 28 files changed, 489 insertions(+), 25 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala index 94ac4bf09b..ff70774847 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala @@ -23,6 +23,8 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ +@ExpressionDescription( + usage = "_FUNC_(x) - Returns the mean calculated from values of a group.") case class Average(child: Expression) extends DeclarativeAggregate { override def prettyName: String = "avg" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala index 9d2db45144..17a7c6dce8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala @@ -130,6 +130,10 @@ abstract class CentralMomentAgg(child: Expression) extends DeclarativeAggregate } // Compute the population standard deviation of a column +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(x) - Returns the population standard deviation calculated from values of a group.") +// scalastyle:on line.size.limit case class StddevPop(child: Expression) extends CentralMomentAgg(child) { override protected def momentOrder = 2 @@ -143,6 +147,8 @@ case class StddevPop(child: Expression) extends CentralMomentAgg(child) { } // Compute the sample standard deviation of a column +@ExpressionDescription( + usage = "_FUNC_(x) - Returns the sample standard deviation calculated from values of a group.") case class StddevSamp(child: Expression) extends CentralMomentAgg(child) { override protected def momentOrder = 2 @@ -157,6 +163,8 @@ case class StddevSamp(child: Expression) extends CentralMomentAgg(child) { } // Compute the population variance of a column +@ExpressionDescription( + usage = "_FUNC_(x) - Returns the population variance calculated from values of a group.") case class VariancePop(child: Expression) extends CentralMomentAgg(child) { override protected def momentOrder = 2 @@ -170,6 +178,8 @@ case class VariancePop(child: Expression) extends CentralMomentAgg(child) { } // Compute the sample variance of a column +@ExpressionDescription( + usage = "_FUNC_(x) - Returns the sample variance calculated from values of a group.") case class VarianceSamp(child: Expression) extends CentralMomentAgg(child) { override protected def momentOrder = 2 @@ -183,6 +193,8 @@ case class VarianceSamp(child: Expression) extends CentralMomentAgg(child) { override def prettyName: String = "var_samp" } +@ExpressionDescription( + usage = "_FUNC_(x) - Returns the Skewness value calculated from values of a group.") case class Skewness(child: Expression) extends CentralMomentAgg(child) { override def prettyName: String = "skewness" @@ -196,6 +208,8 @@ case class Skewness(child: Expression) extends CentralMomentAgg(child) { } } +@ExpressionDescription( + usage = "_FUNC_(x) - Returns the Kurtosis value calculated from values of a group.") case class Kurtosis(child: Expression) extends CentralMomentAgg(child) { override protected def momentOrder = 4 diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala index e6b8214ef2..e29265e2f4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala @@ -28,6 +28,8 @@ import org.apache.spark.sql.types._ * Definition of Pearson correlation can be found at * http://en.wikipedia.org/wiki/Pearson_product-moment_correlation_coefficient */ +@ExpressionDescription( + usage = "_FUNC_(x,y) - Returns Pearson coefficient of correlation between a set of number pairs.") case class Corr(x: Expression, y: Expression) extends DeclarativeAggregate { override def children: Seq[Expression] = Seq(x, y) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala index 663c69e799..17ae012af7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala @@ -21,6 +21,12 @@ import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types._ +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = """_FUNC_(*) - Returns the total number of retrieved rows, including rows containing NULL values. + _FUNC_(expr) - Returns the number of rows for which the supplied expression is non-NULL. + _FUNC_(DISTINCT expr[, expr...]) - Returns the number of rows for which the supplied expression(s) are unique and non-NULL.""") +// scalastyle:on line.size.limit case class Count(children: Seq[Expression]) extends DeclarativeAggregate { override def nullable: Boolean = false diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Covariance.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Covariance.scala index c175a8c4c7..d80afbebf7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Covariance.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Covariance.scala @@ -76,6 +76,8 @@ abstract class Covariance(x: Expression, y: Expression) extends DeclarativeAggre } } +@ExpressionDescription( + usage = "_FUNC_(x,y) - Returns the population covariance of a set of number pairs.") case class CovPopulation(left: Expression, right: Expression) extends Covariance(left, right) { override val evaluateExpression: Expression = { If(n === Literal(0.0), Literal.create(null, DoubleType), @@ -85,6 +87,8 @@ case class CovPopulation(left: Expression, right: Expression) extends Covariance } +@ExpressionDescription( + usage = "_FUNC_(x,y) - Returns the sample covariance of a set of number pairs.") case class CovSample(left: Expression, right: Expression) extends Covariance(left, right) { override val evaluateExpression: Expression = { If(n === Literal(0.0), Literal.create(null, DoubleType), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/First.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/First.scala index 35f57426fe..b8ab0364dd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/First.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/First.scala @@ -28,6 +28,11 @@ import org.apache.spark.sql.types._ * is used) its result will not be deterministic (unless the input table is sorted and has * a single partition, and we use a single reducer to do the aggregation.). */ +@ExpressionDescription( + usage = """_FUNC_(expr) - Returns the first value of `child` for a group of rows. + _FUNC_(expr,isIgnoreNull=false) - Returns the first value of `child` for a group of rows. + If isIgnoreNull is true, returns only non-null values. + """) case class First(child: Expression, ignoreNullsExpr: Expression) extends DeclarativeAggregate { def this(child: Expression) = this(child, Literal.create(false, BooleanType)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlus.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlus.scala index b6bd56cff6..1d218da6db 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlus.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlus.scala @@ -20,8 +20,6 @@ package org.apache.spark.sql.catalyst.expressions.aggregate import java.lang.{Long => JLong} import java.util -import com.clearspring.analytics.hash.MurmurHash - import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ @@ -48,6 +46,11 @@ import org.apache.spark.sql.types._ * @param relativeSD the maximum estimation error allowed. */ // scalastyle:on +@ExpressionDescription( + usage = """_FUNC_(expr) - Returns the estimated cardinality by HyperLogLog++. + _FUNC_(expr, relativeSD=0.05) - Returns the estimated cardinality by HyperLogLog++ + with relativeSD, the maximum estimation error allowed. + """) case class HyperLogLogPlusPlus( child: Expression, relativeSD: Double = 0.05, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala index be7e12d7a2..b05d74b49b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala @@ -28,6 +28,8 @@ import org.apache.spark.sql.types._ * is used) its result will not be deterministic (unless the input table is sorted and has * a single partition, and we use a single reducer to do the aggregation.). */ +@ExpressionDescription( + usage = "_FUNC_(expr,isIgnoreNull) - Returns the last value of `child` for a group of rows.") case class Last(child: Expression, ignoreNullsExpr: Expression) extends DeclarativeAggregate { def this(child: Expression) = this(child, Literal.create(false, BooleanType)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Max.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Max.scala index 906003188d..c534fe495f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Max.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Max.scala @@ -22,6 +22,8 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ +@ExpressionDescription( + usage = "_FUNC_(expr) - Returns the maximum value of expr.") case class Max(child: Expression) extends DeclarativeAggregate { override def children: Seq[Expression] = child :: Nil diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Min.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Min.scala index 39f7afbd08..35289b4681 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Min.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Min.scala @@ -22,7 +22,8 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ - +@ExpressionDescription( + usage = "_FUNC_(expr) - Returns the minimum value of expr.") case class Min(child: Expression) extends DeclarativeAggregate { override def children: Seq[Expression] = child :: Nil diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala index 08a67ea3df..ad217f25b5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala @@ -22,6 +22,8 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ +@ExpressionDescription( + usage = "_FUNC_(x) - Returns the sum calculated from values of a group.") case class Sum(child: Expression) extends DeclarativeAggregate { override def children: Seq[Expression] = child :: Nil diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index b388091538..f3d42fc0b2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -23,7 +23,8 @@ import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.CalendarInterval - +@ExpressionDescription( + usage = "_FUNC_(a) - Returns -a.") case class UnaryMinus(child: Expression) extends UnaryExpression with ExpectsInputTypes with NullIntolerant { @@ -59,6 +60,8 @@ case class UnaryMinus(child: Expression) extends UnaryExpression override def sql: String = s"(-${child.sql})" } +@ExpressionDescription( + usage = "_FUNC_(a) - Returns a.") case class UnaryPositive(child: Expression) extends UnaryExpression with ExpectsInputTypes with NullIntolerant { override def prettyName: String = "positive" @@ -79,8 +82,8 @@ case class UnaryPositive(child: Expression) * A function that get the absolute value of the numeric value. */ @ExpressionDescription( - usage = "_FUNC_(expr) - Returns the absolute value of the numeric value", - extended = "> SELECT _FUNC_('-1');\n1") + usage = "_FUNC_(expr) - Returns the absolute value of the numeric value.", + extended = "> SELECT _FUNC_('-1');\n 1") case class Abs(child: Expression) extends UnaryExpression with ExpectsInputTypes with NullIntolerant { @@ -126,6 +129,8 @@ private[sql] object BinaryArithmetic { def unapply(e: BinaryArithmetic): Option[(Expression, Expression)] = Some((e.left, e.right)) } +@ExpressionDescription( + usage = "a _FUNC_ b - Returns a+b.") case class Add(left: Expression, right: Expression) extends BinaryArithmetic with NullIntolerant { override def inputType: AbstractDataType = TypeCollection.NumericAndInterval @@ -155,6 +160,8 @@ case class Add(left: Expression, right: Expression) extends BinaryArithmetic wit } } +@ExpressionDescription( + usage = "a _FUNC_ b - Returns a-b.") case class Subtract(left: Expression, right: Expression) extends BinaryArithmetic with NullIntolerant { @@ -185,6 +192,8 @@ case class Subtract(left: Expression, right: Expression) } } +@ExpressionDescription( + usage = "a _FUNC_ b - Multiplies a by b.") case class Multiply(left: Expression, right: Expression) extends BinaryArithmetic with NullIntolerant { @@ -198,6 +207,9 @@ case class Multiply(left: Expression, right: Expression) protected override def nullSafeEval(input1: Any, input2: Any): Any = numeric.times(input1, input2) } +@ExpressionDescription( + usage = "a _FUNC_ b - Divides a by b.", + extended = "> SELECT 3 _FUNC_ 2;\n 1.5") case class Divide(left: Expression, right: Expression) extends BinaryArithmetic with NullIntolerant { @@ -275,6 +287,8 @@ case class Divide(left: Expression, right: Expression) } } +@ExpressionDescription( + usage = "a _FUNC_ b - Returns the remainder when dividing a by b.") case class Remainder(left: Expression, right: Expression) extends BinaryArithmetic with NullIntolerant { @@ -464,6 +478,9 @@ case class MinOf(left: Expression, right: Expression) override def symbol: String = "min" } +@ExpressionDescription( + usage = "_FUNC_(a, b) - Returns the positive modulo", + extended = "> SELECT _FUNC_(10,3);\n 1") case class Pmod(left: Expression, right: Expression) extends BinaryArithmetic with NullIntolerant { override def toString: String = s"pmod($left, $right)" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala index 4c90b3f7d3..a7e1cd66f2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala @@ -26,6 +26,9 @@ import org.apache.spark.sql.types._ * * Code generation inherited from BinaryArithmetic. */ +@ExpressionDescription( + usage = "a _FUNC_ b - Bitwise AND.", + extended = "> SELECT 3 _FUNC_ 5; 1") case class BitwiseAnd(left: Expression, right: Expression) extends BinaryArithmetic { override def inputType: AbstractDataType = IntegralType @@ -51,6 +54,9 @@ case class BitwiseAnd(left: Expression, right: Expression) extends BinaryArithme * * Code generation inherited from BinaryArithmetic. */ +@ExpressionDescription( + usage = "a _FUNC_ b - Bitwise OR.", + extended = "> SELECT 3 _FUNC_ 5; 7") case class BitwiseOr(left: Expression, right: Expression) extends BinaryArithmetic { override def inputType: AbstractDataType = IntegralType @@ -76,6 +82,9 @@ case class BitwiseOr(left: Expression, right: Expression) extends BinaryArithmet * * Code generation inherited from BinaryArithmetic. */ +@ExpressionDescription( + usage = "a _FUNC_ b - Bitwise exclusive OR.", + extended = "> SELECT 3 _FUNC_ 5; 2") case class BitwiseXor(left: Expression, right: Expression) extends BinaryArithmetic { override def inputType: AbstractDataType = IntegralType @@ -99,6 +108,9 @@ case class BitwiseXor(left: Expression, right: Expression) extends BinaryArithme /** * A function that calculates bitwise not(~) of a number. */ +@ExpressionDescription( + usage = "_FUNC_ b - Bitwise NOT.", + extended = "> SELECT _FUNC_ 0; -1") case class BitwiseNot(child: Expression) extends UnaryExpression with ExpectsInputTypes { override def inputTypes: Seq[AbstractDataType] = Seq(IntegralType) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index e36c985249..ab790cf372 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -26,6 +26,8 @@ import org.apache.spark.sql.types._ /** * Given an array or map, returns its size. */ +@ExpressionDescription( + usage = "_FUNC_(expr) - Returns the size of an array or a map.") case class Size(child: Expression) extends UnaryExpression with ExpectsInputTypes { override def dataType: DataType = IntegerType override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(ArrayType, MapType)) @@ -44,6 +46,11 @@ case class Size(child: Expression) extends UnaryExpression with ExpectsInputType * Sorts the input array in ascending / descending order according to the natural ordering of * the array elements and returns it. */ +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(array(obj1, obj2,...)) - Sorts the input array in ascending order according to the natural ordering of the array elements.", + extended = " > SELECT _FUNC_(array('b', 'd', 'c', 'a'));\n 'a', 'b', 'c', 'd'") +// scalastyle:on line.size.limit case class SortArray(base: Expression, ascendingOrder: Expression) extends BinaryExpression with ExpectsInputTypes with CodegenFallback { @@ -125,6 +132,9 @@ case class SortArray(base: Expression, ascendingOrder: Expression) /** * Checks if the array (left) has the element (right) */ +@ExpressionDescription( + usage = "_FUNC_(array, value) - Returns TRUE if the array contains value.", + extended = " > SELECT _FUNC_(array(1, 2, 3), 2);\n true") case class ArrayContains(left: Expression, right: Expression) extends BinaryExpression with ImplicitCastInputTypes { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index c299586dde..74de4a776d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -27,6 +27,8 @@ import org.apache.spark.unsafe.types.UTF8String /** * Returns an Array containing the evaluation of all children expressions. */ +@ExpressionDescription( + usage = "_FUNC_(n0, ...) - Returns an array with the given elements.") case class CreateArray(children: Seq[Expression]) extends Expression { override def foldable: Boolean = children.forall(_.foldable) @@ -73,6 +75,8 @@ case class CreateArray(children: Seq[Expression]) extends Expression { * Returns a catalyst Map containing the evaluation of all children expressions as keys and values. * The children are a flatted sequence of kv pairs, e.g. (key1, value1, key2, value2, ...) */ +@ExpressionDescription( + usage = "_FUNC_(key0, value0, key1, value1...) - Creates a map with the given key/value pairs.") case class CreateMap(children: Seq[Expression]) extends Expression { private[sql] lazy val keys = children.indices.filter(_ % 2 == 0).map(children) private[sql] lazy val values = children.indices.filter(_ % 2 != 0).map(children) @@ -153,6 +157,8 @@ case class CreateMap(children: Seq[Expression]) extends Expression { /** * Returns a Row containing the evaluation of all children expressions. */ +@ExpressionDescription( + usage = "_FUNC_(col1, col2, col3, ...) - Creates a struct with the given field values.") case class CreateStruct(children: Seq[Expression]) extends Expression { override def foldable: Boolean = children.forall(_.foldable) @@ -204,6 +210,10 @@ case class CreateStruct(children: Seq[Expression]) extends Expression { * * @param children Seq(name1, val1, name2, val2, ...) */ +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(name1, val1, name2, val2, ...) - Creates a struct with the given field names and values.") +// scalastyle:on line.size.limit case class CreateNamedStruct(children: Seq[Expression]) extends Expression { /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala index 35a7b46020..ae6a94842f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala @@ -23,7 +23,10 @@ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ - +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(expr1,expr2,expr3) - If expr1 is TRUE then IF() returns expr2; otherwise it returns expr3.") +// scalastyle:on line.size.limit case class If(predicate: Expression, trueValue: Expression, falseValue: Expression) extends Expression { @@ -85,6 +88,10 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi * @param branches seq of (branch condition, branch value) * @param elseValue optional value for the else branch */ +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "CASE WHEN a THEN b [WHEN c THEN d]* [ELSE e] END - When a = true, returns b; when c = true, return d; else return e.") +// scalastyle:on line.size.limit case class CaseWhen(branches: Seq[(Expression, Expression)], elseValue: Option[Expression] = None) extends Expression with CodegenFallback { @@ -256,6 +263,8 @@ object CaseKeyWhen { * A function that returns the least value of all parameters, skipping null values. * It takes at least 2 parameters, and returns null iff all parameters are null. */ +@ExpressionDescription( + usage = "_FUNC_(n1, ...) - Returns the least value of all parameters, skipping null values.") case class Least(children: Seq[Expression]) extends Expression { override def nullable: Boolean = children.forall(_.nullable) @@ -315,6 +324,8 @@ case class Least(children: Seq[Expression]) extends Expression { * A function that returns the greatest value of all parameters, skipping null values. * It takes at least 2 parameters, and returns null iff all parameters are null. */ +@ExpressionDescription( + usage = "_FUNC_(n1, ...) - Returns the greatest value of all parameters, skipping null values.") case class Greatest(children: Seq[Expression]) extends Expression { override def nullable: Boolean = children.forall(_.nullable) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala index 1d0ea68d7a..9135753041 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala @@ -35,6 +35,8 @@ import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} * * There is no code generation since this expression should get constant folded by the optimizer. */ +@ExpressionDescription( + usage = "_FUNC_() - Returns the current date at the start of query evaluation.") case class CurrentDate() extends LeafExpression with CodegenFallback { override def foldable: Boolean = true override def nullable: Boolean = false @@ -54,6 +56,8 @@ case class CurrentDate() extends LeafExpression with CodegenFallback { * * There is no code generation since this expression should get constant folded by the optimizer. */ +@ExpressionDescription( + usage = "_FUNC_() - Returns the current timestamp at the start of query evaluation.") case class CurrentTimestamp() extends LeafExpression with CodegenFallback { override def foldable: Boolean = true override def nullable: Boolean = false @@ -70,6 +74,9 @@ case class CurrentTimestamp() extends LeafExpression with CodegenFallback { /** * Adds a number of days to startdate. */ +@ExpressionDescription( + usage = "_FUNC_(start_date, num_days) - Returns the date that is num_days after start_date.", + extended = "> SELECT _FUNC_('2016-07-30', 1);\n '2016-07-31'") case class DateAdd(startDate: Expression, days: Expression) extends BinaryExpression with ImplicitCastInputTypes { @@ -96,6 +103,9 @@ case class DateAdd(startDate: Expression, days: Expression) /** * Subtracts a number of days to startdate. */ +@ExpressionDescription( + usage = "_FUNC_(start_date, num_days) - Returns the date that is num_days before start_date.", + extended = "> SELECT _FUNC_('2016-07-30', 1);\n '2016-07-29'") case class DateSub(startDate: Expression, days: Expression) extends BinaryExpression with ImplicitCastInputTypes { override def left: Expression = startDate @@ -118,6 +128,9 @@ case class DateSub(startDate: Expression, days: Expression) override def prettyName: String = "date_sub" } +@ExpressionDescription( + usage = "_FUNC_(param) - Returns the hour component of the string/timestamp/interval.", + extended = "> SELECT _FUNC_('2009-07-30 12:58:59');\n 12") case class Hour(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType) @@ -134,6 +147,9 @@ case class Hour(child: Expression) extends UnaryExpression with ImplicitCastInpu } } +@ExpressionDescription( + usage = "_FUNC_(param) - Returns the minute component of the string/timestamp/interval.", + extended = "> SELECT _FUNC_('2009-07-30 12:58:59');\n 58") case class Minute(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType) @@ -150,6 +166,9 @@ case class Minute(child: Expression) extends UnaryExpression with ImplicitCastIn } } +@ExpressionDescription( + usage = "_FUNC_(param) - Returns the second component of the string/timestamp/interval.", + extended = "> SELECT _FUNC_('2009-07-30 12:58:59');\n 59") case class Second(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType) @@ -166,6 +185,9 @@ case class Second(child: Expression) extends UnaryExpression with ImplicitCastIn } } +@ExpressionDescription( + usage = "_FUNC_(param) - Returns the day of year of date/timestamp.", + extended = "> SELECT _FUNC_('2016-04-09');\n 100") case class DayOfYear(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { override def inputTypes: Seq[AbstractDataType] = Seq(DateType) @@ -182,7 +204,9 @@ case class DayOfYear(child: Expression) extends UnaryExpression with ImplicitCas } } - +@ExpressionDescription( + usage = "_FUNC_(param) - Returns the year component of the date/timestamp/interval.", + extended = "> SELECT _FUNC_('2016-07-30');\n 2016") case class Year(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { override def inputTypes: Seq[AbstractDataType] = Seq(DateType) @@ -199,6 +223,8 @@ case class Year(child: Expression) extends UnaryExpression with ImplicitCastInpu } } +@ExpressionDescription( + usage = "_FUNC_(param) - Returns the quarter of the year for date, in the range 1 to 4.") case class Quarter(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { override def inputTypes: Seq[AbstractDataType] = Seq(DateType) @@ -215,6 +241,9 @@ case class Quarter(child: Expression) extends UnaryExpression with ImplicitCastI } } +@ExpressionDescription( + usage = "_FUNC_(param) - Returns the month component of the date/timestamp/interval", + extended = "> SELECT _FUNC_('2016-07-30');\n 7") case class Month(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { override def inputTypes: Seq[AbstractDataType] = Seq(DateType) @@ -231,6 +260,9 @@ case class Month(child: Expression) extends UnaryExpression with ImplicitCastInp } } +@ExpressionDescription( + usage = "_FUNC_(param) - Returns the day of month of date/timestamp, or the day of interval.", + extended = "> SELECT _FUNC_('2009-07-30');\n 30") case class DayOfMonth(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { override def inputTypes: Seq[AbstractDataType] = Seq(DateType) @@ -247,6 +279,9 @@ case class DayOfMonth(child: Expression) extends UnaryExpression with ImplicitCa } } +@ExpressionDescription( + usage = "_FUNC_(param) - Returns the week of the year of the given date.", + extended = "> SELECT _FUNC_('2008-02-20');\n 8") case class WeekOfYear(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { override def inputTypes: Seq[AbstractDataType] = Seq(DateType) @@ -283,6 +318,11 @@ case class WeekOfYear(child: Expression) extends UnaryExpression with ImplicitCa } } +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(date/timestamp/string, fmt) - Converts a date/timestamp/string to a value of string in the format specified by the date format fmt.", + extended = "> SELECT _FUNC_('2016-04-08', 'y')\n '2016'") +// scalastyle:on line.size.limit case class DateFormatClass(left: Expression, right: Expression) extends BinaryExpression with ImplicitCastInputTypes { @@ -310,6 +350,8 @@ case class DateFormatClass(left: Expression, right: Expression) extends BinaryEx * Converts time string with given pattern. * Deterministic version of [[UnixTimestamp]], must have at least one parameter. */ +@ExpressionDescription( + usage = "_FUNC_(date[, pattern]) - Returns the UNIX timestamp of the give time.") case class ToUnixTimestamp(timeExp: Expression, format: Expression) extends UnixTime { override def left: Expression = timeExp override def right: Expression = format @@ -331,6 +373,8 @@ case class ToUnixTimestamp(timeExp: Expression, format: Expression) extends Unix * If the first parameter is a Date or Timestamp instead of String, we will ignore the * second parameter. */ +@ExpressionDescription( + usage = "_FUNC_([date[, pattern]]) - Returns the UNIX timestamp of current or specified time.") case class UnixTimestamp(timeExp: Expression, format: Expression) extends UnixTime { override def left: Expression = timeExp override def right: Expression = format @@ -459,6 +503,9 @@ abstract class UnixTime extends BinaryExpression with ExpectsInputTypes { * format. If the format is missing, using format like "1970-01-01 00:00:00". * Note that hive Language Manual says it returns 0 if fail, but in fact it returns null. */ +@ExpressionDescription( + usage = "_FUNC_(unix_time, format) - Returns unix_time in the specified format", + extended = "> SELECT _FUNC_(0, 'yyyy-MM-dd HH:mm:ss');\n '1970-01-01 00:00:00'") case class FromUnixTime(sec: Expression, format: Expression) extends BinaryExpression with ImplicitCastInputTypes { @@ -544,6 +591,9 @@ case class FromUnixTime(sec: Expression, format: Expression) /** * Returns the last day of the month which the date belongs to. */ +@ExpressionDescription( + usage = "_FUNC_(date) - Returns the last day of the month which the date belongs to.", + extended = "> SELECT _FUNC_('2009-01-12');\n '2009-01-31'") case class LastDay(startDate: Expression) extends UnaryExpression with ImplicitCastInputTypes { override def child: Expression = startDate @@ -570,6 +620,11 @@ case class LastDay(startDate: Expression) extends UnaryExpression with ImplicitC * * Allowed "dayOfWeek" is defined in [[DateTimeUtils.getDayOfWeekFromString]]. */ +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(start_date, day_of_week) - Returns the first date which is later than start_date and named as indicated.", + extended = "> SELECT _FUNC_('2015-01-14', 'TU');\n '2015-01-20'") +// scalastyle:on line.size.limit case class NextDay(startDate: Expression, dayOfWeek: Expression) extends BinaryExpression with ImplicitCastInputTypes { @@ -654,6 +709,10 @@ case class TimeAdd(start: Expression, interval: Expression) /** * Assumes given timestamp is UTC and converts to given timezone. */ +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(timestamp, string timezone) - Assumes given timestamp is UTC and converts to given timezone.") +// scalastyle:on line.size.limit case class FromUTCTimestamp(left: Expression, right: Expression) extends BinaryExpression with ImplicitCastInputTypes { @@ -729,6 +788,9 @@ case class TimeSub(start: Expression, interval: Expression) /** * Returns the date that is num_months after start_date. */ +@ExpressionDescription( + usage = "_FUNC_(start_date, num_months) - Returns the date that is num_months after start_date.", + extended = "> SELECT _FUNC_('2016-08-31', 1);\n '2016-09-30'") case class AddMonths(startDate: Expression, numMonths: Expression) extends BinaryExpression with ImplicitCastInputTypes { @@ -756,6 +818,9 @@ case class AddMonths(startDate: Expression, numMonths: Expression) /** * Returns number of months between dates date1 and date2. */ +@ExpressionDescription( + usage = "_FUNC_(date1, date2) - returns number of months between dates date1 and date2.", + extended = "> SELECT _FUNC_('1997-02-28 10:30:00', '1996-10-30');\n 3.94959677") case class MonthsBetween(date1: Expression, date2: Expression) extends BinaryExpression with ImplicitCastInputTypes { @@ -783,6 +848,10 @@ case class MonthsBetween(date1: Expression, date2: Expression) /** * Assumes given timestamp is in given timezone and converts to UTC. */ +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(timestamp, string timezone) - Assumes given timestamp is in given timezone and converts to UTC.") +// scalastyle:on line.size.limit case class ToUTCTimestamp(left: Expression, right: Expression) extends BinaryExpression with ImplicitCastInputTypes { @@ -830,6 +899,9 @@ case class ToUTCTimestamp(left: Expression, right: Expression) /** * Returns the date part of a timestamp or string. */ +@ExpressionDescription( + usage = "_FUNC_(expr) - Extracts the date part of the date or datetime expression expr.", + extended = "> SELECT _FUNC_('2009-07-30 04:17:52');\n '2009-07-30'") case class ToDate(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { // Implicit casting of spark will accept string in both date and timestamp format, as @@ -850,6 +922,11 @@ case class ToDate(child: Expression) extends UnaryExpression with ImplicitCastIn /** * Returns date truncated to the unit specified by the format. */ +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(date, fmt) - Returns returns date with the time portion of the day truncated to the unit specified by the format model fmt.", + extended = "> SELECT _FUNC_('2009-02-12', 'MM')\n '2009-02-01'\n> SELECT _FUNC_('2015-10-27', 'YEAR');\n '2015-01-01'") +// scalastyle:on line.size.limit case class TruncDate(date: Expression, format: Expression) extends BinaryExpression with ImplicitCastInputTypes { override def left: Expression = date @@ -921,6 +998,9 @@ case class TruncDate(date: Expression, format: Expression) /** * Returns the number of days from startDate to endDate. */ +@ExpressionDescription( + usage = "_FUNC_(date1, date2) - Returns the number of days between date1 and date2.", + extended = "> SELECT _FUNC_('2009-07-30', '2009-07-31');\n 1") case class DateDiff(endDate: Expression, startDate: Expression) extends BinaryExpression with ImplicitCastInputTypes { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala index e7ef21aa85..65d7a1d5a0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala @@ -99,6 +99,10 @@ case class UserDefinedGenerator( /** * Given an input array produces a sequence of rows for each value in the array. */ +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(a) - Separates the elements of array a into multiple rows, or the elements of a map into multiple rows and columns.") +// scalastyle:on line.size.limit case class Explode(child: Expression) extends UnaryExpression with Generator with CodegenFallback { override def children: Seq[Expression] = child :: Nil diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala index 72b323587c..ecd09b7083 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala @@ -106,6 +106,8 @@ private[this] object SharedFactory { * Extracts json object from a json string based on json path specified, and returns json string * of the extracted json object. It will return null if the input json string is invalid. */ +@ExpressionDescription( + usage = "_FUNC_(json_txt, path) - Extract a json object from path") case class GetJsonObject(json: Expression, path: Expression) extends BinaryExpression with ExpectsInputTypes with CodegenFallback { @@ -319,6 +321,10 @@ case class GetJsonObject(json: Expression, path: Expression) } } +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(jsonStr, p1, p2, ..., pn) - like get_json_object, but it takes multiple names and return a tuple. All the input parameters and output column types are string.") +// scalastyle:on line.size.limit case class JsonTuple(children: Seq[Expression]) extends Generator with CodegenFallback { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala index e3d1bc127d..c8a28e8477 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala @@ -50,6 +50,7 @@ abstract class LeafMathExpression(c: Double, name: String) /** * A unary expression specifically for math functions. Math Functions expect a specific type of * input format, therefore these functions extend `ExpectsInputTypes`. + * * @param f The math function. * @param name The short name of the function */ @@ -103,6 +104,7 @@ abstract class UnaryLogExpression(f: Double => Double, name: String) /** * A binary expression specifically for math functions that take two `Double`s as input and returns * a `Double`. + * * @param f The math function. * @param name The short name of the function */ @@ -136,12 +138,18 @@ abstract class BinaryMathExpression(f: (Double, Double) => Double, name: String) * Euler's number. Note that there is no code generation because this is only * evaluated by the optimizer during constant folding. */ +@ExpressionDescription( + usage = "_FUNC_() - Returns Euler's number, E.", + extended = "> SELECT _FUNC_();\n 2.718281828459045") case class EulerNumber() extends LeafMathExpression(math.E, "E") /** * Pi. Note that there is no code generation because this is only * evaluated by the optimizer during constant folding. */ +@ExpressionDescription( + usage = "_FUNC_() - Returns PI.", + extended = "> SELECT _FUNC_();\n 3.141592653589793") case class Pi() extends LeafMathExpression(math.Pi, "PI") //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -150,14 +158,29 @@ case class Pi() extends LeafMathExpression(math.Pi, "PI") //////////////////////////////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////////////////////////// +@ExpressionDescription( + usage = "_FUNC_(x) - Returns the arc cosine of x if -1<=x<=1 or NaN otherwise.", + extended = "> SELECT _FUNC_(1);\n 0.0\n> SELECT _FUNC_(2);\n NaN") case class Acos(child: Expression) extends UnaryMathExpression(math.acos, "ACOS") +@ExpressionDescription( + usage = "_FUNC_(x) - Returns the arc sin of x if -1<=x<=1 or NaN otherwise.", + extended = "> SELECT _FUNC_(0);\n 0.0\n> SELECT _FUNC_(2);\n NaN") case class Asin(child: Expression) extends UnaryMathExpression(math.asin, "ASIN") +@ExpressionDescription( + usage = "_FUNC_(x) - Returns the arc tangent.", + extended = "> SELECT _FUNC_(0);\n 0.0") case class Atan(child: Expression) extends UnaryMathExpression(math.atan, "ATAN") +@ExpressionDescription( + usage = "_FUNC_(x) - Returns the cube root of a double value.", + extended = "> SELECT _FUNC_(27.0);\n 3.0") case class Cbrt(child: Expression) extends UnaryMathExpression(math.cbrt, "CBRT") +@ExpressionDescription( + usage = "_FUNC_(x) - Returns the smallest integer not smaller than x.", + extended = "> SELECT _FUNC_(-0.1);\n 0\n> SELECT _FUNC_(5);\n 5") case class Ceil(child: Expression) extends UnaryMathExpression(math.ceil, "CEIL") { override def dataType: DataType = child.dataType match { case dt @ DecimalType.Fixed(_, 0) => dt @@ -184,16 +207,26 @@ case class Ceil(child: Expression) extends UnaryMathExpression(math.ceil, "CEIL" } } +@ExpressionDescription( + usage = "_FUNC_(x) - Returns the cosine of x.", + extended = "> SELECT _FUNC_(0);\n 1.0") case class Cos(child: Expression) extends UnaryMathExpression(math.cos, "COS") +@ExpressionDescription( + usage = "_FUNC_(x) - Returns the hyperbolic cosine of x.", + extended = "> SELECT _FUNC_(0);\n 1.0") case class Cosh(child: Expression) extends UnaryMathExpression(math.cosh, "COSH") /** * Convert a num from one base to another + * * @param numExpr the number to be converted * @param fromBaseExpr from which base * @param toBaseExpr to which base */ +@ExpressionDescription( + usage = "_FUNC_(num, from_base, to_base) - Convert num from from_base to to_base.", + extended = "> SELECT _FUNC_('100', 2, 10);\n '4'\n> SELECT _FUNC_(-10, 16, -10);\n '16'") case class Conv(numExpr: Expression, fromBaseExpr: Expression, toBaseExpr: Expression) extends TernaryExpression with ImplicitCastInputTypes { @@ -222,10 +255,19 @@ case class Conv(numExpr: Expression, fromBaseExpr: Expression, toBaseExpr: Expre } } +@ExpressionDescription( + usage = "_FUNC_(x) - Returns e to the power of x.", + extended = "> SELECT _FUNC_(0);\n 1.0") case class Exp(child: Expression) extends UnaryMathExpression(math.exp, "EXP") +@ExpressionDescription( + usage = "_FUNC_(x) - Returns exp(x) - 1.", + extended = "> SELECT _FUNC_(0);\n 0.0") case class Expm1(child: Expression) extends UnaryMathExpression(math.expm1, "EXPM1") +@ExpressionDescription( + usage = "_FUNC_(x) - Returns the largest integer not greater than x.", + extended = "> SELECT _FUNC_(-0.1);\n -1\n> SELECT _FUNC_(5);\n 5") case class Floor(child: Expression) extends UnaryMathExpression(math.floor, "FLOOR") { override def dataType: DataType = child.dataType match { case dt @ DecimalType.Fixed(_, 0) => dt @@ -283,6 +325,9 @@ object Factorial { ) } +@ExpressionDescription( + usage = "_FUNC_(n) - Returns n factorial for n is [0..20]. Otherwise, NULL.", + extended = "> SELECT _FUNC_(5);\n 120") case class Factorial(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { override def inputTypes: Seq[DataType] = Seq(IntegerType) @@ -315,8 +360,14 @@ case class Factorial(child: Expression) extends UnaryExpression with ImplicitCas } } +@ExpressionDescription( + usage = "_FUNC_(x) - Returns the natural logarithm of x with base e.", + extended = "> SELECT _FUNC_(1);\n 0.0") case class Log(child: Expression) extends UnaryLogExpression(math.log, "LOG") +@ExpressionDescription( + usage = "_FUNC_(x) - Returns the logarithm of x with base 2.", + extended = "> SELECT _FUNC_(2);\n 1.0") case class Log2(child: Expression) extends UnaryLogExpression((x: Double) => math.log(x) / math.log(2), "LOG2") { override def genCode(ctx: CodegenContext, ev: ExprCode): String = { @@ -332,36 +383,72 @@ case class Log2(child: Expression) } } +@ExpressionDescription( + usage = "_FUNC_(x) - Returns the logarithm of x with base 10.", + extended = "> SELECT _FUNC_(10);\n 1.0") case class Log10(child: Expression) extends UnaryLogExpression(math.log10, "LOG10") +@ExpressionDescription( + usage = "_FUNC_(x) - Returns log(1 + x).", + extended = "> SELECT _FUNC_(0);\n 0.0") case class Log1p(child: Expression) extends UnaryLogExpression(math.log1p, "LOG1P") { protected override val yAsymptote: Double = -1.0 } +@ExpressionDescription( + usage = "_FUNC_(x, d) - Return the rounded x at d decimal places.", + extended = "> SELECT _FUNC_(12.3456, 1);\n 12.3") case class Rint(child: Expression) extends UnaryMathExpression(math.rint, "ROUND") { override def funcName: String = "rint" } +@ExpressionDescription( + usage = "_FUNC_(x) - Returns the sign of x.", + extended = "> SELECT _FUNC_(40);\n 1.0") case class Signum(child: Expression) extends UnaryMathExpression(math.signum, "SIGNUM") +@ExpressionDescription( + usage = "_FUNC_(x) - Returns the sine of x.", + extended = "> SELECT _FUNC_(0);\n 0.0") case class Sin(child: Expression) extends UnaryMathExpression(math.sin, "SIN") +@ExpressionDescription( + usage = "_FUNC_(x) - Returns the hyperbolic sine of x.", + extended = "> SELECT _FUNC_(0);\n 0.0") case class Sinh(child: Expression) extends UnaryMathExpression(math.sinh, "SINH") +@ExpressionDescription( + usage = "_FUNC_(x) - Returns the square root of x.", + extended = "> SELECT _FUNC_(4);\n 2.0") case class Sqrt(child: Expression) extends UnaryMathExpression(math.sqrt, "SQRT") +@ExpressionDescription( + usage = "_FUNC_(x) - Returns the tangent of x.", + extended = "> SELECT _FUNC_(0);\n 0.0") case class Tan(child: Expression) extends UnaryMathExpression(math.tan, "TAN") +@ExpressionDescription( + usage = "_FUNC_(x) - Returns the hyperbolic tangent of x.", + extended = "> SELECT _FUNC_(0);\n 0.0") case class Tanh(child: Expression) extends UnaryMathExpression(math.tanh, "TANH") +@ExpressionDescription( + usage = "_FUNC_(x) - Converts radians to degrees.", + extended = "> SELECT _FUNC_(3.141592653589793);\n 180.0") case class ToDegrees(child: Expression) extends UnaryMathExpression(math.toDegrees, "DEGREES") { override def funcName: String = "toDegrees" } +@ExpressionDescription( + usage = "_FUNC_(x) - Converts degrees to radians.", + extended = "> SELECT _FUNC_(180);\n 3.141592653589793") case class ToRadians(child: Expression) extends UnaryMathExpression(math.toRadians, "RADIANS") { override def funcName: String = "toRadians" } +@ExpressionDescription( + usage = "_FUNC_(x) - Returns x in binary.", + extended = "> SELECT _FUNC_(13);\n '1101'") case class Bin(child: Expression) extends UnaryExpression with Serializable with ImplicitCastInputTypes { @@ -453,6 +540,9 @@ object Hex { * Otherwise if the number is a STRING, it converts each character into its hex representation * and returns the resulting STRING. Negative numbers would be treated as two's complement. */ +@ExpressionDescription( + usage = "_FUNC_(x) - Convert the argument to hexadecimal.", + extended = "> SELECT _FUNC_(17);\n '11'\n> SELECT _FUNC_('Spark SQL');\n '537061726B2053514C'") case class Hex(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { override def inputTypes: Seq[AbstractDataType] = @@ -481,6 +571,9 @@ case class Hex(child: Expression) extends UnaryExpression with ImplicitCastInput * Performs the inverse operation of HEX. * Resulting characters are returned as a byte array. */ +@ExpressionDescription( + usage = "_FUNC_(x) - Converts hexadecimal argument to binary.", + extended = "> SELECT decode(_FUNC_('537061726B2053514C'),'UTF-8');\n 'Spark SQL'") case class Unhex(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { override def inputTypes: Seq[AbstractDataType] = Seq(StringType) @@ -509,7 +602,9 @@ case class Unhex(child: Expression) extends UnaryExpression with ImplicitCastInp //////////////////////////////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////////////////////////// - +@ExpressionDescription( + usage = "_FUNC_(x,y) - Returns the arc tangent2.", + extended = "> SELECT _FUNC_(0, 0);\n 0.0") case class Atan2(left: Expression, right: Expression) extends BinaryMathExpression(math.atan2, "ATAN2") { @@ -523,6 +618,9 @@ case class Atan2(left: Expression, right: Expression) } } +@ExpressionDescription( + usage = "_FUNC_(x1, x2) - Raise x1 to the power of x2.", + extended = "> SELECT _FUNC_(2, 3);\n 8.0") case class Pow(left: Expression, right: Expression) extends BinaryMathExpression(math.pow, "POWER") { override def genCode(ctx: CodegenContext, ev: ExprCode): String = { @@ -532,10 +630,14 @@ case class Pow(left: Expression, right: Expression) /** - * Bitwise unsigned left shift. + * Bitwise left shift. + * * @param left the base number to shift. * @param right number of bits to left shift. */ +@ExpressionDescription( + usage = "_FUNC_(a, b) - Bitwise left shift.", + extended = "> SELECT _FUNC_(2, 1);\n 4") case class ShiftLeft(left: Expression, right: Expression) extends BinaryExpression with ImplicitCastInputTypes { @@ -558,10 +660,14 @@ case class ShiftLeft(left: Expression, right: Expression) /** - * Bitwise unsigned left shift. + * Bitwise right shift. + * * @param left the base number to shift. - * @param right number of bits to left shift. + * @param right number of bits to right shift. */ +@ExpressionDescription( + usage = "_FUNC_(a, b) - Bitwise right shift.", + extended = "> SELECT _FUNC_(4, 1);\n 2") case class ShiftRight(left: Expression, right: Expression) extends BinaryExpression with ImplicitCastInputTypes { @@ -585,9 +691,13 @@ case class ShiftRight(left: Expression, right: Expression) /** * Bitwise unsigned right shift, for integer and long data type. + * * @param left the base number. * @param right the number of bits to right shift. */ +@ExpressionDescription( + usage = "_FUNC_(a, b) - Bitwise unsigned right shift.", + extended = "> SELECT _FUNC_(4, 1);\n 2") case class ShiftRightUnsigned(left: Expression, right: Expression) extends BinaryExpression with ImplicitCastInputTypes { @@ -608,16 +718,22 @@ case class ShiftRightUnsigned(left: Expression, right: Expression) } } - +@ExpressionDescription( + usage = "_FUNC_(a, b) - Returns sqrt(a**2 + b**2).", + extended = "> SELECT _FUNC_(3, 4);\n 5.0") case class Hypot(left: Expression, right: Expression) extends BinaryMathExpression(math.hypot, "HYPOT") /** * Computes the logarithm of a number. + * * @param left the logarithm base, default to e. * @param right the number to compute the logarithm of. */ +@ExpressionDescription( + usage = "_FUNC_(b, x) - Returns the logarithm of x with base b.", + extended = "> SELECT _FUNC_(10, 100);\n 2.0") case class Logarithm(left: Expression, right: Expression) extends BinaryMathExpression((c1, c2) => math.log(c2) / math.log(c1), "LOG") { @@ -674,6 +790,9 @@ case class Logarithm(left: Expression, right: Expression) * @param child expr to be round, all [[NumericType]] is allowed as Input * @param scale new scale to be round to, this should be a constant int at runtime */ +@ExpressionDescription( + usage = "_FUNC_(x, d) - Round x to d decimal places.", + extended = "> SELECT _FUNC_(12.3456, 1);\n 12.3") case class Round(child: Expression, scale: Expression) extends BinaryExpression with ImplicitCastInputTypes { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala index eb8dc1423a..4bd918ed01 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala @@ -438,6 +438,8 @@ abstract class InterpretedHashFunction { * We should use this hash function for both shuffle and bucket, so that we can guarantee shuffle * and bucketing have same data distribution. */ +@ExpressionDescription( + usage = "_FUNC_(a1, a2, ...) - Returns a hash value of the arguments.") case class Murmur3Hash(children: Seq[Expression], seed: Int) extends HashExpression[Int] { def this(arguments: Seq[Expression]) = this(arguments, 42) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala index e22026d584..6a45249943 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala @@ -34,6 +34,9 @@ import org.apache.spark.sql.types._ * coalesce(null, null, null) => null * }}} */ +@ExpressionDescription( + usage = "_FUNC_(a1, a2, ...) - Returns the first non-null argument if exists. Otherwise, NULL.", + extended = "> SELECT _FUNC_(NULL, 1, NULL);\n 1") case class Coalesce(children: Seq[Expression]) extends Expression { /** Coalesce is nullable if all of its children are nullable, or if it has no children. */ @@ -89,6 +92,8 @@ case class Coalesce(children: Seq[Expression]) extends Expression { /** * Evaluates to `true` iff it's NaN. */ +@ExpressionDescription( + usage = "_FUNC_(a) - Returns true if a is NaN and false otherwise.") case class IsNaN(child: Expression) extends UnaryExpression with Predicate with ImplicitCastInputTypes { @@ -126,6 +131,8 @@ case class IsNaN(child: Expression) extends UnaryExpression * An Expression evaluates to `left` iff it's not NaN, or evaluates to `right` otherwise. * This Expression is useful for mapping NaN values to null. */ +@ExpressionDescription( + usage = "_FUNC_(a,b) - Returns a iff it's not NaN, or b otherwise.") case class NaNvl(left: Expression, right: Expression) extends BinaryExpression with ImplicitCastInputTypes { @@ -180,6 +187,8 @@ case class NaNvl(left: Expression, right: Expression) /** * An expression that is evaluated to true if the input is null. */ +@ExpressionDescription( + usage = "_FUNC_(a) - Returns true if a is NULL and false otherwise.") case class IsNull(child: Expression) extends UnaryExpression with Predicate { override def nullable: Boolean = false @@ -201,6 +210,8 @@ case class IsNull(child: Expression) extends UnaryExpression with Predicate { /** * An expression that is evaluated to true if the input is not null. */ +@ExpressionDescription( + usage = "_FUNC_(a) - Returns true if a is not NULL and false otherwise.") case class IsNotNull(child: Expression) extends UnaryExpression with Predicate { override def nullable: Boolean = false diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 4eb33258ac..38f1210a4e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -88,7 +88,8 @@ trait PredicateHelper { expr.references.subsetOf(plan.outputSet) } - +@ExpressionDescription( + usage = "_FUNC_ a - Logical not") case class Not(child: Expression) extends UnaryExpression with Predicate with ImplicitCastInputTypes with NullIntolerant { @@ -109,6 +110,8 @@ case class Not(child: Expression) /** * Evaluates to `true` if `list` contains `value`. */ +@ExpressionDescription( + usage = "expr _FUNC_(val1, val2, ...) - Returns true if expr equals to any valN.") case class In(value: Expression, list: Seq[Expression]) extends Predicate with ImplicitCastInputTypes { @@ -243,6 +246,8 @@ case class InSet(child: Expression, hset: Set[Any]) extends UnaryExpression with } } +@ExpressionDescription( + usage = "a _FUNC_ b - Logical AND.") case class And(left: Expression, right: Expression) extends BinaryOperator with Predicate { override def inputType: AbstractDataType = BooleanType @@ -306,7 +311,8 @@ case class And(left: Expression, right: Expression) extends BinaryOperator with } } - +@ExpressionDescription( + usage = "a _FUNC_ b - Logical OR.") case class Or(left: Expression, right: Expression) extends BinaryOperator with Predicate { override def inputType: AbstractDataType = BooleanType @@ -401,7 +407,8 @@ private[sql] object Equality { } } - +@ExpressionDescription( + usage = "a _FUNC_ b - Returns TRUE if a equals b and false otherwise.") case class EqualTo(left: Expression, right: Expression) extends BinaryComparison with NullIntolerant { @@ -426,7 +433,9 @@ case class EqualTo(left: Expression, right: Expression) } } - +@ExpressionDescription( + usage = """a _FUNC_ b - Returns same result with EQUAL(=) operator for non-null operands, + but returns TRUE if both are NULL, FALSE if one of the them is NULL.""") case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComparison { override def inputType: AbstractDataType = AnyDataType @@ -467,7 +476,8 @@ case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComp } } - +@ExpressionDescription( + usage = "a _FUNC_ b - Returns TRUE if a is less than b.") case class LessThan(left: Expression, right: Expression) extends BinaryComparison with NullIntolerant { @@ -480,7 +490,8 @@ case class LessThan(left: Expression, right: Expression) protected override def nullSafeEval(input1: Any, input2: Any): Any = ordering.lt(input1, input2) } - +@ExpressionDescription( + usage = "a _FUNC_ b - Returns TRUE if a is not greater than b.") case class LessThanOrEqual(left: Expression, right: Expression) extends BinaryComparison with NullIntolerant { @@ -493,7 +504,8 @@ case class LessThanOrEqual(left: Expression, right: Expression) protected override def nullSafeEval(input1: Any, input2: Any): Any = ordering.lteq(input1, input2) } - +@ExpressionDescription( + usage = "a _FUNC_ b - Returns TRUE if a is greater than b.") case class GreaterThan(left: Expression, right: Expression) extends BinaryComparison with NullIntolerant { @@ -506,7 +518,8 @@ case class GreaterThan(left: Expression, right: Expression) protected override def nullSafeEval(input1: Any, input2: Any): Any = ordering.gt(input1, input2) } - +@ExpressionDescription( + usage = "a _FUNC_ b - Returns TRUE if a is not smaller than b.") case class GreaterThanOrEqual(left: Expression, right: Expression) extends BinaryComparison with NullIntolerant { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala index 6be3cbcae6..1ec092a5be 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala @@ -55,6 +55,8 @@ abstract class RDG extends LeafExpression with Nondeterministic { } /** Generate a random column with i.i.d. uniformly distributed values in [0, 1). */ +@ExpressionDescription( + usage = "_FUNC_(a) - Returns a random column with i.i.d. uniformly distributed values in [0, 1).") case class Rand(seed: Long) extends RDG { override protected def evalInternal(input: InternalRow): Double = rng.nextDouble() @@ -78,6 +80,8 @@ case class Rand(seed: Long) extends RDG { } /** Generate a random column with i.i.d. gaussian random distribution. */ +@ExpressionDescription( + usage = "_FUNC_(a) - Returns a random column with i.i.d. gaussian random distribution.") case class Randn(seed: Long) extends RDG { override protected def evalInternal(input: InternalRow): Double = rng.nextGaussian() diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala index b68009331b..85a5429263 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala @@ -67,6 +67,8 @@ trait StringRegexExpression extends ImplicitCastInputTypes { /** * Simple RegEx pattern matching function */ +@ExpressionDescription( + usage = "str _FUNC_ pattern - Returns true if str matches pattern and false otherwise.") case class Like(left: Expression, right: Expression) extends BinaryExpression with StringRegexExpression { @@ -117,7 +119,8 @@ case class Like(left: Expression, right: Expression) } } - +@ExpressionDescription( + usage = "str _FUNC_ regexp - Returns true if str matches regexp and false otherwise.") case class RLike(left: Expression, right: Expression) extends BinaryExpression with StringRegexExpression { @@ -169,6 +172,9 @@ case class RLike(left: Expression, right: Expression) /** * Splits str around pat (pattern is a regular expression). */ +@ExpressionDescription( + usage = "_FUNC_(str, regex) - Splits str around occurrences that match regex", + extended = "> SELECT _FUNC_('oneAtwoBthreeC', '[ABC]');\n ['one', 'two', 'three']") case class StringSplit(str: Expression, pattern: Expression) extends BinaryExpression with ImplicitCastInputTypes { @@ -198,6 +204,9 @@ case class StringSplit(str: Expression, pattern: Expression) * * NOTE: this expression is not THREAD-SAFE, as it has some internal mutable status. */ +@ExpressionDescription( + usage = "_FUNC_(str, regexp, rep) - replace all substrings of str that match regexp with rep.", + extended = "> SELECT _FUNC_('100-200', '(\\d+)', 'num');\n 'num-num'") case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expression) extends TernaryExpression with ImplicitCastInputTypes { @@ -289,6 +298,9 @@ case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expressio * * NOTE: this expression is not THREAD-SAFE, as it has some internal mutable status. */ +@ExpressionDescription( + usage = "_FUNC_(str, regexp[, idx]) - extracts a group that matches regexp.", + extended = "> SELECT _FUNC_('100-200', '(\\d+)-(\\d+)', 1);\n '100'") case class RegExpExtract(subject: Expression, regexp: Expression, idx: Expression) extends TernaryExpression with ImplicitCastInputTypes { def this(s: Expression, r: Expression) = this(s, r, Literal(1)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index 7e0e7a833b..a17482697d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -35,6 +35,9 @@ import org.apache.spark.unsafe.types.{ByteArray, UTF8String} * An expression that concatenates multiple input strings into a single string. * If any input is null, concat returns null. */ +@ExpressionDescription( + usage = "_FUNC_(str1, str2, ..., strN) - Returns the concatenation of str1, str2, ..., strN", + extended = "> SELECT _FUNC_('Spark','SQL');\n 'SparkSQL'") case class Concat(children: Seq[Expression]) extends Expression with ImplicitCastInputTypes { override def inputTypes: Seq[AbstractDataType] = Seq.fill(children.size)(StringType) @@ -70,6 +73,10 @@ case class Concat(children: Seq[Expression]) extends Expression with ImplicitCas * * Returns null if the separator is null. Otherwise, concat_ws skips all null values. */ +@ExpressionDescription( + usage = + "_FUNC_(sep, [str | array(str)]+) - Returns the concatenation of the strings separated by sep.", + extended = "> SELECT _FUNC_(' ', Spark', 'SQL');\n 'Spark SQL'") case class ConcatWs(children: Seq[Expression]) extends Expression with ImplicitCastInputTypes { @@ -188,7 +195,7 @@ case class Upper(child: Expression) */ @ExpressionDescription( usage = "_FUNC_(str) - Returns str with all characters changed to lowercase", - extended = "> SELECT _FUNC_('SparkSql');\n'sparksql'") + extended = "> SELECT _FUNC_('SparkSql');\n 'sparksql'") case class Lower(child: Expression) extends UnaryExpression with String2StringExpression { override def convert(v: UTF8String): UTF8String = v.toLowerCase @@ -270,6 +277,11 @@ object StringTranslate { * The translate will happen when any character in the string matching with the character * in the `matchingExpr`. */ +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = """_FUNC_(input, from, to) - Translates the input string by replacing the characters present in the from string with the corresponding characters in the to string""", + extended = "> SELECT _FUNC_('AaBbCc', 'abc', '123');\n 'A1B2C3'") +// scalastyle:on line.size.limit case class StringTranslate(srcExpr: Expression, matchingExpr: Expression, replaceExpr: Expression) extends TernaryExpression with ImplicitCastInputTypes { @@ -325,6 +337,12 @@ case class StringTranslate(srcExpr: Expression, matchingExpr: Expression, replac * delimited list (right). Returns 0, if the string wasn't found or if the given * string (left) contains a comma. */ +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = """_FUNC_(str, str_array) - Returns the index (1-based) of the given string (left) in the comma-delimited list (right). + Returns 0, if the string wasn't found or if the given string (left) contains a comma.""", + extended = "> SELECT _FUNC_('ab','abc,b,ab,c,def');\n 3") +// scalastyle:on case class FindInSet(left: Expression, right: Expression) extends BinaryExpression with ImplicitCastInputTypes { @@ -347,6 +365,9 @@ case class FindInSet(left: Expression, right: Expression) extends BinaryExpressi /** * A function that trim the spaces from both ends for the specified string. */ +@ExpressionDescription( + usage = "_FUNC_(str) - Removes the leading and trailing space characters from str.", + extended = "> SELECT _FUNC_(' SparkSQL ');\n 'SparkSQL'") case class StringTrim(child: Expression) extends UnaryExpression with String2StringExpression { @@ -362,6 +383,9 @@ case class StringTrim(child: Expression) /** * A function that trim the spaces from left end for given string. */ +@ExpressionDescription( + usage = "_FUNC_(str) - Removes the leading space characters from str.", + extended = "> SELECT _FUNC_(' SparkSQL ');\n 'SparkSQL '") case class StringTrimLeft(child: Expression) extends UnaryExpression with String2StringExpression { @@ -377,6 +401,9 @@ case class StringTrimLeft(child: Expression) /** * A function that trim the spaces from right end for given string. */ +@ExpressionDescription( + usage = "_FUNC_(str) - Removes the trailing space characters from str.", + extended = "> SELECT _FUNC_(' SparkSQL ');\n ' SparkSQL'") case class StringTrimRight(child: Expression) extends UnaryExpression with String2StringExpression { @@ -396,6 +423,9 @@ case class StringTrimRight(child: Expression) * * NOTE: that this is not zero based, but 1-based index. The first character in str has index 1. */ +@ExpressionDescription( + usage = "_FUNC_(str, substr) - Returns the (1-based) index of the first occurrence of substr in str.", + extended = "> SELECT _FUNC_('SparkSQL', 'SQL');\n 6") case class StringInstr(str: Expression, substr: Expression) extends BinaryExpression with ImplicitCastInputTypes { @@ -422,6 +452,15 @@ case class StringInstr(str: Expression, substr: Expression) * returned. If count is negative, every to the right of the final delimiter (counting from the * right) is returned. substring_index performs a case-sensitive match when searching for delim. */ +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = """_FUNC_(str, delim, count) - Returns the substring from str before count occurrences of the delimiter delim. + If count is positive, everything to the left of the final delimiter (counting from the + left) is returned. If count is negative, everything to the right of the final delimiter + (counting from the right) is returned. Substring_index performs a case-sensitive match + when searching for delim.""", + extended = "> SELECT _FUNC_('www.apache.org', '.', 2);\n 'www.apache'") +// scalastyle:on line.size.limit case class SubstringIndex(strExpr: Expression, delimExpr: Expression, countExpr: Expression) extends TernaryExpression with ImplicitCastInputTypes { @@ -445,6 +484,12 @@ case class SubstringIndex(strExpr: Expression, delimExpr: Expression, countExpr: * A function that returns the position of the first occurrence of substr * in given string after position pos. */ +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = """_FUNC_(substr, str[, pos]) - Returns the position of the first occurrence of substr in str after position pos. + The given pos and return value are 1-based.""", + extended = "> SELECT _FUNC_('bar', 'foobarbar', 5);\n 7") +// scalastyle:on line.size.limit case class StringLocate(substr: Expression, str: Expression, start: Expression) extends TernaryExpression with ImplicitCastInputTypes { @@ -510,6 +555,11 @@ case class StringLocate(substr: Expression, str: Expression, start: Expression) /** * Returns str, left-padded with pad to a length of len. */ +@ExpressionDescription( + usage = """_FUNC_(str, len, pad) - Returns str, left-padded with pad to a length of len. + If str is longer than len, the return value is shortened to len characters.""", + extended = "> SELECT _FUNC_('hi', 5, '??');\n '???hi'\n" + + "> SELECT _FUNC_('hi', 1, '??');\n 'h'") case class StringLPad(str: Expression, len: Expression, pad: Expression) extends TernaryExpression with ImplicitCastInputTypes { @@ -531,6 +581,11 @@ case class StringLPad(str: Expression, len: Expression, pad: Expression) /** * Returns str, right-padded with pad to a length of len. */ +@ExpressionDescription( + usage = """_FUNC_(str, len, pad) - Returns str, right-padded with pad to a length of len. + If str is longer than len, the return value is shortened to len characters.""", + extended = "> SELECT _FUNC_('hi', 5, '??');\n 'hi???'\n" + + "> SELECT _FUNC_('hi', 1, '??');\n 'h'") case class StringRPad(str: Expression, len: Expression, pad: Expression) extends TernaryExpression with ImplicitCastInputTypes { @@ -552,6 +607,11 @@ case class StringRPad(str: Expression, len: Expression, pad: Expression) /** * Returns the input formatted according do printf-style format strings */ +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(String format, Obj... args) - Returns a formatted string from printf-style format strings.", + extended = "> SELECT _FUNC_(\"Hello World %d %s\", 100, \"days\");\n 'Hello World 100 days'") +// scalastyle:on line.size.limit case class FormatString(children: Expression*) extends Expression with ImplicitCastInputTypes { require(children.nonEmpty, "format_string() should take at least 1 argument") @@ -642,6 +702,9 @@ case class InitCap(child: Expression) extends UnaryExpression with ImplicitCastI /** * Returns the string which repeat the given string value n times. */ +@ExpressionDescription( + usage = "_FUNC_(str, n) - Returns the string which repeat the given string value n times.", + extended = "> SELECT _FUNC_('123', 2);\n '123123'") case class StringRepeat(str: Expression, times: Expression) extends BinaryExpression with ImplicitCastInputTypes { @@ -664,6 +727,9 @@ case class StringRepeat(str: Expression, times: Expression) /** * Returns the reversed given string. */ +@ExpressionDescription( + usage = "_FUNC_(str) - Returns the reversed given string.", + extended = "> SELECT _FUNC_('Spark SQL');\n 'LQS krapS'") case class StringReverse(child: Expression) extends UnaryExpression with String2StringExpression { override def convert(v: UTF8String): UTF8String = v.reverse() @@ -677,6 +743,9 @@ case class StringReverse(child: Expression) extends UnaryExpression with String2 /** * Returns a n spaces string. */ +@ExpressionDescription( + usage = "_FUNC_(n) - Returns a n spaces string.", + extended = "> SELECT _FUNC_(2);\n ' '") case class StringSpace(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { @@ -699,7 +768,14 @@ case class StringSpace(child: Expression) /** * A function that takes a substring of its first argument starting at a given position. * Defined for String and Binary types. + * + * NOTE: that this is not zero based, but 1-based index. The first character in str has index 1. */ +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(str, pos[, len]) - Returns the substring of str that starts at pos and is of length len or the slice of byte array that starts at pos and is of length len.", + extended = "> SELECT _FUNC_('Spark SQL', 5);\n 'k SQL'\n> SELECT _FUNC_('Spark SQL', -3);\n 'SQL'\n> SELECT _FUNC_('Spark SQL', 5, 1);\n 'k'") +// scalastyle:on line.size.limit case class Substring(str: Expression, pos: Expression, len: Expression) extends TernaryExpression with ImplicitCastInputTypes { @@ -737,6 +813,9 @@ case class Substring(str: Expression, pos: Expression, len: Expression) /** * A function that return the length of the given string or binary expression. */ +@ExpressionDescription( + usage = "_FUNC_(str | binary) - Returns the length of str or number of bytes in binary data.", + extended = "> SELECT _FUNC_('Spark SQL');\n 9") case class Length(child: Expression) extends UnaryExpression with ExpectsInputTypes { override def dataType: DataType = IntegerType override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(StringType, BinaryType)) @@ -757,6 +836,9 @@ case class Length(child: Expression) extends UnaryExpression with ExpectsInputTy /** * A function that return the Levenshtein distance between the two given strings. */ +@ExpressionDescription( + usage = "_FUNC_(str1, str2) - Returns the Levenshtein distance between the two given strings.", + extended = "> SELECT _FUNC_('kitten', 'sitting');\n 3") case class Levenshtein(left: Expression, right: Expression) extends BinaryExpression with ImplicitCastInputTypes { @@ -775,6 +857,9 @@ case class Levenshtein(left: Expression, right: Expression) extends BinaryExpres /** * A function that return soundex code of the given string expression. */ +@ExpressionDescription( + usage = "_FUNC_(str) - Returns soundex code of the string.", + extended = "> SELECT _FUNC_('Miller');\n 'M460'") case class SoundEx(child: Expression) extends UnaryExpression with ExpectsInputTypes { override def dataType: DataType = StringType @@ -791,6 +876,10 @@ case class SoundEx(child: Expression) extends UnaryExpression with ExpectsInputT /** * Returns the numeric value of the first character of str. */ +@ExpressionDescription( + usage = "_FUNC_(str) - Returns the numeric value of the first character of str.", + extended = "> SELECT _FUNC_('222');\n 50\n" + + "> SELECT _FUNC_(2);\n 50") case class Ascii(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { override def dataType: DataType = IntegerType @@ -822,6 +911,8 @@ case class Ascii(child: Expression) extends UnaryExpression with ImplicitCastInp /** * Converts the argument from binary to a base 64 string. */ +@ExpressionDescription( + usage = "_FUNC_(bin) - Convert the argument from binary to a base 64 string.") case class Base64(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { override def dataType: DataType = StringType @@ -844,6 +935,8 @@ case class Base64(child: Expression) extends UnaryExpression with ImplicitCastIn /** * Converts the argument from a base 64 string to BINARY. */ +@ExpressionDescription( + usage = "_FUNC_(str) - Convert the argument from a base 64 string to binary.") case class UnBase64(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { override def dataType: DataType = BinaryType @@ -865,6 +958,8 @@ case class UnBase64(child: Expression) extends UnaryExpression with ImplicitCast * (one of 'US-ASCII', 'ISO-8859-1', 'UTF-8', 'UTF-16BE', 'UTF-16LE', 'UTF-16'). * If either argument is null, the result will also be null. */ +@ExpressionDescription( + usage = "_FUNC_(bin, str) - Decode the first argument using the second argument character set.") case class Decode(bin: Expression, charset: Expression) extends BinaryExpression with ImplicitCastInputTypes { @@ -894,7 +989,9 @@ case class Decode(bin: Expression, charset: Expression) * Encodes the first argument into a BINARY using the provided character set * (one of 'US-ASCII', 'ISO-8859-1', 'UTF-8', 'UTF-16BE', 'UTF-16LE', 'UTF-16'). * If either argument is null, the result will also be null. -*/ + */ +@ExpressionDescription( + usage = "_FUNC_(str, str) - Encode the first argument using the second argument character set.") case class Encode(value: Expression, charset: Expression) extends BinaryExpression with ImplicitCastInputTypes { @@ -924,6 +1021,11 @@ case class Encode(value: Expression, charset: Expression) * and returns the result as a string. If D is 0, the result has no decimal point or * fractional part. */ +@ExpressionDescription( + usage = """_FUNC_(X, D) - Formats the number X like '#,###,###.##', rounded to D decimal places. + If D is 0, the result has no decimal point or fractional part. + This is supposed to function like MySQL's FORMAT.""", + extended = "> SELECT _FUNC_(12332.123456, 4);\n '12,332.1235'") case class FormatNumber(x: Expression, d: Expression) extends BinaryExpression with ExpectsInputTypes { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index dd648cdb81..695dda269a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -89,6 +89,14 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { "Function: abcadf not found.") } + test("SPARK-14415: All functions should have own descriptions") { + for (f <- sqlContext.sessionState.functionRegistry.listFunction()) { + if (!Seq("cube", "grouping", "grouping_id", "rollup", "window").contains(f)) { + checkExistence(sql(s"describe function `$f`"), false, "To be added.") + } + } + } + test("SPARK-6743: no columns from cache") { Seq( (83, 0, 38), diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index f3796a9966..b4886eba7a 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -238,7 +238,7 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { checkExistence(sql("describe functioN `~`"), true, "Function: ~", "Class: org.apache.spark.sql.catalyst.expressions.BitwiseNot", - "Usage: To be added.") + "Usage: ~ b - Bitwise NOT.") // Hard coded describe functions checkExistence(sql("describe function `<>`"), true, -- cgit v1.2.3 From fbf8d008833c985d0e222dd2360c7f7375caa68a Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Sun, 10 Apr 2016 18:10:44 -0700 Subject: [SPARK-14419] [MINOR] coding style cleanup ## What changes were proposed in this pull request? Making them more consistent. ## How was this patch tested? Existing tests. Author: Davies Liu Closes #12289 from davies/cleanup_style. --- .../execution/aggregate/TungstenAggregate.scala | 2 +- .../spark/sql/execution/joins/HashedRelation.scala | 35 ++++++++-------------- 2 files changed, 13 insertions(+), 24 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala index 692fef703f..253592028c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala @@ -454,7 +454,7 @@ case class TungstenAggregate( val thisPlan = ctx.addReferenceObj("plan", this) hashMapTerm = ctx.freshName("hashMap") val hashMapClassName = classOf[UnsafeFixedWidthAggregationMap].getName - ctx.addMutableState(hashMapClassName, hashMapTerm, s"") + ctx.addMutableState(hashMapClassName, hashMapTerm, "") sorterTerm = ctx.freshName("sorter") ctx.addMutableState(classOf[UnsafeKVExternalSorter].getName, sorterTerm, "") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala index 68b5486faa..0427db4e3b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala @@ -122,13 +122,12 @@ private[joins] class UnsafeHashedRelation( override def keyIsUnique: Boolean = binaryMap.numKeys() == binaryMap.numValues() - override def asReadOnlyCopy(): UnsafeHashedRelation = + override def asReadOnlyCopy(): UnsafeHashedRelation = { new UnsafeHashedRelation(numFields, binaryMap) - - override def estimatedSize: Long = { - binaryMap.getTotalMemoryConsumption } + override def estimatedSize: Long = binaryMap.getTotalMemoryConsumption + // re-used in get()/getValue() var resultRow = new UnsafeRow(numFields) @@ -374,8 +373,9 @@ private[execution] final class LongToUnsafeRowMap(var mm: TaskMemoryManager, cap // do not support spilling val got = mm.acquireExecutionMemory(size, MemoryMode.ON_HEAP, this) if (got < size) { - mm.releaseExecutionMemory(got, MemoryMode.ON_HEAP, this) - throw new SparkException(s"Can't acquire $size bytes memory to build hash relation") + freeMemory(got) + throw new SparkException(s"Can't acquire $size bytes memory to build hash relation, " + + s"got $got bytes") } } @@ -396,9 +396,7 @@ private[execution] final class LongToUnsafeRowMap(var mm: TaskMemoryManager, cap init() - def spill(size: Long, trigger: MemoryConsumer): Long = { - 0L - } + def spill(size: Long, trigger: MemoryConsumer): Long = 0L /** * Returns whether all the keys are unique. @@ -408,9 +406,7 @@ private[execution] final class LongToUnsafeRowMap(var mm: TaskMemoryManager, cap /** * Returns total memory consumption. */ - def getTotalMemoryConsumption: Long = { - array.length * 8 + page.length - } + def getTotalMemoryConsumption: Long = array.length * 8 + page.length /** * Returns the first slot of array that store the keys (sparse mode). @@ -423,9 +419,7 @@ private[execution] final class LongToUnsafeRowMap(var mm: TaskMemoryManager, cap /** * Returns the next probe in the array. */ - private def nextSlot(pos: Int): Int = { - (pos + 2) & mask - } + private def nextSlot(pos: Int): Int = (pos + 2) & mask private def getRow(address: Long, resultRow: UnsafeRow): UnsafeRow = { val offset = address >>> 32 @@ -674,9 +668,7 @@ private[joins] class LongHashedRelation( override def asReadOnlyCopy(): LongHashedRelation = new LongHashedRelation(nFields, map) - override def estimatedSize: Long = { - map.getTotalMemoryConsumption - } + override def estimatedSize: Long = map.getTotalMemoryConsumption override def get(key: InternalRow): Iterator[InternalRow] = { if (key.isNullAt(0)) { @@ -694,12 +686,9 @@ private[joins] class LongHashedRelation( } } - override def get(key: Long): Iterator[InternalRow] = - map.get(key, resultRow) + override def get(key: Long): Iterator[InternalRow] = map.get(key, resultRow) - override def getValue(key: Long): InternalRow = { - map.getValue(key, resultRow) - } + override def getValue(key: Long): InternalRow = map.getValue(key, resultRow) override def keyIsUnique: Boolean = map.keyIsUnique -- cgit v1.2.3 From 9f838bd24242866a687a2655a1b8ac2f5d562526 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Sun, 10 Apr 2016 20:46:15 -0700 Subject: [SPARK-14362][SPARK-14406][SQL][FOLLOW-UP] DDL Native Support: Drop View and Drop Table #### What changes were proposed in this pull request? This PR is to address the comment: https://github.com/apache/spark/pull/12146#discussion-diff-59092238. It removes the function `isViewSupported` from `SessionCatalog`. After the removal, we still can capture the user errors if users try to drop a table using `DROP VIEW`. #### How was this patch tested? Modified the existing test cases Author: gatorsmile Closes #12284 from gatorsmile/followupDropTable. --- R/pkg/inst/tests/testthat/test_sparkSQL.R | 2 +- .../org/apache/spark/sql/catalyst/analysis/Analyzer.scala | 2 +- .../apache/spark/sql/catalyst/analysis/CheckAnalysis.scala | 2 +- .../spark/sql/catalyst/analysis/NoSuchItemException.scala | 2 +- .../spark/sql/catalyst/catalog/InMemoryCatalog.scala | 4 ++-- .../apache/spark/sql/catalyst/catalog/SessionCatalog.scala | 9 ++------- .../scala/org/apache/spark/sql/execution/command/ddl.scala | 3 --- .../test/scala/org/apache/spark/sql/SQLQuerySuite.scala | 4 ++-- .../org/apache/spark/sql/execution/command/DDLSuite.scala | 14 ++++++++++++-- .../org/apache/spark/sql/hive/thriftserver/CliSuite.scala | 2 +- .../org/apache/spark/sql/hive/HiveSessionCatalog.scala | 2 -- .../apache/spark/sql/hive/execution/HiveCommandSuite.scala | 2 +- 12 files changed, 24 insertions(+), 24 deletions(-) diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R index 22eb3ec984..d747d4f83f 100644 --- a/R/pkg/inst/tests/testthat/test_sparkSQL.R +++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R @@ -1853,7 +1853,7 @@ test_that("approxQuantile() on a DataFrame", { test_that("SQL error message is returned from JVM", { retError <- tryCatch(sql(sqlContext, "select * from blah"), error = function(e) e) - expect_equal(grepl("Table not found", retError), TRUE) + expect_equal(grepl("Table or View not found", retError), TRUE) expect_equal(grepl("blah", retError), TRUE) }) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 3555a6d7fa..de40ddde1b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -409,7 +409,7 @@ class Analyzer( catalog.lookupRelation(u.tableIdentifier, u.alias) } catch { case _: NoSuchTableException => - u.failAnalysis(s"Table not found: ${u.tableName}") + u.failAnalysis(s"Table or View not found: ${u.tableName}") } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 4880502398..d6a8c3eec8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -52,7 +52,7 @@ trait CheckAnalysis { case p if p.analyzed => // Skip already analyzed sub-plans case u: UnresolvedRelation => - u.failAnalysis(s"Table not found: ${u.tableIdentifier}") + u.failAnalysis(s"Table or View not found: ${u.tableIdentifier}") case operator: LogicalPlan => operator transformExpressionsUp { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/NoSuchItemException.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/NoSuchItemException.scala index e9f04eecf8..96fd1a027e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/NoSuchItemException.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/NoSuchItemException.scala @@ -33,7 +33,7 @@ class NoSuchDatabaseException(db: String) extends NoSuchItemException { } class NoSuchTableException(db: String, table: String) extends NoSuchItemException { - override def getMessage: String = s"Table $table not found in database $db" + override def getMessage: String = s"Table or View $table not found in database $db" } class NoSuchPartitionException( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala index 1994acd1ad..f8a6fb74cc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala @@ -62,7 +62,7 @@ class InMemoryCatalog extends ExternalCatalog { private def requireTableExists(db: String, table: String): Unit = { if (!tableExists(db, table)) { throw new AnalysisException( - s"Table not found: '$table' does not exist in database '$db'") + s"Table or View not found: '$table' does not exist in database '$db'") } } @@ -164,7 +164,7 @@ class InMemoryCatalog extends ExternalCatalog { catalog(db).tables.remove(table) } else { if (!ignoreIfNotExists) { - throw new AnalysisException(s"Table '$table' does not exist in database '$db'") + throw new AnalysisException(s"Table or View '$table' does not exist in database '$db'") } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala index c1e5a485e7..34e1cb7315 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala @@ -242,11 +242,11 @@ class SessionCatalog( val table = formatTableName(name.table) if (name.database.isDefined || !tempTables.contains(table)) { // When ignoreIfNotExists is false, no exception is issued when the table does not exist. - // Instead, log it as an error message. This is consistent with Hive. + // Instead, log it as an error message. if (externalCatalog.tableExists(db, table)) { externalCatalog.dropTable(db, table, ignoreIfNotExists = true) } else if (!ignoreIfNotExists) { - logError(s"Table '${name.quotedString}' does not exist") + logError(s"Table or View '${name.quotedString}' does not exist") } } else { tempTables.remove(table) @@ -304,11 +304,6 @@ class SessionCatalog( name.database.isEmpty && tempTables.contains(formatTableName(name.table)) } - /** - * Return whether View is supported - */ - def isViewSupported: Boolean = false - /** * List all tables in the specified database, including temporary tables. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala index e941736f9a..8a37cf8f4c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala @@ -191,9 +191,6 @@ case class DropTable( override def run(sqlContext: SQLContext): Seq[Row] = { val catalog = sqlContext.sessionState.catalog - if (isView && !catalog.isViewSupported) { - throw new AnalysisException(s"Not supported object: views") - } // If the command DROP VIEW is to drop a table or DROP TABLE is to drop a view // issue an exception. catalog.getTableMetadataOption(tableName).map(_.tableType match { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 695dda269a..cdd404d699 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -1827,12 +1827,12 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { val e1 = intercept[AnalysisException] { sql("select * from in_valid_table") } - assert(e1.message.contains("Table not found")) + assert(e1.message.contains("Table or View not found")) val e2 = intercept[AnalysisException] { sql("select * from no_db.no_table").show() } - assert(e2.message.contains("Table not found")) + assert(e2.message.contains("Table or View not found")) val e3 = intercept[AnalysisException] { sql("select * from json.invalid_file") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala index e75e5f5cb2..c6479bf33e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala @@ -432,11 +432,21 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { sql("DROP TABLE dbx.tab1") } - test("drop view") { + test("drop view in SQLContext") { + // SQLContext does not support create view. Log an error message, if tab1 does not exists + sql("DROP VIEW tab1") + + val catalog = sqlContext.sessionState.catalog + val tableIdent = TableIdentifier("tab1", Some("dbx")) + createDatabase(catalog, "dbx") + createTable(catalog, tableIdent) + assert(catalog.listTables("dbx") == Seq(tableIdent)) + val e = intercept[AnalysisException] { sql("DROP VIEW dbx.tab1") } - assert(e.getMessage.contains("Not supported object: views")) + assert( + e.getMessage.contains("Cannot drop a table with DROP VIEW. Please use DROP TABLE instead")) } private def convertToDatasourceTable( diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala index 9ec8b9a9a6..bfc3d195ff 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala @@ -230,7 +230,7 @@ class CliSuite extends SparkFunSuite with BeforeAndAfterAll with Logging { runCliWithin(timeout = 2.minute, errorResponses = Seq("AnalysisException"))( "select * from nonexistent_table;" - -> "Error in query: Table not found: nonexistent_table;" + -> "Error in query: Table or View not found: nonexistent_table;" ) } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala index 875652c226..0cccc22e5a 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala @@ -70,8 +70,6 @@ private[sql] class HiveSessionCatalog( } } - override def isViewSupported: Boolean = true - // ---------------------------------------------------------------- // | Methods and fields for interacting with HiveMetastoreCatalog | // ---------------------------------------------------------------- diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveCommandSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveCommandSuite.scala index 8de2bdcfc0..061d1512a5 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveCommandSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveCommandSuite.scala @@ -96,7 +96,7 @@ class HiveCommandSuite extends QueryTest with SQLTestUtils with TestHiveSingleto val message1 = intercept[AnalysisException] { sql("SHOW TBLPROPERTIES badtable") }.getMessage - assert(message1.contains("Table badtable not found in database default")) + assert(message1.contains("Table or View badtable not found in database default")) // When key is not found, a row containing the error is returned. checkAnswer( -- cgit v1.2.3 From 1a0cca1fc81512d480ed0efc46114cb2b2189183 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Mon, 11 Apr 2016 09:03:11 +0100 Subject: [MINOR][DOCS] Fix wrong data types in JSON Datasets example. ## What changes were proposed in this pull request? This PR fixes the `age` data types from `integer` to `long` in `SQL Programming Guide: JSON Datasets`. ## How was this patch tested? Manual. Author: Dongjoon Hyun Closes #12290 from dongjoon-hyun/minor_fix_type_in_json_example. --- docs/sql-programming-guide.md | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 63310be22c..2d9849d032 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1502,7 +1502,7 @@ val people = sqlContext.read.json(path) // The inferred schema can be visualized using the printSchema() method. people.printSchema() // root -// |-- age: integer (nullable = true) +// |-- age: long (nullable = true) // |-- name: string (nullable = true) // Register this DataFrame as a table. @@ -1540,7 +1540,7 @@ DataFrame people = sqlContext.read().json("examples/src/main/resources/people.js // The inferred schema can be visualized using the printSchema() method. people.printSchema(); // root -// |-- age: integer (nullable = true) +// |-- age: long (nullable = true) // |-- name: string (nullable = true) // Register this DataFrame as a table. @@ -1578,7 +1578,7 @@ people = sqlContext.read.json("examples/src/main/resources/people.json") # The inferred schema can be visualized using the printSchema() method. people.printSchema() # root -# |-- age: integer (nullable = true) +# |-- age: long (nullable = true) # |-- name: string (nullable = true) # Register this DataFrame as a table. @@ -1617,7 +1617,7 @@ people <- jsonFile(sqlContext, path) # The inferred schema can be visualized using the printSchema() method. printSchema(people) # root -# |-- age: integer (nullable = true) +# |-- age: long (nullable = true) # |-- name: string (nullable = true) # Register this DataFrame as a table. -- cgit v1.2.3 From e82d95bf63f57cefa02dc545ceb451ecdeedce28 Mon Sep 17 00:00:00 2001 From: Rekha Joshi Date: Mon, 11 Apr 2016 17:13:30 +0800 Subject: [SPARK-14372][SQL] Dataset.randomSplit() needs a Java version ## What changes were proposed in this pull request? 1.Added method randomSplitAsList() in Dataset for java for https://issues.apache.org/jira/browse/SPARK-14372 ## How was this patch tested? TestSuite Author: Rekha Joshi Author: Joshi Closes #12184 from rekhajoshm/SPARK-14372. --- .../src/main/scala/org/apache/spark/sql/Dataset.scala | 17 ++++++++++++++++- .../test/org/apache/spark/sql/JavaDatasetSuite.java | 10 ++++++++++ 2 files changed, 26 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 2f6d8d109f..e216945fbe 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -20,7 +20,6 @@ package org.apache.spark.sql import java.io.CharArrayWriter import scala.collection.JavaConverters._ -import scala.collection.mutable.ArrayBuffer import scala.language.implicitConversions import scala.reflect.runtime.universe.TypeTag import scala.util.control.NonFatal @@ -1493,6 +1492,8 @@ class Dataset[T] private[sql]( * @param weights weights for splits, will be normalized if they don't sum to 1. * @param seed Seed for sampling. * + * For Java API, use [[randomSplitAsList]]. + * * @group typedrel * @since 2.0.0 */ @@ -1510,6 +1511,20 @@ class Dataset[T] private[sql]( }.toArray } + /** + * Returns a Java list that contains randomly split [[Dataset]] with the provided weights. + * + * @param weights weights for splits, will be normalized if they don't sum to 1. + * @param seed Seed for sampling. + * + * @group typedrel + * @since 2.0.0 + */ + def randomSplitAsList(weights: Array[Double], seed: Long): java.util.List[Dataset[T]] = { + val values = randomSplit(weights, seed) + java.util.Arrays.asList(values : _*) + } + /** * Randomly splits this [[Dataset]] with the provided weights. * diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java index f26c57b301..5abd62cbc2 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java @@ -454,6 +454,16 @@ public class JavaDatasetSuite implements Serializable { Assert.assertEquals(data, ds.collectAsList()); } + @Test + public void testRandomSplit() { + List data = Arrays.asList("hello", "world", "from", "spark"); + Dataset ds = context.createDataset(data, Encoders.STRING()); + double[] arraySplit = {1, 2, 3}; + + List> randomSplit = ds.randomSplitAsList(arraySplit, 1); + Assert.assertEquals("wrong number of splits", randomSplit.size(), 3); + } + /** * For testing error messages when creating an encoder on a private class. This is done * here since we cannot create truly private classes in Scala. -- cgit v1.2.3 From 1c751fcf488189e5176546fe0d00f560ffcf1cec Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Mon, 11 Apr 2016 09:28:28 -0700 Subject: [SPARK-14500] [ML] Accept Dataset[_] instead of DataFrame in MLlib APIs ## What changes were proposed in this pull request? This PR updates MLlib APIs to accept `Dataset[_]` as input where `DataFrame` was the input type. This PR doesn't change the output type. In Java, `Dataset[_]` maps to `Dataset`, which includes `Dataset`. Some implementations were changed in order to return `DataFrame`. Tests and examples were updated. Note that this is a breaking change for subclasses of Transformer/Estimator. Lol, we don't have to rename the input argument, which has been `dataset` since Spark 1.2. TODOs: - [x] update MiMaExcludes (seems all covered by explicit filters from SPARK-13920) - [x] Python - [x] add a new test to accept Dataset[LabeledPoint] - [x] remove unused imports of Dataset ## How was this patch tested? Exiting unit tests with some modifications. cc: rxin jkbradley Author: Xiangrui Meng Closes #12274 from mengxr/SPARK-14500. --- .../spark/examples/ml/JavaDeveloperApiExample.java | 2 +- .../spark/examples/ml/DeveloperApiExample.scala | 4 ++-- .../main/scala/org/apache/spark/ml/Estimator.scala | 16 +++++++++------ .../main/scala/org/apache/spark/ml/Pipeline.scala | 12 +++++------ .../main/scala/org/apache/spark/ml/Predictor.scala | 14 ++++++------- .../scala/org/apache/spark/ml/Transformer.scala | 15 ++++++++------ .../spark/ml/classification/Classifier.scala | 6 +++--- .../ml/classification/DecisionTreeClassifier.scala | 4 ++-- .../spark/ml/classification/GBTClassifier.scala | 6 +++--- .../ml/classification/LogisticRegression.scala | 8 ++++---- .../MultilayerPerceptronClassifier.scala | 4 ++-- .../spark/ml/classification/NaiveBayes.scala | 4 ++-- .../apache/spark/ml/classification/OneVsRest.scala | 10 ++++----- .../classification/ProbabilisticClassifier.scala | 6 +++--- .../ml/classification/RandomForestClassifier.scala | 6 +++--- .../spark/ml/clustering/BisectingKMeans.scala | 8 ++++---- .../spark/ml/clustering/GaussianMixture.scala | 6 +++--- .../org/apache/spark/ml/clustering/KMeans.scala | 14 ++++++------- .../scala/org/apache/spark/ml/clustering/LDA.scala | 24 +++++++++++----------- .../evaluation/BinaryClassificationEvaluator.scala | 6 +++--- .../org/apache/spark/ml/evaluation/Evaluator.scala | 10 ++++----- .../MulticlassClassificationEvaluator.scala | 6 +++--- .../spark/ml/evaluation/RegressionEvaluator.scala | 6 +++--- .../org/apache/spark/ml/feature/Binarizer.scala | 3 ++- .../org/apache/spark/ml/feature/Bucketizer.scala | 3 ++- .../apache/spark/ml/feature/ChiSqSelector.scala | 6 ++++-- .../apache/spark/ml/feature/CountVectorizer.scala | 8 +++++--- .../org/apache/spark/ml/feature/HashingTF.scala | 5 +++-- .../scala/org/apache/spark/ml/feature/IDF.scala | 6 ++++-- .../org/apache/spark/ml/feature/Interaction.scala | 6 +++--- .../org/apache/spark/ml/feature/MaxAbsScaler.scala | 6 ++++-- .../org/apache/spark/ml/feature/MinMaxScaler.scala | 6 ++++-- .../apache/spark/ml/feature/OneHotEncoder.scala | 5 +++-- .../scala/org/apache/spark/ml/feature/PCA.scala | 6 ++++-- .../spark/ml/feature/QuantileDiscretizer.scala | 13 +++++++----- .../org/apache/spark/ml/feature/RFormula.scala | 18 ++++++++-------- .../apache/spark/ml/feature/SQLTransformer.scala | 6 +++--- .../apache/spark/ml/feature/StandardScaler.scala | 6 ++++-- .../apache/spark/ml/feature/StopWordsRemover.scala | 5 +++-- .../apache/spark/ml/feature/StringIndexer.scala | 13 +++++++----- .../apache/spark/ml/feature/VectorAssembler.scala | 7 ++++--- .../apache/spark/ml/feature/VectorIndexer.scala | 8 +++++--- .../org/apache/spark/ml/feature/VectorSlicer.scala | 5 +++-- .../org/apache/spark/ml/feature/Word2Vec.scala | 8 +++++--- .../spark/ml/r/AFTSurvivalRegressionWrapper.scala | 4 ++-- .../org/apache/spark/ml/r/KMeansWrapper.scala | 4 ++-- .../org/apache/spark/ml/r/NaiveBayesWrapper.scala | 4 ++-- .../org/apache/spark/ml/recommendation/ALS.scala | 10 ++++----- .../ml/regression/AFTSurvivalRegression.scala | 12 +++++------ .../ml/regression/DecisionTreeRegressor.scala | 11 +++++----- .../apache/spark/ml/regression/GBTRegressor.scala | 6 +++--- .../regression/GeneralizedLinearRegression.scala | 4 ++-- .../spark/ml/regression/IsotonicRegression.scala | 12 +++++------ .../spark/ml/regression/LinearRegression.scala | 6 +++--- .../ml/regression/RandomForestRegressor.scala | 6 +++--- .../apache/spark/ml/tuning/CrossValidator.scala | 12 +++++------ .../spark/ml/tuning/TrainValidationSplit.scala | 11 +++++----- .../org/apache/spark/mllib/linalg/Vectors.scala | 2 +- .../scala/org/apache/spark/ml/PipelineSuite.scala | 12 ++++++++--- .../classification/LogisticRegressionSuite.scala | 4 ++-- .../MultilayerPerceptronClassifierSuite.scala | 4 ++-- .../spark/ml/classification/NaiveBayesSuite.scala | 4 ++-- .../spark/ml/classification/OneVsRestSuite.scala | 6 +++--- .../spark/ml/clustering/BisectingKMeansSuite.scala | 4 ++-- .../spark/ml/clustering/GaussianMixtureSuite.scala | 4 ++-- .../apache/spark/ml/clustering/KMeansSuite.scala | 4 ++-- .../org/apache/spark/ml/clustering/LDASuite.scala | 4 ++-- .../org/apache/spark/ml/feature/NGramSuite.scala | 4 ++-- .../spark/ml/feature/StopWordsRemoverSuite.scala | 4 ++-- .../spark/ml/feature/StringIndexerSuite.scala | 2 +- .../apache/spark/ml/feature/TokenizerSuite.scala | 4 ++-- .../GeneralizedLinearRegressionSuite.scala | 8 ++++++++ .../spark/ml/tuning/CrossValidatorSuite.scala | 8 ++++---- .../ml/tuning/TrainValidationSplitSuite.scala | 6 +++--- .../spark/ml/util/DefaultReadWriteTest.scala | 4 ++-- 75 files changed, 296 insertions(+), 240 deletions(-) diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java index fbd8817669..0ba94786d4 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java @@ -146,7 +146,7 @@ class MyJavaLogisticRegression // This method is used by fit(). // In Java, we have to make it public since Java does not understand Scala's protected modifier. - public MyJavaLogisticRegressionModel train(Dataset dataset) { + public MyJavaLogisticRegressionModel train(Dataset dataset) { // Extract columns from data using helper method. JavaRDD oldDataset = extractLabeledPoints(dataset).toJavaRDD(); diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala index c1f63c6a1d..8d127f9b35 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala @@ -24,7 +24,7 @@ import org.apache.spark.ml.param.{IntParam, ParamMap} import org.apache.spark.ml.util.Identifiable import org.apache.spark.mllib.linalg.{BLAS, Vector, Vectors} import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.sql.{DataFrame, Row, SQLContext} +import org.apache.spark.sql.{DataFrame, Dataset, Row, SQLContext} /** * A simple example demonstrating how to write your own learning algorithm using Estimator, @@ -120,7 +120,7 @@ private class MyLogisticRegression(override val uid: String) def setMaxIter(value: Int): this.type = set(maxIter, value) // This method is used by fit() - override protected def train(dataset: DataFrame): MyLogisticRegressionModel = { + override protected def train(dataset: Dataset[_]): MyLogisticRegressionModel = { // Extract columns from data using helper method. val oldDataset = extractLabeledPoints(dataset) diff --git a/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala b/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala index 57e416591d..1247882d6c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala @@ -19,9 +19,9 @@ package org.apache.spark.ml import scala.annotation.varargs -import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.annotation.{DeveloperApi, Since} import org.apache.spark.ml.param.{ParamMap, ParamPair} -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.Dataset /** * :: DeveloperApi :: @@ -39,8 +39,9 @@ abstract class Estimator[M <: Model[M]] extends PipelineStage { * Estimator's embedded ParamMap. * @return fitted model */ + @Since("2.0.0") @varargs - def fit(dataset: DataFrame, firstParamPair: ParamPair[_], otherParamPairs: ParamPair[_]*): M = { + def fit(dataset: Dataset[_], firstParamPair: ParamPair[_], otherParamPairs: ParamPair[_]*): M = { val map = new ParamMap() .put(firstParamPair) .put(otherParamPairs: _*) @@ -55,14 +56,16 @@ abstract class Estimator[M <: Model[M]] extends PipelineStage { * These values override any specified in this Estimator's embedded ParamMap. * @return fitted model */ - def fit(dataset: DataFrame, paramMap: ParamMap): M = { + @Since("2.0.0") + def fit(dataset: Dataset[_], paramMap: ParamMap): M = { copy(paramMap).fit(dataset) } /** * Fits a model to the input data. */ - def fit(dataset: DataFrame): M + @Since("2.0.0") + def fit(dataset: Dataset[_]): M /** * Fits multiple models to the input data with multiple sets of parameters. @@ -74,7 +77,8 @@ abstract class Estimator[M <: Model[M]] extends PipelineStage { * These values override any specified in this Estimator's embedded ParamMap. * @return fitted models, matching the input parameter maps */ - def fit(dataset: DataFrame, paramMaps: Array[ParamMap]): Seq[M] = { + @Since("2.0.0") + def fit(dataset: Dataset[_], paramMaps: Array[ParamMap]): Seq[M] = { paramMaps.map(fit(dataset, _)) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala index afefaaa883..82066726a0 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala @@ -31,7 +31,7 @@ import org.apache.spark.annotation.{DeveloperApi, Experimental, Since} import org.apache.spark.internal.Logging import org.apache.spark.ml.param.{Param, ParamMap, Params} import org.apache.spark.ml.util._ -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.types.StructType /** @@ -123,8 +123,8 @@ class Pipeline @Since("1.4.0") ( * @param dataset input dataset * @return fitted pipeline */ - @Since("1.2.0") - override def fit(dataset: DataFrame): PipelineModel = { + @Since("2.0.0") + override def fit(dataset: Dataset[_]): PipelineModel = { transformSchema(dataset.schema, logging = true) val theStages = $(stages) // Search for the last estimator. @@ -291,10 +291,10 @@ class PipelineModel private[ml] ( this(uid, stages.asScala.toArray) } - @Since("1.2.0") - override def transform(dataset: DataFrame): DataFrame = { + @Since("2.0.0") + override def transform(dataset: Dataset[_]): DataFrame = { transformSchema(dataset.schema, logging = true) - stages.foldLeft(dataset)((cur, transformer) => transformer.transform(cur)) + stages.foldLeft(dataset.toDF)((cur, transformer) => transformer.transform(cur)) } @Since("1.2.0") diff --git a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala index d23ae6f794..81140d1f7b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala @@ -24,7 +24,7 @@ import org.apache.spark.ml.util.SchemaUtils import org.apache.spark.mllib.linalg.{Vector, VectorUDT} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.{DataFrame, Dataset, Row} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.{DataType, DoubleType, StructType} @@ -83,7 +83,7 @@ abstract class Predictor[ /** @group setParam */ def setPredictionCol(value: String): Learner = set(predictionCol, value).asInstanceOf[Learner] - override def fit(dataset: DataFrame): M = { + override def fit(dataset: Dataset[_]): M = { // This handles a few items such as schema validation. // Developers only need to implement train(). transformSchema(dataset.schema, logging = true) @@ -100,7 +100,7 @@ abstract class Predictor[ * @param dataset Training dataset * @return Fitted model */ - protected def train(dataset: DataFrame): M + protected def train(dataset: Dataset[_]): M /** * Returns the SQL DataType corresponding to the FeaturesType type parameter. @@ -120,7 +120,7 @@ abstract class Predictor[ * Extract [[labelCol]] and [[featuresCol]] from the given dataset, * and put it in an RDD with strong types. */ - protected def extractLabeledPoints(dataset: DataFrame): RDD[LabeledPoint] = { + protected def extractLabeledPoints(dataset: Dataset[_]): RDD[LabeledPoint] = { dataset.select(col($(labelCol)).cast(DoubleType), col($(featuresCol))).rdd.map { case Row(label: Double, features: Vector) => LabeledPoint(label, features) } @@ -171,18 +171,18 @@ abstract class PredictionModel[FeaturesType, M <: PredictionModel[FeaturesType, * @param dataset input dataset * @return transformed dataset with [[predictionCol]] of type [[Double]] */ - override def transform(dataset: DataFrame): DataFrame = { + override def transform(dataset: Dataset[_]): DataFrame = { transformSchema(dataset.schema, logging = true) if ($(predictionCol).nonEmpty) { transformImpl(dataset) } else { this.logWarning(s"$uid: Predictor.transform() was called as NOOP" + " since no output columns were set.") - dataset + dataset.toDF } } - protected def transformImpl(dataset: DataFrame): DataFrame = { + protected def transformImpl(dataset: Dataset[_]): DataFrame = { val predictUDF = udf { (features: Any) => predict(features.asInstanceOf[FeaturesType]) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala b/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala index 2538c0f477..a3a2b55adc 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala @@ -19,11 +19,11 @@ package org.apache.spark.ml import scala.annotation.varargs -import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.annotation.{DeveloperApi, Since} import org.apache.spark.internal.Logging import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ @@ -41,9 +41,10 @@ abstract class Transformer extends PipelineStage { * @param otherParamPairs other param pairs, overwrite embedded params * @return transformed dataset */ + @Since("2.0.0") @varargs def transform( - dataset: DataFrame, + dataset: Dataset[_], firstParamPair: ParamPair[_], otherParamPairs: ParamPair[_]*): DataFrame = { val map = new ParamMap() @@ -58,14 +59,16 @@ abstract class Transformer extends PipelineStage { * @param paramMap additional parameters, overwrite embedded params * @return transformed dataset */ - def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = { + @Since("2.0.0") + def transform(dataset: Dataset[_], paramMap: ParamMap): DataFrame = { this.copy(paramMap).transform(dataset) } /** * Transforms the input dataset. */ - def transform(dataset: DataFrame): DataFrame + @Since("2.0.0") + def transform(dataset: Dataset[_]): DataFrame override def copy(extra: ParamMap): Transformer } @@ -113,7 +116,7 @@ abstract class UnaryTransformer[IN, OUT, T <: UnaryTransformer[IN, OUT, T]] StructType(outputFields) } - override def transform(dataset: DataFrame): DataFrame = { + override def transform(dataset: Dataset[_]): DataFrame = { transformSchema(dataset.schema, logging = true) val transformUDF = udf(this.createTransformFunc, outputDataType) dataset.withColumn($(outputCol), transformUDF(dataset($(inputCol)))) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala index 8186afc17a..473e801794 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala @@ -22,7 +22,7 @@ import org.apache.spark.ml.{PredictionModel, Predictor, PredictorParams} import org.apache.spark.ml.param.shared.HasRawPredictionCol import org.apache.spark.ml.util.SchemaUtils import org.apache.spark.mllib.linalg.{Vector, VectorUDT} -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.{DataType, StructType} @@ -92,7 +92,7 @@ abstract class ClassificationModel[FeaturesType, M <: ClassificationModel[Featur * @param dataset input dataset * @return transformed dataset */ - override def transform(dataset: DataFrame): DataFrame = { + override def transform(dataset: Dataset[_]): DataFrame = { transformSchema(dataset.schema, logging = true) // Output selected columns only. @@ -123,7 +123,7 @@ abstract class ClassificationModel[FeaturesType, M <: ClassificationModel[Featur logWarning(s"$uid: ClassificationModel.transform() was called as NOOP" + " since no output columns were set.") } - outputData + outputData.toDF } /** diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala index 4525bf71f6..300ae4339c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala @@ -32,7 +32,7 @@ import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy} import org.apache.spark.mllib.tree.model.{DecisionTreeModel => OldDecisionTreeModel} import org.apache.spark.rdd.RDD -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.{DataFrame, Dataset} /** @@ -82,7 +82,7 @@ final class DecisionTreeClassifier @Since("1.4.0") ( @Since("1.6.0") override def setSeed(value: Long): this.type = super.setSeed(value) - override protected def train(dataset: DataFrame): DecisionTreeClassificationModel = { + override protected def train(dataset: Dataset[_]): DecisionTreeClassificationModel = { val categoricalFeatures: Map[Int, Int] = MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol))) val numClasses: Int = MetadataUtils.getNumClasses(dataset.schema($(labelCol))) match { diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala index a2150fbcc3..46e8b89d01 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala @@ -33,7 +33,7 @@ import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} import org.apache.spark.mllib.tree.loss.{LogLoss => OldLogLoss, Loss => OldLoss} import org.apache.spark.mllib.tree.model.{GradientBoostedTreesModel => OldGBTModel} import org.apache.spark.rdd.RDD -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.functions._ /** @@ -149,7 +149,7 @@ final class GBTClassifier @Since("1.4.0") ( } } - override protected def train(dataset: DataFrame): GBTClassificationModel = { + override protected def train(dataset: Dataset[_]): GBTClassificationModel = { val categoricalFeatures: Map[Int, Int] = MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol))) val numClasses: Int = MetadataUtils.getNumClasses(dataset.schema($(labelCol))) match { @@ -220,7 +220,7 @@ final class GBTClassificationModel private[ml]( @Since("1.4.0") override def treeWeights: Array[Double] = _treeWeights - override protected def transformImpl(dataset: DataFrame): DataFrame = { + override protected def transformImpl(dataset: Dataset[_]): DataFrame = { val bcastModel = dataset.sqlContext.sparkContext.broadcast(this) val predictUDF = udf { (features: Any) => bcastModel.value.predict(features.asInstanceOf[Vector]) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index 268c3e32c3..4a3fe5c663 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -36,7 +36,7 @@ import org.apache.spark.mllib.linalg.BLAS._ import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer import org.apache.spark.mllib.util.MLUtils import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.{DataFrame, Dataset, Row} import org.apache.spark.sql.functions.{col, lit} import org.apache.spark.sql.types.DoubleType import org.apache.spark.storage.StorageLevel @@ -257,12 +257,12 @@ class LogisticRegression @Since("1.2.0") ( this } - override protected[spark] def train(dataset: DataFrame): LogisticRegressionModel = { + override protected[spark] def train(dataset: Dataset[_]): LogisticRegressionModel = { val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE train(dataset, handlePersistence) } - protected[spark] def train(dataset: DataFrame, handlePersistence: Boolean): + protected[spark] def train(dataset: Dataset[_], handlePersistence: Boolean): LogisticRegressionModel = { val w = if ($(weightCol).isEmpty) lit(1.0) else col($(weightCol)) val instances: RDD[Instance] = @@ -544,7 +544,7 @@ class LogisticRegressionModel private[spark] ( * @param dataset Test dataset to evaluate model on. */ @Since("2.0.0") - def evaluate(dataset: DataFrame): LogisticRegressionSummary = { + def evaluate(dataset: Dataset[_]): LogisticRegressionSummary = { // Handle possible missing or invalid prediction columns val (summaryModel, probabilityColName) = findSummaryModelAndProbabilityCol() new BinaryLogisticRegressionSummary(summaryModel.transform(dataset), diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala index 79bb2a8855..9ff5252e4f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala @@ -29,7 +29,7 @@ import org.apache.spark.ml.param.shared.{HasMaxIter, HasSeed, HasStepSize, HasTo import org.apache.spark.ml.util._ import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.{DataFrame, Dataset} /** Params for Multilayer Perceptron. */ private[ml] trait MultilayerPerceptronParams extends PredictorParams @@ -199,7 +199,7 @@ class MultilayerPerceptronClassifier @Since("1.5.0") ( * @param dataset Training dataset * @return Fitted model */ - override protected def train(dataset: DataFrame): MultilayerPerceptronClassificationModel = { + override protected def train(dataset: Dataset[_]): MultilayerPerceptronClassificationModel = { val myLayers = $(layers) val labels = myLayers.last val lpData = extractLabeledPoints(dataset) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala index 483ef0d88c..267d63b51e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala @@ -29,7 +29,7 @@ import org.apache.spark.mllib.classification.{NaiveBayesModel => OldNaiveBayesMo import org.apache.spark.mllib.linalg._ import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.rdd.RDD -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.{DataFrame, Dataset} /** * Params for Naive Bayes Classifiers. @@ -101,7 +101,7 @@ class NaiveBayes @Since("1.5.0") ( def setModelType(value: String): this.type = set(modelType, value) setDefault(modelType -> OldNaiveBayes.Multinomial) - override protected def train(dataset: DataFrame): NaiveBayesModel = { + override protected def train(dataset: Dataset[_]): NaiveBayesModel = { val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset) val oldModel = OldNaiveBayes.train(oldDataset, $(smoothing), $(modelType)) NaiveBayesModel.fromOld(oldModel, this) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala index 263d54ce4d..4de1b877b0 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala @@ -33,7 +33,7 @@ import org.apache.spark.ml.attribute._ import org.apache.spark.ml.param.{Param, ParamMap, ParamPair, Params} import org.apache.spark.ml.util._ import org.apache.spark.mllib.linalg.Vector -import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.{DataFrame, Dataset, Row} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ import org.apache.spark.storage.StorageLevel @@ -140,8 +140,8 @@ final class OneVsRestModel private[ml] ( validateAndTransformSchema(schema, fitting = false, getClassifier.featuresDataType) } - @Since("1.4.0") - override def transform(dataset: DataFrame): DataFrame = { + @Since("2.0.0") + override def transform(dataset: Dataset[_]): DataFrame = { // Check schema transformSchema(dataset.schema, logging = true) @@ -293,8 +293,8 @@ final class OneVsRest @Since("1.4.0") ( validateAndTransformSchema(schema, fitting = true, getClassifier.featuresDataType) } - @Since("1.4.0") - override def fit(dataset: DataFrame): OneVsRestModel = { + @Since("2.0.0") + override def fit(dataset: Dataset[_]): OneVsRestModel = { transformSchema(dataset.schema) // determine number of classes either from metadata if provided, or via computation. diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala index 865614aa5c..d00fee12b0 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala @@ -21,7 +21,7 @@ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util.SchemaUtils import org.apache.spark.mllib.linalg.{DenseVector, Vector, Vectors, VectorUDT} -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.{DataType, StructType} @@ -95,7 +95,7 @@ abstract class ProbabilisticClassificationModel[ * @param dataset input dataset * @return transformed dataset */ - override def transform(dataset: DataFrame): DataFrame = { + override def transform(dataset: Dataset[_]): DataFrame = { transformSchema(dataset.schema, logging = true) if (isDefined(thresholds)) { require($(thresholds).length == numClasses, this.getClass.getSimpleName + @@ -145,7 +145,7 @@ abstract class ProbabilisticClassificationModel[ this.logWarning(s"$uid: ProbabilisticClassificationModel.transform() was called as NOOP" + " since no output columns were set.") } - outputData + outputData.toDF } /** diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala index cb42532271..9d80b8eb68 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala @@ -31,7 +31,7 @@ import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} import org.apache.spark.mllib.tree.model.{RandomForestModel => OldRandomForestModel} import org.apache.spark.rdd.RDD -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.functions._ @@ -98,7 +98,7 @@ final class RandomForestClassifier @Since("1.4.0") ( override def setFeatureSubsetStrategy(value: String): this.type = super.setFeatureSubsetStrategy(value) - override protected def train(dataset: DataFrame): RandomForestClassificationModel = { + override protected def train(dataset: Dataset[_]): RandomForestClassificationModel = { val categoricalFeatures: Map[Int, Int] = MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol))) val numClasses: Int = MetadataUtils.getNumClasses(dataset.schema($(labelCol))) match { @@ -180,7 +180,7 @@ final class RandomForestClassificationModel private[ml] ( @Since("1.4.0") override def treeWeights: Array[Double] = _treeWeights - override protected def transformImpl(dataset: DataFrame): DataFrame = { + override protected def transformImpl(dataset: Dataset[_]): DataFrame = { val bcastModel = dataset.sqlContext.sparkContext.broadcast(this) val predictUDF = udf { (features: Any) => bcastModel.value.predict(features.asInstanceOf[Vector]) diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala index 55f751c57f..6cc9117da3 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala @@ -27,7 +27,7 @@ import org.apache.spark.ml.util._ import org.apache.spark.mllib.clustering. {BisectingKMeans => MLlibBisectingKMeans, BisectingKMeansModel => MLlibBisectingKMeansModel} import org.apache.spark.mllib.linalg.{Vector, VectorUDT} -import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.{DataFrame, Dataset, Row} import org.apache.spark.sql.functions.{col, udf} import org.apache.spark.sql.types.{IntegerType, StructType} @@ -92,7 +92,7 @@ class BisectingKMeansModel private[ml] ( } @Since("2.0.0") - override def transform(dataset: DataFrame): DataFrame = { + override def transform(dataset: Dataset[_]): DataFrame = { val predictUDF = udf((vector: Vector) => predict(vector)) dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol)))) } @@ -112,7 +112,7 @@ class BisectingKMeansModel private[ml] ( * centers. */ @Since("2.0.0") - def computeCost(dataset: DataFrame): Double = { + def computeCost(dataset: Dataset[_]): Double = { SchemaUtils.checkColumnType(dataset.schema, $(featuresCol), new VectorUDT) val data = dataset.select(col($(featuresCol))).rdd.map { case Row(point: Vector) => point } parentModel.computeCost(data) @@ -215,7 +215,7 @@ class BisectingKMeans @Since("2.0.0") ( def setMinDivisibleClusterSize(value: Double): this.type = set(minDivisibleClusterSize, value) @Since("2.0.0") - override def fit(dataset: DataFrame): BisectingKMeansModel = { + override def fit(dataset: Dataset[_]): BisectingKMeansModel = { val rdd = dataset.select(col($(featuresCol))).rdd.map { case Row(point: Vector) => point } val bkm = new MLlibBisectingKMeans() diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala index 120bf3cf9d..ead8ad7806 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala @@ -27,7 +27,7 @@ import org.apache.spark.ml.util._ import org.apache.spark.mllib.clustering.{GaussianMixture => MLlibGM, GaussianMixtureModel => MLlibGMModel} import org.apache.spark.mllib.linalg._ import org.apache.spark.mllib.stat.distribution.MultivariateGaussian -import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.{DataFrame, Dataset, Row} import org.apache.spark.sql.functions.{col, udf} import org.apache.spark.sql.types.{IntegerType, StructType} @@ -80,7 +80,7 @@ class GaussianMixtureModel private[ml] ( } @Since("2.0.0") - override def transform(dataset: DataFrame): DataFrame = { + override def transform(dataset: Dataset[_]): DataFrame = { val predUDF = udf((vector: Vector) => predict(vector)) val probUDF = udf((vector: Vector) => predictProbability(vector)) dataset.withColumn($(predictionCol), predUDF(col($(featuresCol)))) @@ -238,7 +238,7 @@ class GaussianMixture @Since("2.0.0") ( def setSeed(value: Long): this.type = set(seed, value) @Since("2.0.0") - override def fit(dataset: DataFrame): GaussianMixtureModel = { + override def fit(dataset: Dataset[_]): GaussianMixtureModel = { val rdd = dataset.select(col($(featuresCol))).rdd.map { case Row(point: Vector) => point } val algo = new MLlibGM() diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala index a8beef8b12..d716bc6887 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala @@ -27,7 +27,7 @@ import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util._ import org.apache.spark.mllib.clustering.{KMeans => MLlibKMeans, KMeansModel => MLlibKMeansModel} import org.apache.spark.mllib.linalg.{Vector, VectorUDT} -import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.{DataFrame, Dataset, Row} import org.apache.spark.sql.functions.{col, udf} import org.apache.spark.sql.types.{IntegerType, StructType} @@ -105,8 +105,8 @@ class KMeansModel private[ml] ( copyValues(copied, extra) } - @Since("1.5.0") - override def transform(dataset: DataFrame): DataFrame = { + @Since("2.0.0") + override def transform(dataset: Dataset[_]): DataFrame = { val predictUDF = udf((vector: Vector) => predict(vector)) dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol)))) } @@ -126,8 +126,8 @@ class KMeansModel private[ml] ( * model on the given data. */ // TODO: Replace the temp fix when we have proper evaluators defined for clustering. - @Since("1.6.0") - def computeCost(dataset: DataFrame): Double = { + @Since("2.0.0") + def computeCost(dataset: Dataset[_]): Double = { SchemaUtils.checkColumnType(dataset.schema, $(featuresCol), new VectorUDT) val data = dataset.select(col($(featuresCol))).rdd.map { case Row(point: Vector) => point } parentModel.computeCost(data) @@ -254,8 +254,8 @@ class KMeans @Since("1.5.0") ( @Since("1.5.0") def setSeed(value: Long): this.type = set(seed, value) - @Since("1.5.0") - override def fit(dataset: DataFrame): KMeansModel = { + @Since("2.0.0") + override def fit(dataset: Dataset[_]): KMeansModel = { val rdd = dataset.select(col($(featuresCol))).rdd.map { case Row(point: Vector) => point } val algo = new MLlibKMeans() diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala index 89a7a4ccf6..c57ceba4a9 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala @@ -32,7 +32,7 @@ import org.apache.spark.mllib.clustering.{DistributedLDAModel => OldDistributedL import org.apache.spark.mllib.impl.PeriodicCheckpointer import org.apache.spark.mllib.linalg.{Matrix, Vector, Vectors, VectorUDT} import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{DataFrame, Row, SQLContext} +import org.apache.spark.sql.{DataFrame, Dataset, Row, SQLContext} import org.apache.spark.sql.functions.{col, monotonicallyIncreasingId, udf} import org.apache.spark.sql.types.StructType @@ -402,15 +402,15 @@ sealed abstract class LDAModel private[ml] ( * is set to "em"), this involves collecting a large [[topicsMatrix]] to the driver. * This implementation may be changed in the future. */ - @Since("1.6.0") - override def transform(dataset: DataFrame): DataFrame = { + @Since("2.0.0") + override def transform(dataset: Dataset[_]): DataFrame = { if ($(topicDistributionCol).nonEmpty) { val t = udf(oldLocalModel.getTopicDistributionMethod(sqlContext.sparkContext)) - dataset.withColumn($(topicDistributionCol), t(col($(featuresCol)))) + dataset.withColumn($(topicDistributionCol), t(col($(featuresCol)))).toDF } else { logWarning("LDAModel.transform was called without any output columns. Set an output column" + " such as topicDistributionCol to produce results.") - dataset + dataset.toDF } } @@ -455,8 +455,8 @@ sealed abstract class LDAModel private[ml] ( * @param dataset test corpus to use for calculating log likelihood * @return variational lower bound on the log likelihood of the entire corpus */ - @Since("1.6.0") - def logLikelihood(dataset: DataFrame): Double = { + @Since("2.0.0") + def logLikelihood(dataset: Dataset[_]): Double = { val oldDataset = LDA.getOldDataset(dataset, $(featuresCol)) oldLocalModel.logLikelihood(oldDataset) } @@ -472,8 +472,8 @@ sealed abstract class LDAModel private[ml] ( * @param dataset test corpus to use for calculating perplexity * @return Variational upper bound on log perplexity per token. */ - @Since("1.6.0") - def logPerplexity(dataset: DataFrame): Double = { + @Since("2.0.0") + def logPerplexity(dataset: Dataset[_]): Double = { val oldDataset = LDA.getOldDataset(dataset, $(featuresCol)) oldLocalModel.logPerplexity(oldDataset) } @@ -840,8 +840,8 @@ class LDA @Since("1.6.0") ( @Since("1.6.0") override def copy(extra: ParamMap): LDA = defaultCopy(extra) - @Since("1.6.0") - override def fit(dataset: DataFrame): LDAModel = { + @Since("2.0.0") + override def fit(dataset: Dataset[_]): LDAModel = { transformSchema(dataset.schema, logging = true) val oldLDA = new OldLDA() .setK($(k)) @@ -873,7 +873,7 @@ class LDA @Since("1.6.0") ( private[clustering] object LDA extends DefaultParamsReadable[LDA] { /** Get dataset for spark.mllib LDA */ - def getOldDataset(dataset: DataFrame, featuresCol: String): RDD[(Long, Vector)] = { + def getOldDataset(dataset: Dataset[_], featuresCol: String): RDD[(Long, Vector)] = { dataset .withColumn("docId", monotonicallyIncreasingId()) .select("docId", featuresCol) diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala index 337ffbe90f..bde8c275fd 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala @@ -23,7 +23,7 @@ import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable, SchemaUtils} import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics import org.apache.spark.mllib.linalg.{Vector, VectorUDT} -import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.{Dataset, Row} import org.apache.spark.sql.types.DoubleType /** @@ -69,8 +69,8 @@ class BinaryClassificationEvaluator @Since("1.4.0") (@Since("1.4.0") override va setDefault(metricName -> "areaUnderROC") - @Since("1.2.0") - override def evaluate(dataset: DataFrame): Double = { + @Since("2.0.0") + override def evaluate(dataset: Dataset[_]): Double = { val schema = dataset.schema SchemaUtils.checkColumnTypes(schema, $(rawPredictionCol), Seq(DoubleType, new VectorUDT)) SchemaUtils.checkColumnType(schema, $(labelCol), DoubleType) diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/Evaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/Evaluator.scala index 0f22cca3a7..5f765c071b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/Evaluator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/Evaluator.scala @@ -19,7 +19,7 @@ package org.apache.spark.ml.evaluation import org.apache.spark.annotation.{DeveloperApi, Since} import org.apache.spark.ml.param.{ParamMap, Params} -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.Dataset /** * :: DeveloperApi :: @@ -36,8 +36,8 @@ abstract class Evaluator extends Params { * @param paramMap parameter map that specifies the input columns and output metrics * @return metric */ - @Since("1.5.0") - def evaluate(dataset: DataFrame, paramMap: ParamMap): Double = { + @Since("2.0.0") + def evaluate(dataset: Dataset[_], paramMap: ParamMap): Double = { this.copy(paramMap).evaluate(dataset) } @@ -46,8 +46,8 @@ abstract class Evaluator extends Params { * @param dataset a dataset that contains labels/observations and predictions. * @return metric */ - @Since("1.5.0") - def evaluate(dataset: DataFrame): Double + @Since("2.0.0") + def evaluate(dataset: Dataset[_]): Double /** * Indicates whether the metric returned by [[evaluate()]] should be maximized (true, default) diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala index 55ff44323a..3acfc221c9 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala @@ -22,7 +22,7 @@ import org.apache.spark.ml.param.{Param, ParamMap, ParamValidators} import org.apache.spark.ml.param.shared.{HasLabelCol, HasPredictionCol} import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable, SchemaUtils} import org.apache.spark.mllib.evaluation.MulticlassMetrics -import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.{Dataset, Row} import org.apache.spark.sql.types.DoubleType /** @@ -68,8 +68,8 @@ class MulticlassClassificationEvaluator @Since("1.5.0") (@Since("1.5.0") overrid setDefault(metricName -> "f1") - @Since("1.5.0") - override def evaluate(dataset: DataFrame): Double = { + @Since("2.0.0") + override def evaluate(dataset: Dataset[_]): Double = { val schema = dataset.schema SchemaUtils.checkColumnType(schema, $(predictionCol), DoubleType) SchemaUtils.checkColumnType(schema, $(labelCol), DoubleType) diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala index 9976d7ed43..4134e2dbc5 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala @@ -22,7 +22,7 @@ import org.apache.spark.ml.param.{Param, ParamMap, ParamValidators} import org.apache.spark.ml.param.shared.{HasLabelCol, HasPredictionCol} import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable} import org.apache.spark.mllib.evaluation.RegressionMetrics -import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.{DataFrame, Dataset, Row} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.{DoubleType, FloatType} @@ -70,8 +70,8 @@ final class RegressionEvaluator @Since("1.4.0") (@Since("1.4.0") override val ui setDefault(metricName -> "rmse") - @Since("1.4.0") - override def evaluate(dataset: DataFrame): Double = { + @Since("2.0.0") + override def evaluate(dataset: Dataset[_]): Double = { val schema = dataset.schema val predictionColName = $(predictionCol) val predictionType = schema($(predictionCol)).dataType diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala index 2f8e3a0371..898ac2cc89 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala @@ -64,7 +64,8 @@ final class Binarizer(override val uid: String) /** @group setParam */ def setOutputCol(value: String): this.type = set(outputCol, value) - override def transform(dataset: DataFrame): DataFrame = { + @Since("2.0.0") + override def transform(dataset: Dataset[_]): DataFrame = { val outputSchema = transformSchema(dataset.schema, logging = true) val schema = dataset.schema val inputType = schema($(inputCol)).dataType diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala index 33abc7c99d..10e622ace6 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala @@ -68,7 +68,8 @@ final class Bucketizer(override val uid: String) /** @group setParam */ def setOutputCol(value: String): this.type = set(outputCol, value) - override def transform(dataset: DataFrame): DataFrame = { + @Since("2.0.0") + override def transform(dataset: Dataset[_]): DataFrame = { transformSchema(dataset.schema) val bucketizer = udf { feature: Double => Bucketizer.binarySearchForBuckets($(splits), feature) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala index b9e9d56853..cfecae7e0b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala @@ -77,7 +77,8 @@ final class ChiSqSelector(override val uid: String) /** @group setParam */ def setLabelCol(value: String): this.type = set(labelCol, value) - override def fit(dataset: DataFrame): ChiSqSelectorModel = { + @Since("2.0.0") + override def fit(dataset: Dataset[_]): ChiSqSelectorModel = { transformSchema(dataset.schema, logging = true) val input = dataset.select($(labelCol), $(featuresCol)).rdd.map { case Row(label: Double, features: Vector) => @@ -127,7 +128,8 @@ final class ChiSqSelectorModel private[ml] ( /** @group setParam */ def setLabelCol(value: String): this.type = set(labelCol, value) - override def transform(dataset: DataFrame): DataFrame = { + @Since("2.0.0") + override def transform(dataset: Dataset[_]): DataFrame = { val transformedSchema = transformSchema(dataset.schema, logging = true) val newField = transformedSchema.last val selector = udf { chiSqSelector.transform _ } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala index 00abbbe29c..922670a41b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala @@ -26,7 +26,7 @@ import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} import org.apache.spark.ml.util._ import org.apache.spark.mllib.linalg.{Vectors, VectorUDT} import org.apache.spark.rdd.RDD -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ import org.apache.spark.util.collection.OpenHashMap @@ -147,7 +147,8 @@ class CountVectorizer(override val uid: String) setDefault(vocabSize -> (1 << 18), minDF -> 1) - override def fit(dataset: DataFrame): CountVectorizerModel = { + @Since("2.0.0") + override def fit(dataset: Dataset[_]): CountVectorizerModel = { transformSchema(dataset.schema, logging = true) val vocSize = $(vocabSize) val input = dataset.select($(inputCol)).rdd.map(_.getAs[Seq[String]](0)) @@ -224,7 +225,8 @@ class CountVectorizerModel(override val uid: String, val vocabulary: Array[Strin /** Dictionary created from [[vocabulary]] and its indices, broadcast once for [[transform()]] */ private var broadcastDict: Option[Broadcast[Map[String, Int]]] = None - override def transform(dataset: DataFrame): DataFrame = { + @Since("2.0.0") + override def transform(dataset: Dataset[_]): DataFrame = { transformSchema(dataset.schema, logging = true) if (broadcastDict.isEmpty) { val dict = vocabulary.zipWithIndex.toMap diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala index 0f7ae5a100..467ad73074 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala @@ -24,7 +24,7 @@ import org.apache.spark.ml.param.{BooleanParam, IntParam, ParamMap, ParamValidat import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} import org.apache.spark.ml.util._ import org.apache.spark.mllib.feature -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.functions.{col, udf} import org.apache.spark.sql.types.{ArrayType, StructType} @@ -77,7 +77,8 @@ class HashingTF(override val uid: String) /** @group setParam */ def setBinary(value: Boolean): this.type = set(binary, value) - override def transform(dataset: DataFrame): DataFrame = { + @Since("2.0.0") + override def transform(dataset: Dataset[_]): DataFrame = { val outputSchema = transformSchema(dataset.schema) val hashingTF = new feature.HashingTF($(numFeatures)).setBinary($(binary)) val t = udf { terms: Seq[_] => hashingTF.transform(terms) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala index f36cf503a0..5075b78c98 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala @@ -76,7 +76,8 @@ final class IDF(override val uid: String) extends Estimator[IDFModel] with IDFBa /** @group setParam */ def setMinDocFreq(value: Int): this.type = set(minDocFreq, value) - override def fit(dataset: DataFrame): IDFModel = { + @Since("2.0.0") + override def fit(dataset: Dataset[_]): IDFModel = { transformSchema(dataset.schema, logging = true) val input = dataset.select($(inputCol)).rdd.map { case Row(v: Vector) => v } val idf = new feature.IDF($(minDocFreq)).fit(input) @@ -115,7 +116,8 @@ class IDFModel private[ml] ( /** @group setParam */ def setOutputCol(value: String): this.type = set(outputCol, value) - override def transform(dataset: DataFrame): DataFrame = { + @Since("2.0.0") + override def transform(dataset: Dataset[_]): DataFrame = { transformSchema(dataset.schema, logging = true) val idf = udf { vec: Vector => idfModel.transform(vec) } dataset.withColumn($(outputCol), idf(col($(inputCol)))) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala index d3fe6e528f..9ca34e9ae2 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala @@ -27,7 +27,7 @@ import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util._ import org.apache.spark.ml.Transformer import org.apache.spark.mllib.linalg.{Vector, Vectors, VectorUDT} -import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.{DataFrame, Dataset, Row} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ @@ -68,8 +68,8 @@ class Interaction @Since("1.6.0") (override val uid: String) extends Transformer StructType(schema.fields :+ StructField($(outputCol), new VectorUDT, false)) } - @Since("1.6.0") - override def transform(dataset: DataFrame): DataFrame = { + @Since("2.0.0") + override def transform(dataset: Dataset[_]): DataFrame = { val inputFeatures = $(inputCols).map(c => dataset.schema(c)) val featureEncoders = getFeatureEncoders(inputFeatures) val featureAttrs = getFeatureAttrs(inputFeatures) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/MaxAbsScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/MaxAbsScaler.scala index 7de5a4d5d3..e9df600c8a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/MaxAbsScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/MaxAbsScaler.scala @@ -66,7 +66,8 @@ class MaxAbsScaler @Since("2.0.0") (override val uid: String) /** @group setParam */ def setOutputCol(value: String): this.type = set(outputCol, value) - override def fit(dataset: DataFrame): MaxAbsScalerModel = { + @Since("2.0.0") + override def fit(dataset: Dataset[_]): MaxAbsScalerModel = { transformSchema(dataset.schema, logging = true) val input = dataset.select($(inputCol)).rdd.map { case Row(v: Vector) => v } val summary = Statistics.colStats(input) @@ -111,7 +112,8 @@ class MaxAbsScalerModel private[ml] ( /** @group setParam */ def setOutputCol(value: String): this.type = set(outputCol, value) - override def transform(dataset: DataFrame): DataFrame = { + @Since("2.0.0") + override def transform(dataset: Dataset[_]): DataFrame = { transformSchema(dataset.schema, logging = true) // TODO: this looks hack, we may have to handle sparse and dense vectors separately. val maxAbsUnzero = Vectors.dense(maxAbs.toArray.map(x => if (x == 0) 1 else x)) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala index b13684a1cb..125becbb8a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala @@ -103,7 +103,8 @@ class MinMaxScaler(override val uid: String) /** @group setParam */ def setMax(value: Double): this.type = set(max, value) - override def fit(dataset: DataFrame): MinMaxScalerModel = { + @Since("2.0.0") + override def fit(dataset: Dataset[_]): MinMaxScalerModel = { transformSchema(dataset.schema, logging = true) val input = dataset.select($(inputCol)).rdd.map { case Row(v: Vector) => v } val summary = Statistics.colStats(input) @@ -154,7 +155,8 @@ class MinMaxScalerModel private[ml] ( /** @group setParam */ def setMax(value: Double): this.type = set(max, value) - override def transform(dataset: DataFrame): DataFrame = { + @Since("2.0.0") + override def transform(dataset: Dataset[_]): DataFrame = { val originalRange = (originalMax.toBreeze - originalMin.toBreeze).toArray val minArray = originalMin.toArray diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala index 4f67042629..99357793db 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala @@ -24,7 +24,7 @@ import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} import org.apache.spark.ml.util._ import org.apache.spark.mllib.linalg.Vectors -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.functions.{col, udf} import org.apache.spark.sql.types.{DoubleType, NumericType, StructType} @@ -121,7 +121,8 @@ class OneHotEncoder(override val uid: String) extends Transformer StructType(outputFields) } - override def transform(dataset: DataFrame): DataFrame = { + @Since("2.0.0") + override def transform(dataset: Dataset[_]): DataFrame = { // schema transformation val inputColName = $(inputCol) val outputColName = $(outputCol) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala index 305c3d187f..9cf722e121 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala @@ -68,7 +68,8 @@ class PCA (override val uid: String) extends Estimator[PCAModel] with PCAParams /** * Computes a [[PCAModel]] that contains the principal components of the input vectors. */ - override def fit(dataset: DataFrame): PCAModel = { + @Since("2.0.0") + override def fit(dataset: Dataset[_]): PCAModel = { transformSchema(dataset.schema, logging = true) val input = dataset.select($(inputCol)).rdd.map { case Row(v: Vector) => v} val pca = new feature.PCA(k = $(k)) @@ -124,7 +125,8 @@ class PCAModel private[ml] ( * NOTE: Vectors to be transformed must be the same length * as the source vectors given to [[PCA.fit()]]. */ - override def transform(dataset: DataFrame): DataFrame = { + @Since("2.0.0") + override def transform(dataset: Dataset[_]): DataFrame = { transformSchema(dataset.schema, logging = true) val pcaModel = new feature.PCAModel($(k), pc, explainedVariance) val pcaOp = udf { pcaModel.transform _ } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala index e486e92c12..efe8b93d82 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala @@ -23,10 +23,10 @@ import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.internal.Logging import org.apache.spark.ml._ import org.apache.spark.ml.attribute.NominalAttribute -import org.apache.spark.ml.param.{IntParam, _} +import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol, HasSeed} import org.apache.spark.ml.util._ -import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.{Dataset, Row} import org.apache.spark.sql.types.{DoubleType, StructType} import org.apache.spark.util.random.XORShiftRandom @@ -87,7 +87,8 @@ final class QuantileDiscretizer(override val uid: String) StructType(outputFields) } - override def fit(dataset: DataFrame): Bucketizer = { + @Since("2.0.0") + override def fit(dataset: Dataset[_]): Bucketizer = { val samples = QuantileDiscretizer .getSampledInput(dataset.select($(inputCol)), $(numBuckets), $(seed)) .map { case Row(feature: Double) => feature } @@ -112,13 +113,15 @@ object QuantileDiscretizer extends DefaultParamsReadable[QuantileDiscretizer] wi /** * Sampling from the given dataset to collect quantile statistics. */ - private[feature] def getSampledInput(dataset: DataFrame, numBins: Int, seed: Long): Array[Row] = { + private[feature] + def getSampledInput(dataset: Dataset[_], numBins: Int, seed: Long): Array[Row] = { val totalSamples = dataset.count() require(totalSamples > 0, "QuantileDiscretizer requires non-empty input dataset but was given an empty input.") val requiredSamples = math.max(numBins * numBins, minSamplesRequired) val fraction = math.min(requiredSamples.toDouble / totalSamples, 1.0) - dataset.sample(withReplacement = false, fraction, new XORShiftRandom(seed).nextInt()).collect() + dataset.toDF.sample(withReplacement = false, fraction, new XORShiftRandom(seed).nextInt()) + .collect() } /** diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala index 12a76dbbfb..3ac6c77669 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala @@ -29,7 +29,7 @@ import org.apache.spark.ml.param.{Param, ParamMap} import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasLabelCol} import org.apache.spark.ml.util._ import org.apache.spark.mllib.linalg.VectorUDT -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.types._ /** @@ -103,7 +103,8 @@ class RFormula(override val uid: String) RFormulaParser.parse($(formula)).hasIntercept } - override def fit(dataset: DataFrame): RFormulaModel = { + @Since("2.0.0") + override def fit(dataset: Dataset[_]): RFormulaModel = { require(isDefined(formula), "Formula must be defined first.") val parsedFormula = RFormulaParser.parse($(formula)) val resolvedFormula = parsedFormula.resolve(dataset.schema) @@ -204,7 +205,8 @@ class RFormulaModel private[feature]( private[ml] val pipelineModel: PipelineModel) extends Model[RFormulaModel] with RFormulaBase with MLWritable { - override def transform(dataset: DataFrame): DataFrame = { + @Since("2.0.0") + override def transform(dataset: Dataset[_]): DataFrame = { checkCanTransform(dataset.schema) transformLabel(pipelineModel.transform(dataset)) } @@ -232,10 +234,10 @@ class RFormulaModel private[feature]( override def toString: String = s"RFormulaModel($resolvedFormula) (uid=$uid)" - private def transformLabel(dataset: DataFrame): DataFrame = { + private def transformLabel(dataset: Dataset[_]): DataFrame = { val labelName = resolvedFormula.label if (hasLabelCol(dataset.schema)) { - dataset + dataset.toDF } else if (dataset.schema.exists(_.name == labelName)) { dataset.schema(labelName).dataType match { case _: NumericType | BooleanType => @@ -246,7 +248,7 @@ class RFormulaModel private[feature]( } else { // Ignore the label field. This is a hack so that this transformer can also work on test // datasets in a Pipeline. - dataset + dataset.toDF } } @@ -323,7 +325,7 @@ private class ColumnPruner(override val uid: String, val columnsToPrune: Set[Str def this(columnsToPrune: Set[String]) = this(Identifiable.randomUID("columnPruner"), columnsToPrune) - override def transform(dataset: DataFrame): DataFrame = { + override def transform(dataset: Dataset[_]): DataFrame = { val columnsToKeep = dataset.columns.filter(!columnsToPrune.contains(_)) dataset.select(columnsToKeep.map(dataset.col): _*) } @@ -396,7 +398,7 @@ private class VectorAttributeRewriter( def this(vectorCol: String, prefixesToRewrite: Map[String, String]) = this(Identifiable.randomUID("vectorAttrRewriter"), vectorCol, prefixesToRewrite) - override def transform(dataset: DataFrame): DataFrame = { + override def transform(dataset: Dataset[_]): DataFrame = { val metadata = { val group = AttributeGroup.fromStructField(dataset.schema(vectorCol)) val attrs = group.attributes.get.map { attr => diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala index e0ca45b9a6..95fe942c6b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala @@ -22,7 +22,7 @@ import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.param.{Param, ParamMap} import org.apache.spark.ml.Transformer import org.apache.spark.ml.util._ -import org.apache.spark.sql.{DataFrame, Row, SQLContext} +import org.apache.spark.sql.{DataFrame, Dataset, Row, SQLContext} import org.apache.spark.sql.types.StructType /** @@ -63,8 +63,8 @@ class SQLTransformer @Since("1.6.0") (override val uid: String) extends Transfor private val tableIdentifier: String = "__THIS__" - @Since("1.6.0") - override def transform(dataset: DataFrame): DataFrame = { + @Since("2.0.0") + override def transform(dataset: Dataset[_]): DataFrame = { val tableName = Identifiable.randomUID(uid) dataset.registerTempTable(tableName) val realStatement = $(statement).replace(tableIdentifier, tableName) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala index 26ee8e1bf1..118a6e3e6a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala @@ -85,7 +85,8 @@ class StandardScaler(override val uid: String) extends Estimator[StandardScalerM /** @group setParam */ def setWithStd(value: Boolean): this.type = set(withStd, value) - override def fit(dataset: DataFrame): StandardScalerModel = { + @Since("2.0.0") + override def fit(dataset: Dataset[_]): StandardScalerModel = { transformSchema(dataset.schema, logging = true) val input = dataset.select($(inputCol)).rdd.map { case Row(v: Vector) => v } val scaler = new feature.StandardScaler(withMean = $(withMean), withStd = $(withStd)) @@ -135,7 +136,8 @@ class StandardScalerModel private[ml] ( /** @group setParam */ def setOutputCol(value: String): this.type = set(outputCol, value) - override def transform(dataset: DataFrame): DataFrame = { + @Since("2.0.0") + override def transform(dataset: Dataset[_]): DataFrame = { transformSchema(dataset.schema, logging = true) val scaler = new feature.StandardScalerModel(std, mean, $(withStd), $(withMean)) val scale = udf { scaler.transform _ } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala index 0a0e0b0960..b96bc48566 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala @@ -22,7 +22,7 @@ import org.apache.spark.ml.Transformer import org.apache.spark.ml.param.{BooleanParam, ParamMap, StringArrayParam} import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} import org.apache.spark.ml.util._ -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.functions.{col, udf} import org.apache.spark.sql.types.{ArrayType, StringType, StructType} @@ -125,7 +125,8 @@ class StopWordsRemover(override val uid: String) setDefault(stopWords -> StopWords.English, caseSensitive -> false) - override def transform(dataset: DataFrame): DataFrame = { + @Since("2.0.0") + override def transform(dataset: Dataset[_]): DataFrame = { val outputSchema = transformSchema(dataset.schema) val t = if ($(caseSensitive)) { val stopWordsSet = $(stopWords).toSet diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala index faa0f6f407..7e0d374f02 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala @@ -26,7 +26,7 @@ import org.apache.spark.ml.attribute.{Attribute, NominalAttribute} import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util._ -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ import org.apache.spark.util.collection.OpenHashMap @@ -80,7 +80,8 @@ class StringIndexer(override val uid: String) extends Estimator[StringIndexerMod def setOutputCol(value: String): this.type = set(outputCol, value) - override def fit(dataset: DataFrame): StringIndexerModel = { + @Since("2.0.0") + override def fit(dataset: Dataset[_]): StringIndexerModel = { val counts = dataset.select(col($(inputCol)).cast(StringType)) .rdd .map(_.getString(0)) @@ -144,11 +145,12 @@ class StringIndexerModel ( /** @group setParam */ def setOutputCol(value: String): this.type = set(outputCol, value) - override def transform(dataset: DataFrame): DataFrame = { + @Since("2.0.0") + override def transform(dataset: Dataset[_]): DataFrame = { if (!dataset.schema.fieldNames.contains($(inputCol))) { logInfo(s"Input column ${$(inputCol)} does not exist during transformation. " + "Skip StringIndexerModel.") - return dataset + return dataset.toDF } validateAndTransformSchema(dataset.schema) @@ -286,7 +288,8 @@ class IndexToString private[ml] (override val uid: String) StructType(outputFields) } - override def transform(dataset: DataFrame): DataFrame = { + @Since("2.0.0") + override def transform(dataset: Dataset[_]): DataFrame = { val inputColSchema = dataset.schema($(inputCol)) // If the labels array is empty use column metadata val values = if ($(labels).isEmpty) { diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala index 957e8e7a59..4d3e46e488 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala @@ -27,7 +27,7 @@ import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util._ import org.apache.spark.mllib.linalg.{Vector, Vectors, VectorUDT} -import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.{DataFrame, Dataset, Row} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ @@ -47,10 +47,11 @@ class VectorAssembler(override val uid: String) /** @group setParam */ def setOutputCol(value: String): this.type = set(outputCol, value) - override def transform(dataset: DataFrame): DataFrame = { + @Since("2.0.0") + override def transform(dataset: Dataset[_]): DataFrame = { // Schema transformation. val schema = dataset.schema - lazy val first = dataset.first() + lazy val first = dataset.toDF.first() val attrs = $(inputCols).flatMap { c => val field = schema(c) val index = schema.fieldIndex(c) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala index bf4aef2a74..68b699d569 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala @@ -31,7 +31,7 @@ import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util._ import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, VectorUDT} -import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.{DataFrame, Dataset, Row} import org.apache.spark.sql.functions.udf import org.apache.spark.sql.types.{StructField, StructType} import org.apache.spark.util.collection.OpenHashSet @@ -108,7 +108,8 @@ class VectorIndexer(override val uid: String) extends Estimator[VectorIndexerMod /** @group setParam */ def setOutputCol(value: String): this.type = set(outputCol, value) - override def fit(dataset: DataFrame): VectorIndexerModel = { + @Since("2.0.0") + override def fit(dataset: Dataset[_]): VectorIndexerModel = { transformSchema(dataset.schema, logging = true) val firstRow = dataset.select($(inputCol)).take(1) require(firstRow.length == 1, s"VectorIndexer cannot be fit on an empty dataset.") @@ -345,7 +346,8 @@ class VectorIndexerModel private[ml] ( /** @group setParam */ def setOutputCol(value: String): this.type = set(outputCol, value) - override def transform(dataset: DataFrame): DataFrame = { + @Since("2.0.0") + override def transform(dataset: Dataset[_]): DataFrame = { transformSchema(dataset.schema, logging = true) val newField = prepOutputField(dataset.schema) val transformUDF = udf { (vector: Vector) => transformFunc(vector) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala index b60e82de00..7a9468b87b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala @@ -24,7 +24,7 @@ import org.apache.spark.ml.param.{IntArrayParam, ParamMap, StringArrayParam} import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} import org.apache.spark.ml.util._ import org.apache.spark.mllib.linalg._ -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.StructType @@ -89,7 +89,8 @@ final class VectorSlicer(override val uid: String) /** @group setParam */ def setOutputCol(value: String): this.type = set(outputCol, value) - override def transform(dataset: DataFrame): DataFrame = { + @Since("2.0.0") + override def transform(dataset: Dataset[_]): DataFrame = { // Validity checks transformSchema(dataset.schema) val inputAttr = AttributeGroup.fromStructField(dataset.schema($(inputCol))) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala index 95bae1c8a3..a72692960f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala @@ -27,7 +27,7 @@ import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util._ import org.apache.spark.mllib.feature import org.apache.spark.mllib.linalg.{BLAS, Vector, Vectors, VectorUDT} -import org.apache.spark.sql.{DataFrame, SQLContext} +import org.apache.spark.sql.{DataFrame, Dataset, SQLContext} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ @@ -135,7 +135,8 @@ final class Word2Vec(override val uid: String) extends Estimator[Word2VecModel] /** @group setParam */ def setMinCount(value: Int): this.type = set(minCount, value) - override def fit(dataset: DataFrame): Word2VecModel = { + @Since("2.0.0") + override def fit(dataset: Dataset[_]): Word2VecModel = { transformSchema(dataset.schema, logging = true) val input = dataset.select($(inputCol)).rdd.map(_.getAs[Seq[String]](0)) val wordVectors = new feature.Word2Vec() @@ -219,7 +220,8 @@ class Word2VecModel private[ml] ( * Transform a sentence column to a vector column to represent the whole sentence. The transform * is performed by averaging all word vectors it contains. */ - override def transform(dataset: DataFrame): DataFrame = { + @Since("2.0.0") + override def transform(dataset: Dataset[_]): DataFrame = { transformSchema(dataset.schema, logging = true) val vectors = wordVectors.getVectors .mapValues(vv => Vectors.dense(vv.map(_.toDouble))) diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/AFTSurvivalRegressionWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/AFTSurvivalRegressionWrapper.scala index 40590e71c4..2ae411555f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/AFTSurvivalRegressionWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/AFTSurvivalRegressionWrapper.scala @@ -22,7 +22,7 @@ import org.apache.spark.ml.{Pipeline, PipelineModel} import org.apache.spark.ml.attribute.AttributeGroup import org.apache.spark.ml.feature.RFormula import org.apache.spark.ml.regression.{AFTSurvivalRegression, AFTSurvivalRegressionModel} -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.{DataFrame, Dataset} private[r] class AFTSurvivalRegressionWrapper private ( pipeline: PipelineModel, @@ -43,7 +43,7 @@ private[r] class AFTSurvivalRegressionWrapper private ( features ++ Array("Log(scale)") } - def transform(dataset: DataFrame): DataFrame = { + def transform(dataset: Dataset[_]): DataFrame = { pipeline.transform(dataset) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/KMeansWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/KMeansWrapper.scala index ed735a4ea3..ee513579ce 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/KMeansWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/KMeansWrapper.scala @@ -21,7 +21,7 @@ import org.apache.spark.ml.{Pipeline, PipelineModel} import org.apache.spark.ml.attribute.AttributeGroup import org.apache.spark.ml.clustering.{KMeans, KMeansModel} import org.apache.spark.ml.feature.VectorAssembler -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.{DataFrame, Dataset} private[r] class KMeansWrapper private ( pipeline: PipelineModel) { @@ -52,7 +52,7 @@ private[r] class KMeansWrapper private ( } } - def transform(dataset: DataFrame): DataFrame = { + def transform(dataset: Dataset[_]): DataFrame = { pipeline.transform(dataset).drop(kMeansModel.getFeaturesCol) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/NaiveBayesWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/NaiveBayesWrapper.scala index 07383d393d..2cd709d2ee 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/NaiveBayesWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/NaiveBayesWrapper.scala @@ -21,7 +21,7 @@ import org.apache.spark.ml.{Pipeline, PipelineModel} import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, NominalAttribute} import org.apache.spark.ml.classification.{NaiveBayes, NaiveBayesModel} import org.apache.spark.ml.feature.{IndexToString, RFormula} -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.{DataFrame, Dataset} private[r] class NaiveBayesWrapper private ( pipeline: PipelineModel, @@ -36,7 +36,7 @@ private[r] class NaiveBayesWrapper private ( lazy val tables: Array[Double] = naiveBayesModel.theta.toArray.map(math.exp) - def transform(dataset: DataFrame): DataFrame = { + def transform(dataset: Dataset[_]): DataFrame = { pipeline.transform(dataset).drop(PREDICTED_LABEL_INDEX_COL) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala index 4a3ad662a0..36dce01590 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala @@ -40,7 +40,7 @@ import org.apache.spark.ml.util._ import org.apache.spark.mllib.linalg.CholeskyDecomposition import org.apache.spark.mllib.optimization.NNLS import org.apache.spark.rdd.RDD -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.{DoubleType, FloatType, IntegerType, StructType} import org.apache.spark.storage.StorageLevel @@ -200,8 +200,8 @@ class ALSModel private[ml] ( @Since("1.3.0") def setPredictionCol(value: String): this.type = set(predictionCol, value) - @Since("1.3.0") - override def transform(dataset: DataFrame): DataFrame = { + @Since("2.0.0") + override def transform(dataset: Dataset[_]): DataFrame = { // Register a UDF for DataFrame, and then // create a new column named map(predictionCol) by running the predict UDF. val predict = udf { (userFeatures: Seq[Float], itemFeatures: Seq[Float]) => @@ -385,8 +385,8 @@ class ALS(@Since("1.4.0") override val uid: String) extends Estimator[ALSModel] this } - @Since("1.3.0") - override def fit(dataset: DataFrame): ALSModel = { + @Since("2.0.0") + override def fit(dataset: Dataset[_]): ALSModel = { import dataset.sqlContext.implicits._ val r = if ($(ratingCol) != "") col($(ratingCol)).cast(FloatType) else lit(1.0f) val ratings = dataset diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala index 3278974954..afed1f32b6 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala @@ -32,7 +32,7 @@ import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util._ import org.apache.spark.mllib.linalg.{BLAS, Vector, Vectors, VectorUDT} import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.{DataFrame, Dataset, Row} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.{DoubleType, StructType} import org.apache.spark.storage.StorageLevel @@ -183,7 +183,7 @@ class AFTSurvivalRegression @Since("1.6.0") (@Since("1.6.0") override val uid: S * Extract [[featuresCol]], [[labelCol]] and [[censorCol]] from input dataset, * and put it in an RDD with strong types. */ - protected[ml] def extractAFTPoints(dataset: DataFrame): RDD[AFTPoint] = { + protected[ml] def extractAFTPoints(dataset: Dataset[_]): RDD[AFTPoint] = { dataset.select(col($(featuresCol)), col($(labelCol)).cast(DoubleType), col($(censorCol))) .rdd.map { case Row(features: Vector, label: Double, censor: Double) => @@ -191,8 +191,8 @@ class AFTSurvivalRegression @Since("1.6.0") (@Since("1.6.0") override val uid: S } } - @Since("1.6.0") - override def fit(dataset: DataFrame): AFTSurvivalRegressionModel = { + @Since("2.0.0") + override def fit(dataset: Dataset[_]): AFTSurvivalRegressionModel = { validateAndTransformSchema(dataset.schema, fitting = true) val instances = extractAFTPoints(dataset) val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE @@ -299,8 +299,8 @@ class AFTSurvivalRegressionModel private[ml] ( math.exp(BLAS.dot(coefficients, features) + intercept) } - @Since("1.6.0") - override def transform(dataset: DataFrame): DataFrame = { + @Since("2.0.0") + override def transform(dataset: Dataset[_]): DataFrame = { transformSchema(dataset.schema) val predictUDF = udf { features: Vector => predict(features) } val predictQuantilesUDF = udf { features: Vector => predictQuantiles(features)} diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala index 1289a317ee..c04c416aaf 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala @@ -33,7 +33,7 @@ import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy} import org.apache.spark.mllib.tree.model.{DecisionTreeModel => OldDecisionTreeModel} import org.apache.spark.rdd.RDD -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.functions._ @@ -83,7 +83,7 @@ final class DecisionTreeRegressor @Since("1.4.0") (@Since("1.4.0") override val /** @group setParam */ def setVarianceCol(value: String): this.type = set(varianceCol, value) - override protected def train(dataset: DataFrame): DecisionTreeRegressionModel = { + override protected def train(dataset: Dataset[_]): DecisionTreeRegressionModel = { val categoricalFeatures: Map[Int, Int] = MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol))) val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset) @@ -158,15 +158,16 @@ final class DecisionTreeRegressionModel private[ml] ( rootNode.predictImpl(features).impurityStats.calculate() } - override def transform(dataset: DataFrame): DataFrame = { + @Since("2.0.0") + override def transform(dataset: Dataset[_]): DataFrame = { transformSchema(dataset.schema, logging = true) transformImpl(dataset) } - override protected def transformImpl(dataset: DataFrame): DataFrame = { + override protected def transformImpl(dataset: Dataset[_]): DataFrame = { val predictUDF = udf { (features: Vector) => predict(features) } val predictVarianceUDF = udf { (features: Vector) => predictVariance(features) } - var output = dataset + var output = dataset.toDF if ($(predictionCol).nonEmpty) { output = output.withColumn($(predictionCol), predictUDF(col($(featuresCol)))) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala index 8eb2984f7b..0b52fe2d13 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala @@ -33,7 +33,7 @@ import org.apache.spark.mllib.tree.loss.{AbsoluteError => OldAbsoluteError, Loss SquaredError => OldSquaredError} import org.apache.spark.mllib.tree.model.{GradientBoostedTreesModel => OldGBTModel} import org.apache.spark.rdd.RDD -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.functions._ /** @@ -147,7 +147,7 @@ final class GBTRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: Stri } } - override protected def train(dataset: DataFrame): GBTRegressionModel = { + override protected def train(dataset: Dataset[_]): GBTRegressionModel = { val categoricalFeatures: Map[Int, Int] = MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol))) val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset) @@ -209,7 +209,7 @@ final class GBTRegressionModel private[ml]( @Since("1.4.0") override def treeWeights: Array[Double] = _treeWeights - override protected def transformImpl(dataset: DataFrame): DataFrame = { + override protected def transformImpl(dataset: Dataset[_]): DataFrame = { val bcastModel = dataset.sqlContext.sparkContext.broadcast(this) val predictUDF = udf { (features: Any) => bcastModel.value.predict(features.asInstanceOf[Vector]) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala index 05bf64591b..00cf25dc54 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala @@ -31,7 +31,7 @@ import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util._ import org.apache.spark.mllib.linalg.{BLAS, Vector} import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.{DataFrame, Dataset, Row} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.{DataType, DoubleType, StructType} @@ -196,7 +196,7 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val def setSolver(value: String): this.type = set(solver, value) setDefault(solver -> "irls") - override protected def train(dataset: DataFrame): GeneralizedLinearRegressionModel = { + override protected def train(dataset: Dataset[_]): GeneralizedLinearRegressionModel = { val familyObj = Family.fromName($(family)) val linkObj = if (isDefined(link)) { Link.fromName($(link)) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala index bd0b631d89..7a78ecbdf1 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala @@ -30,7 +30,7 @@ import org.apache.spark.mllib.linalg.{Vector, Vectors, VectorUDT} import org.apache.spark.mllib.regression.{IsotonicRegression => MLlibIsotonicRegression} import org.apache.spark.mllib.regression.{IsotonicRegressionModel => MLlibIsotonicRegressionModel} import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.{DataFrame, Dataset, Row} import org.apache.spark.sql.functions.{col, lit, udf} import org.apache.spark.sql.types.{DoubleType, StructType} import org.apache.spark.storage.StorageLevel @@ -77,7 +77,7 @@ private[regression] trait IsotonicRegressionBase extends Params with HasFeatures * Extracts (label, feature, weight) from input dataset. */ protected[ml] def extractWeightedLabeledPoints( - dataset: DataFrame): RDD[(Double, Double, Double)] = { + dataset: Dataset[_]): RDD[(Double, Double, Double)] = { val f = if (dataset.schema($(featuresCol)).dataType.isInstanceOf[VectorUDT]) { val idx = $(featureIndex) val extract = udf { v: Vector => v(idx) } @@ -164,8 +164,8 @@ class IsotonicRegression @Since("1.5.0") (@Since("1.5.0") override val uid: Stri @Since("1.5.0") override def copy(extra: ParamMap): IsotonicRegression = defaultCopy(extra) - @Since("1.5.0") - override def fit(dataset: DataFrame): IsotonicRegressionModel = { + @Since("2.0.0") + override def fit(dataset: Dataset[_]): IsotonicRegressionModel = { validateAndTransformSchema(dataset.schema, fitting = true) // Extract columns from data. If dataset is persisted, do not persist oldDataset. val instances = extractWeightedLabeledPoints(dataset) @@ -236,8 +236,8 @@ class IsotonicRegressionModel private[ml] ( copyValues(new IsotonicRegressionModel(uid, oldModel), extra).setParent(parent) } - @Since("1.5.0") - override def transform(dataset: DataFrame): DataFrame = { + @Since("2.0.0") + override def transform(dataset: Dataset[_]): DataFrame = { val predict = dataset.schema($(featuresCol)).dataType match { case DoubleType => udf { feature: Double => oldModel.predict(feature) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index aacff4ea47..71e02730c7 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -38,7 +38,7 @@ import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.linalg.BLAS._ import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.{DataFrame, Dataset, Row} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.DoubleType import org.apache.spark.storage.StorageLevel @@ -158,7 +158,7 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String def setSolver(value: String): this.type = set(solver, value) setDefault(solver -> "auto") - override protected def train(dataset: DataFrame): LinearRegressionModel = { + override protected def train(dataset: Dataset[_]): LinearRegressionModel = { // Extract the number of features before deciding optimization solver. val numFeatures = dataset.select(col($(featuresCol))).limit(1).rdd.map { case Row(features: Vector) => features.size @@ -417,7 +417,7 @@ class LinearRegressionModel private[ml] ( * @param dataset Test dataset to evaluate model on. */ @Since("2.0.0") - def evaluate(dataset: DataFrame): LinearRegressionSummary = { + def evaluate(dataset: Dataset[_]): LinearRegressionSummary = { // Handle possible missing or invalid prediction columns val (summaryModel, predictionColName) = findSummaryModelAndPredictionCol() new LinearRegressionSummary(summaryModel.transform(dataset), predictionColName, diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala index 736cd9f776..bee13c2ebf 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala @@ -32,7 +32,7 @@ import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} import org.apache.spark.mllib.tree.model.{RandomForestModel => OldRandomForestModel} import org.apache.spark.rdd.RDD -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.functions._ @@ -93,7 +93,7 @@ final class RandomForestRegressor @Since("1.4.0") (@Since("1.4.0") override val override def setFeatureSubsetStrategy(value: String): this.type = super.setFeatureSubsetStrategy(value) - override protected def train(dataset: DataFrame): RandomForestRegressionModel = { + override protected def train(dataset: Dataset[_]): RandomForestRegressionModel = { val categoricalFeatures: Map[Int, Int] = MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol))) val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset) @@ -164,7 +164,7 @@ final class RandomForestRegressionModel private[ml] ( @Since("1.4.0") override def treeWeights: Array[Double] = _treeWeights - override protected def transformImpl(dataset: DataFrame): DataFrame = { + override protected def transformImpl(dataset: Dataset[_]): DataFrame = { val bcastModel = dataset.sqlContext.sparkContext.broadcast(this) val predictUDF = udf { (features: Any) => bcastModel.value.predict(features.asInstanceOf[Vector]) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala index 4d9d4d472e..de563d4fad 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala @@ -33,7 +33,7 @@ import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared.HasSeed import org.apache.spark.ml.util._ import org.apache.spark.mllib.util.MLUtils -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.types.StructType /** @@ -90,8 +90,8 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String) @Since("2.0.0") def setSeed(value: Long): this.type = set(seed, value) - @Since("1.4.0") - override def fit(dataset: DataFrame): CrossValidatorModel = { + @Since("2.0.0") + override def fit(dataset: Dataset[_]): CrossValidatorModel = { val schema = dataset.schema transformSchema(schema, logging = true) val sqlCtx = dataset.sqlContext @@ -100,7 +100,7 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String) val epm = $(estimatorParamMaps) val numModels = epm.length val metrics = new Array[Double](epm.length) - val splits = MLUtils.kFold(dataset.rdd, $(numFolds), $(seed)) + val splits = MLUtils.kFold(dataset.toDF.rdd, $(numFolds), $(seed)) splits.zipWithIndex.foreach { case ((training, validation), splitIndex) => val trainingDataset = sqlCtx.createDataFrame(training, schema).cache() val validationDataset = sqlCtx.createDataFrame(validation, schema).cache() @@ -209,8 +209,8 @@ class CrossValidatorModel private[ml] ( this(uid, bestModel, avgMetrics.asScala.toArray) } - @Since("1.4.0") - override def transform(dataset: DataFrame): DataFrame = { + @Since("2.0.0") + override def transform(dataset: Dataset[_]): DataFrame = { transformSchema(dataset.schema, logging = true) bestModel.transform(dataset) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala index 0f2179c2a1..12d6905510 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala @@ -20,6 +20,7 @@ package org.apache.spark.ml.tuning import java.util.{List => JList} import scala.collection.JavaConverters._ +import scala.language.existentials import org.apache.hadoop.fs.Path import org.json4s.DefaultFormats @@ -31,7 +32,7 @@ import org.apache.spark.ml.evaluation.Evaluator import org.apache.spark.ml.param.{DoubleParam, ParamMap, ParamValidators} import org.apache.spark.ml.param.shared.HasSeed import org.apache.spark.ml.util._ -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.types.StructType /** @@ -89,8 +90,8 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: St @Since("2.0.0") def setSeed(value: Long): this.type = set(seed, value) - @Since("1.5.0") - override def fit(dataset: DataFrame): TrainValidationSplitModel = { + @Since("2.0.0") + override def fit(dataset: Dataset[_]): TrainValidationSplitModel = { val schema = dataset.schema transformSchema(schema, logging = true) val sqlCtx = dataset.sqlContext @@ -207,8 +208,8 @@ class TrainValidationSplitModel private[ml] ( this(uid, bestModel, validationMetrics.asScala.toArray) } - @Since("1.5.0") - override def transform(dataset: DataFrame): DataFrame = { + @Since("2.0.0") + override def transform(dataset: Dataset[_]): DataFrame = { transformSchema(dataset.schema, logging = true) bestModel.transform(dataset) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala index 0f0c3a2df5..5812cdde2c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala @@ -186,7 +186,7 @@ sealed trait Vector extends Serializable { * :: AlphaComponent :: * * User-defined type for [[Vector]] which allows easy interaction with SQL - * via [[org.apache.spark.sql.DataFrame]]. + * via [[org.apache.spark.sql.Dataset]]. */ @AlphaComponent class VectorUDT extends UserDefinedType[Vector] { diff --git a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala index f3321fb5a1..a8c4ac6d05 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala @@ -31,7 +31,7 @@ import org.apache.spark.ml.param.{IntParam, ParamMap} import org.apache.spark.ml.util._ import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.types.StructType class PipelineSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { @@ -51,6 +51,12 @@ class PipelineSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul val dataset3 = mock[DataFrame] val dataset4 = mock[DataFrame] + when(dataset0.toDF).thenReturn(dataset0) + when(dataset1.toDF).thenReturn(dataset1) + when(dataset2.toDF).thenReturn(dataset2) + when(dataset3.toDF).thenReturn(dataset3) + when(dataset4.toDF).thenReturn(dataset4) + when(estimator0.copy(any[ParamMap])).thenReturn(estimator0) when(model0.copy(any[ParamMap])).thenReturn(model0) when(transformer1.copy(any[ParamMap])).thenReturn(transformer1) @@ -213,7 +219,7 @@ class WritableStage(override val uid: String) extends Transformer with MLWritabl override def write: MLWriter = new DefaultParamsWriter(this) - override def transform(dataset: DataFrame): DataFrame = dataset + override def transform(dataset: Dataset[_]): DataFrame = dataset.toDF override def transformSchema(schema: StructType): StructType = schema } @@ -234,7 +240,7 @@ class UnWritableStage(override val uid: String) extends Transformer { override def copy(extra: ParamMap): UnWritableStage = defaultCopy(extra) - override def transform(dataset: DataFrame): DataFrame = dataset + override def transform(dataset: Dataset[_]): DataFrame = dataset.toDF override def transformSchema(schema: StructType): StructType = schema } diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala index 7eefaf2346..48db428130 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala @@ -29,13 +29,13 @@ import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ -import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.{DataFrame, Dataset, Row} import org.apache.spark.sql.functions.lit class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { - @transient var dataset: DataFrame = _ + @transient var dataset: Dataset[_] = _ @transient var binaryDataset: DataFrame = _ private val eps: Double = 1e-5 diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala index 06ff049b48..80547fad6a 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala @@ -26,12 +26,12 @@ import org.apache.spark.mllib.evaluation.MulticlassMetrics import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ -import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.{DataFrame, Dataset, Row} class MultilayerPerceptronClassifierSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { - @transient var dataset: DataFrame = _ + @transient var dataset: Dataset[_] = _ override def beforeAll(): Unit = { super.beforeAll() diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala index 4727cd436f..80a46fc70c 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala @@ -27,11 +27,11 @@ import org.apache.spark.mllib.classification.NaiveBayesSuite._ import org.apache.spark.mllib.linalg._ import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ -import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.{DataFrame, Dataset, Row} class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { - @transient var dataset: DataFrame = _ + @transient var dataset: Dataset[_] = _ override def beforeAll(): Unit = { super.beforeAll() diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala index 4131396726..f3e8fd11b2 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala @@ -30,12 +30,12 @@ import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.rdd.RDD -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.types.Metadata class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { - @transient var dataset: DataFrame = _ + @transient var dataset: Dataset[_] = _ @transient var rdd: RDD[LabeledPoint] = _ override def beforeAll(): Unit = { @@ -246,7 +246,7 @@ private class MockLogisticRegression(uid: String) extends LogisticRegression(uid setMaxIter(1) - override protected[spark] def train(dataset: DataFrame): LogisticRegressionModel = { + override protected[spark] def train(dataset: Dataset[_]): LogisticRegressionModel = { val labelSchema = dataset.schema($(labelCol)) // check for label attribute propagation. assert(MetadataUtils.getNumClasses(labelSchema).forall(_ == 2)) diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala index 18f2c994b4..e641d79c17 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala @@ -20,13 +20,13 @@ package org.apache.spark.ml.clustering import org.apache.spark.SparkFunSuite import org.apache.spark.ml.util.DefaultReadWriteTest import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.{DataFrame, Dataset} class BisectingKMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { final val k = 5 - @transient var dataset: DataFrame = _ + @transient var dataset: Dataset[_] = _ override def beforeAll(): Unit = { super.beforeAll() diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala index 8edd44e5f1..1a274aea29 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala @@ -20,14 +20,14 @@ package org.apache.spark.ml.clustering import org.apache.spark.SparkFunSuite import org.apache.spark.ml.util.DefaultReadWriteTest import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.{DataFrame, Dataset} class GaussianMixtureSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { final val k = 5 - @transient var dataset: DataFrame = _ + @transient var dataset: Dataset[_] = _ override def beforeAll(): Unit = { super.beforeAll() diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala index c684bc11cc..2076c745e2 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala @@ -22,14 +22,14 @@ import org.apache.spark.ml.util.DefaultReadWriteTest import org.apache.spark.mllib.clustering.{KMeans => MLlibKMeans} import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.sql.{DataFrame, SQLContext} +import org.apache.spark.sql.{DataFrame, Dataset, SQLContext} private[clustering] case class TestRow(features: Vector) class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { final val k = 5 - @transient var dataset: DataFrame = _ + @transient var dataset: Dataset[_] = _ override def beforeAll(): Unit = { super.beforeAll() diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala index a1c93891c7..ee8eae8f69 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala @@ -24,7 +24,7 @@ import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ -import org.apache.spark.sql.{DataFrame, Row, SQLContext} +import org.apache.spark.sql.{DataFrame, Dataset, Row, SQLContext} object LDASuite { @@ -64,7 +64,7 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead val k: Int = 5 val vocabSize: Int = 30 - @transient var dataset: DataFrame = _ + @transient var dataset: Dataset[_] = _ override def beforeAll(): Unit = { super.beforeAll() diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/NGramSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/NGramSuite.scala index 58fda29aa1..e4e15f4331 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/NGramSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/NGramSuite.scala @@ -22,7 +22,7 @@ import scala.beans.BeanInfo import org.apache.spark.SparkFunSuite import org.apache.spark.ml.util.DefaultReadWriteTest import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.{DataFrame, Dataset, Row} @BeanInfo case class NGramTestData(inputTokens: Array[String], wantedNGrams: Array[String]) @@ -92,7 +92,7 @@ class NGramSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRe object NGramSuite extends SparkFunSuite { - def testNGram(t: NGram, dataset: DataFrame): Unit = { + def testNGram(t: NGram, dataset: Dataset[_]): Unit = { t.transform(dataset) .select("nGrams", "wantedNGrams") .collect() diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala index a5b24c1856..3505befdf8 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala @@ -20,10 +20,10 @@ package org.apache.spark.ml.feature import org.apache.spark.SparkFunSuite import org.apache.spark.ml.util.DefaultReadWriteTest import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.{DataFrame, Dataset, Row} object StopWordsRemoverSuite extends SparkFunSuite { - def testStopWordsRemover(t: StopWordsRemover, dataset: DataFrame): Unit = { + def testStopWordsRemover(t: StopWordsRemover, dataset: Dataset[_]): Unit = { t.transform(dataset) .select("filtered", "expected") .collect() diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala index 2c3255ef33..d0f3cdc841 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala @@ -115,7 +115,7 @@ class StringIndexerSuite .setInputCol("label") .setOutputCol("labelIndex") val df = sqlContext.range(0L, 10L).toDF() - assert(indexerModel.transform(df).eq(df)) + assert(indexerModel.transform(df).collect().toSet === df.collect().toSet) } test("StringIndexerModel can't overwrite output column") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala index 36e8e5d868..299f6223b2 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala @@ -23,7 +23,7 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.util.DefaultReadWriteTest import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.{DataFrame, Dataset, Row} @BeanInfo case class TokenizerTestData(rawText: String, wantedTokens: Array[String]) @@ -106,7 +106,7 @@ class RegexTokenizerSuite object RegexTokenizerSuite extends SparkFunSuite { - def testRegexTokenizer(t: RegexTokenizer, dataset: DataFrame): Unit = { + def testRegexTokenizer(t: RegexTokenizer, dataset: Dataset[_]): Unit = { t.transform(dataset) .select("tokens", "wantedTokens") .collect() diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala index 2265464b51..4905f3e068 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala @@ -992,6 +992,14 @@ class GeneralizedLinearRegressionSuite assert(expected.coefficients === actual.coefficients) } } + + test("glm accepts Dataset[LabeledPoint]") { + val context = sqlContext + import context.implicits._ + new GeneralizedLinearRegression() + .setFamily("gaussian") + .fit(datasetGaussianIdentity.as[LabeledPoint]) + } } object GeneralizedLinearRegressionSuite { diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala index 7af3c6d6ed..3e734aabc5 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala @@ -29,13 +29,13 @@ import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInput import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext} -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.types.{StructField, StructType} class CrossValidatorSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { - @transient var dataset: DataFrame = _ + @transient var dataset: Dataset[_] = _ override def beforeAll(): Unit = { super.beforeAll() @@ -311,7 +311,7 @@ object CrossValidatorSuite extends SparkFunSuite { class MyEstimator(override val uid: String) extends Estimator[MyModel] with HasInputCol { - override def fit(dataset: DataFrame): MyModel = { + override def fit(dataset: Dataset[_]): MyModel = { throw new UnsupportedOperationException } @@ -325,7 +325,7 @@ object CrossValidatorSuite extends SparkFunSuite { class MyEvaluator extends Evaluator { - override def evaluate(dataset: DataFrame): Double = { + override def evaluate(dataset: Dataset[_]): Double = { throw new UnsupportedOperationException } diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala index 4030956fab..dbee47c847 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala @@ -28,7 +28,7 @@ import org.apache.spark.ml.util.DefaultReadWriteTest import org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInput import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext} -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.types.StructType class TrainValidationSplitSuite @@ -158,7 +158,7 @@ object TrainValidationSplitSuite { class MyEstimator(override val uid: String) extends Estimator[MyModel] with HasInputCol { - override def fit(dataset: DataFrame): MyModel = { + override def fit(dataset: Dataset[_]): MyModel = { throw new UnsupportedOperationException } @@ -172,7 +172,7 @@ object TrainValidationSplitSuite { class MyEvaluator extends Evaluator { - override def evaluate(dataset: DataFrame): Double = { + override def evaluate(dataset: Dataset[_]): Double = { throw new UnsupportedOperationException } diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala b/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala index 16280473c6..7ebd7eb144 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala @@ -25,7 +25,7 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.param._ import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.{DataFrame, Dataset} trait DefaultReadWriteTest extends TempDirectory { self: Suite => @@ -98,7 +98,7 @@ trait DefaultReadWriteTest extends TempDirectory { self: Suite => def testEstimatorAndModelReadWrite[ E <: Estimator[M] with MLWritable, M <: Model[M] with MLWritable]( estimator: E, - dataset: DataFrame, + dataset: Dataset[_], testParams: Map[String, Any], checkModelData: (M, M) => Unit): Unit = { // Set some Params to make sure set Params are serialized. -- cgit v1.2.3 From 643b4e2257c56338b192f8554e2fe5523bea4bdf Mon Sep 17 00:00:00 2001 From: Zheng RuiFeng Date: Mon, 11 Apr 2016 09:33:52 -0700 Subject: [SPARK-14510][MLLIB] Add args-checking for LDA and StreamingKMeans ## What changes were proposed in this pull request? add the checking for LDA and StreamingKMeans ## How was this patch tested? manual tests Author: Zheng RuiFeng Closes #12062 from zhengruifeng/initmodel. --- .../src/main/scala/org/apache/spark/mllib/clustering/LDA.scala | 10 +++++++--- .../org/apache/spark/mllib/clustering/StreamingKMeans.scala | 10 ++++++++++ 2 files changed, 17 insertions(+), 3 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala index 12813fd412..d999b9be8e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala @@ -130,7 +130,8 @@ class LDA private ( */ @Since("1.5.0") def setDocConcentration(docConcentration: Vector): this.type = { - require(docConcentration.size > 0, "docConcentration must have > 0 elements") + require(docConcentration.size == 1 || docConcentration.size == k, + s"Size of docConcentration must be 1 or ${k} but got ${docConcentration.size}") this.docConcentration = docConcentration this } @@ -260,15 +261,18 @@ class LDA private ( def getCheckpointInterval: Int = checkpointInterval /** - * Period (in iterations) between checkpoints (default = 10). Checkpointing helps with recovery + * Parameter for set checkpoint interval (>= 1) or disable checkpoint (-1). E.g. 10 means that + * the cache will get checkpointed every 10 iterations. Checkpointing helps with recovery * (when nodes fail). It also helps with eliminating temporary shuffle files on disk, which can be * important when LDA is run for many iterations. If the checkpoint directory is not set in - * [[org.apache.spark.SparkContext]], this setting is ignored. + * [[org.apache.spark.SparkContext]], this setting is ignored. (default = 10) * * @see [[org.apache.spark.SparkContext#setCheckpointDir]] */ @Since("1.3.0") def setCheckpointInterval(checkpointInterval: Int): this.type = { + require(checkpointInterval == -1 || checkpointInterval > 0, + s"Period between checkpoints must be -1 or positive but got ${checkpointInterval}") this.checkpointInterval = checkpointInterval this } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala index 4eb8fc049e..24e1cff0dc 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala @@ -218,6 +218,12 @@ class StreamingKMeans @Since("1.2.0") ( */ @Since("1.2.0") def setInitialCenters(centers: Array[Vector], weights: Array[Double]): this.type = { + require(centers.size == weights.size, + "Number of initial centers must be equal to number of weights") + require(centers.size == k, + s"Number of initial centers must be ${k} but got ${centers.size}") + require(weights.forall(_ >= 0), + s"Weight for each inital center must be nonnegative but got [${weights.mkString(" ")}]") model = new StreamingKMeansModel(centers, weights) this } @@ -231,6 +237,10 @@ class StreamingKMeans @Since("1.2.0") ( */ @Since("1.2.0") def setRandomCenters(dim: Int, weight: Double, seed: Long = Utils.random.nextLong): this.type = { + require(dim > 0, + s"Number of dimensions must be positive but got ${dim}") + require(weight >= 0, + s"Weight for each center must be nonnegative but got ${weight}") val random = new XORShiftRandom(seed) val centers = Array.fill(k)(Vectors.dense(Array.fill(dim)(random.nextGaussian()))) val weights = Array.fill(k)(weight) -- cgit v1.2.3 From efaf7d18205f5ae3a1c767942ee7d7320f7410de Mon Sep 17 00:00:00 2001 From: DB Tsai Date: Mon, 11 Apr 2016 09:35:47 -0700 Subject: [SPARK-14462][ML][MLLIB] Add the mllib-local build to maven pom ## What changes were proposed in this pull request? In order to separate the linear algebra, and vector matrix classes into a standalone jar, we need to setup the build first. This PR will create a new jar called mllib-local with minimal dependencies. The previous PR was failing the build because of `spark-core:test` dependency, and that was reverted. In this PR, `FunSuite` with `// scalastyle:ignore funsuite` in mllib-local test was used, similar to sketch. Thanks. ## How was this patch tested? Unit tests mengxr tedyu holdenk Author: DB Tsai Closes #12298 from dbtsai/dbtsai-mllib-local-build-fix. --- dev/sparktestsupport/modules.py | 14 +++- mllib-local/pom.xml | 87 ++++++++++++++++++++++ .../scala/org/apache/spark/ml/DummyTesting.scala | 23 ++++++ .../org/apache/spark/ml/DummyTestingSuite.scala | 28 +++++++ mllib/pom.xml | 12 +++ pom.xml | 1 + project/SparkBuild.scala | 6 +- 7 files changed, 167 insertions(+), 4 deletions(-) create mode 100644 mllib-local/pom.xml create mode 100644 mllib-local/src/main/scala/org/apache/spark/ml/DummyTesting.scala create mode 100644 mllib-local/src/test/scala/org/apache/spark/ml/DummyTestingSuite.scala diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index bb04ec6ee6..c844bcff7e 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -256,9 +256,21 @@ streaming_flume_assembly = Module( ) +mllib_local = Module( + name="mllib-local", + dependencies=[], + source_file_regexes=[ + "mllib-local", + ], + sbt_test_goals=[ + "mllib-local/test", + ] +) + + mllib = Module( name="mllib", - dependencies=[streaming, sql], + dependencies=[mllib_local, streaming, sql], source_file_regexes=[ "data/mllib/", "mllib/", diff --git a/mllib-local/pom.xml b/mllib-local/pom.xml new file mode 100644 index 0000000000..c56561f215 --- /dev/null +++ b/mllib-local/pom.xml @@ -0,0 +1,87 @@ + + + + + 4.0.0 + + org.apache.spark + spark-parent_2.11 + 2.0.0-SNAPSHOT + ../pom.xml + + + org.apache.spark + spark-mllib-local_2.11 + + mllib-local + + jar + Spark Project ML Local Library + http://spark.apache.org/ + + + + org.scalanlp + breeze_${scala.binary.version} + 0.11.2 + + + + junit + junit + + + org.apache.commons + commons-math3 + + + + + org.apache.commons + commons-math3 + + + org.scalacheck + scalacheck_${scala.binary.version} + test + + + org.mockito + mockito-core + test + + + + + netlib-lgpl + + + com.github.fommil.netlib + all + ${netlib.java.version} + pom + + + + + + target/scala-${scala.binary.version}/classes + target/scala-${scala.binary.version}/test-classes + + diff --git a/mllib-local/src/main/scala/org/apache/spark/ml/DummyTesting.scala b/mllib-local/src/main/scala/org/apache/spark/ml/DummyTesting.scala new file mode 100644 index 0000000000..6b3268cdfa --- /dev/null +++ b/mllib-local/src/main/scala/org/apache/spark/ml/DummyTesting.scala @@ -0,0 +1,23 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml + +// This is a private class testing if the new build works. To be removed soon. +private[ml] object DummyTesting { + private[ml] def add10(input: Double): Double = input + 10 +} diff --git a/mllib-local/src/test/scala/org/apache/spark/ml/DummyTestingSuite.scala b/mllib-local/src/test/scala/org/apache/spark/ml/DummyTestingSuite.scala new file mode 100644 index 0000000000..51b7c2409f --- /dev/null +++ b/mllib-local/src/test/scala/org/apache/spark/ml/DummyTestingSuite.scala @@ -0,0 +1,28 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml + +import org.scalatest.FunSuite // scalastyle:ignore funsuite + +// This is testing if the new build works. To be removed soon. +class DummyTestingSuite extends FunSuite { // scalastyle:ignore funsuite + + test("This is testing if the new build works.") { + assert(DummyTesting.add10(15) === 25) + } +} diff --git a/mllib/pom.xml b/mllib/pom.xml index 428176dcbf..e56eafc300 100644 --- a/mllib/pom.xml +++ b/mllib/pom.xml @@ -62,6 +62,18 @@ spark-graphx_${scala.binary.version} ${project.version}
    + + org.apache.spark + spark-mllib-local_${scala.binary.version} + ${project.version} + + + org.apache.spark + spark-mllib-local_${scala.binary.version} + ${project.version} + test-jar + test + org.scalanlp breeze_${scala.binary.version} diff --git a/pom.xml b/pom.xml index 4cbc6a2f11..38843b4f74 100644 --- a/pom.xml +++ b/pom.xml @@ -94,6 +94,7 @@ core graphx mllib + mllib-local tools streaming sql/catalyst diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 60124ef0a1..c5688ecec6 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -47,9 +47,9 @@ object BuildCommons { ).map(ProjectRef(buildLocation, _)) val allProjects@Seq( - core, graphx, mllib, repl, networkCommon, networkShuffle, launcher, unsafe, testTags, sketch, _* + core, graphx, mllib, mllibLocal, repl, networkCommon, networkShuffle, launcher, unsafe, testTags, sketch, _* ) = Seq( - "core", "graphx", "mllib", "repl", "network-common", "network-shuffle", "launcher", "unsafe", + "core", "graphx", "mllib", "mllib-local", "repl", "network-common", "network-shuffle", "launcher", "unsafe", "test-tags", "sketch" ).map(ProjectRef(buildLocation, _)) ++ sqlProjects ++ streamingProjects @@ -254,7 +254,7 @@ object SparkBuild extends PomBuild { val mimaProjects = allProjects.filterNot { x => Seq( spark, hive, hiveThriftServer, catalyst, repl, networkCommon, networkShuffle, networkYarn, - unsafe, testTags, sketch + unsafe, testTags, sketch, mllibLocal ).contains(x) } -- cgit v1.2.3 From 652c4703099c1e5b17732fa019318358dc99ad50 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Mon, 11 Apr 2016 09:43:16 -0700 Subject: [SPARK-14528] [SQL] Fix same result of Union ## What changes were proposed in this pull request? This PR fix resultResult() for Union. ## How was this patch tested? Added regression test. Author: Davies Liu Closes #12295 from davies/fix_sameResult. --- .../main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala | 7 +++---- .../org/apache/spark/sql/catalyst/plans/SameResultSuite.scala | 7 ++++++- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index 0a11574f44..d4447ca32d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -312,18 +312,17 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT /** Args that have cleaned such that differences in expression id should not affect equality */ protected lazy val cleanArgs: Seq[Any] = { def cleanArg(arg: Any): Any = arg match { + // Children are checked using sameResult above. + case tn: TreeNode[_] if containsChild(tn) => null case e: Expression => cleanExpression(e).canonicalized case other => other } productIterator.map { - // Children are checked using sameResult above. - case tn: TreeNode[_] if containsChild(tn) => null - case e: Expression => cleanArg(e) case s: Option[_] => s.map(cleanArg) case s: Seq[_] => s.map(cleanArg) case m: Map[_, _] => m.mapValues(cleanArg) - case other => other + case other => cleanArg(other) }.toSeq } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/SameResultSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/SameResultSuite.scala index 37941cf34e..467f76193c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/SameResultSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/SameResultSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.plans import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ -import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, Union} import org.apache.spark.sql.catalyst.util._ /** @@ -61,4 +61,9 @@ class SameResultSuite extends SparkFunSuite { test("sorts") { assertSameResult(testRelation.orderBy('a.asc), testRelation2.orderBy('a.asc)) } + + test("union") { + assertSameResult(Union(Seq(testRelation, testRelation2)), + Union(Seq(testRelation2, testRelation))) + } } -- cgit v1.2.3 From 5de26194a3aaeab9b7a8323107f614126c90441f Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Mon, 11 Apr 2016 09:52:50 -0700 Subject: [SPARK-14502] [SQL] Add optimization for Binary Comparison Simplification ## What changes were proposed in this pull request? We can simplifies binary comparisons with semantically-equal operands: 1. Replace '<=>' with 'true' literal. 2. Replace '=', '<=', and '>=' with 'true' literal if both operands are non-nullable. 3. Replace '<' and '>' with 'false' literal if both operands are non-nullable. For example, the following example plan ``` scala> sql("SELECT * FROM (SELECT explode(array(1,2,3)) a) T WHERE a BETWEEN a AND a+7").explain() ... : +- Filter ((a#59 >= a#59) && (a#59 <= (a#59 + 7))) ... ``` will be optimized into the following. ``` : +- Filter (a#47 <= (a#47 + 7)) ``` ## How was this patch tested? Pass the Jenkins tests including new `BinaryComparisonSimplificationSuite`. Author: Dongjoon Hyun Closes #12267 from dongjoon-hyun/SPARK-14502. --- .../spark/sql/catalyst/optimizer/Optimizer.scala | 24 ++++++ .../BinaryComparisonSimplificationSuite.scala | 95 ++++++++++++++++++++++ 2 files changed, 119 insertions(+) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BinaryComparisonSimplificationSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 619514e8aa..bad115d22f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -86,6 +86,7 @@ abstract class Optimizer extends RuleExecutor[LogicalPlan] { BooleanSimplification, SimplifyConditionals, RemoveDispensableExpressions, + BinaryComparisonSimplification, PruneFilters, EliminateSorts, SimplifyCasts, @@ -786,6 +787,29 @@ object BooleanSimplification extends Rule[LogicalPlan] with PredicateHelper { } } +/** + * Simplifies binary comparisons with semantically-equal expressions: + * 1) Replace '<=>' with 'true' literal. + * 2) Replace '=', '<=', and '>=' with 'true' literal if both operands are non-nullable. + * 3) Replace '<' and '>' with 'false' literal if both operands are non-nullable. + */ +object BinaryComparisonSimplification extends Rule[LogicalPlan] with PredicateHelper { + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case q: LogicalPlan => q transformExpressionsUp { + // True with equality + case a EqualNullSafe b if a.semanticEquals(b) => TrueLiteral + case a EqualTo b if !a.nullable && !b.nullable && a.semanticEquals(b) => TrueLiteral + case a GreaterThanOrEqual b if !a.nullable && !b.nullable && a.semanticEquals(b) => + TrueLiteral + case a LessThanOrEqual b if !a.nullable && !b.nullable && a.semanticEquals(b) => TrueLiteral + + // False with inequality + case a GreaterThan b if !a.nullable && !b.nullable && a.semanticEquals(b) => FalseLiteral + case a LessThan b if !a.nullable && !b.nullable && a.semanticEquals(b) => FalseLiteral + } + } +} + /** * Simplifies conditional expressions (if / case). */ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BinaryComparisonSimplificationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BinaryComparisonSimplificationSuite.scala new file mode 100644 index 0000000000..7cd038570b --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BinaryComparisonSimplificationSuite.scala @@ -0,0 +1,95 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.analysis._ +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral} +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.rules._ + +class BinaryComparisonSimplificationSuite extends PlanTest with PredicateHelper { + + object Optimize extends RuleExecutor[LogicalPlan] { + val batches = + Batch("AnalysisNodes", Once, + EliminateSubqueryAliases) :: + Batch("Constant Folding", FixedPoint(50), + NullPropagation, + ConstantFolding, + BooleanSimplification, + BinaryComparisonSimplification, + PruneFilters) :: Nil + } + + val nullableRelation = LocalRelation('a.int.withNullability(true)) + val nonNullableRelation = LocalRelation('a.int.withNullability(false)) + + test("Preserve nullable exprs in general") { + for (e <- Seq('a === 'a, 'a <= 'a, 'a >= 'a, 'a < 'a, 'a > 'a)) { + val plan = nullableRelation.where(e).analyze + val actual = Optimize.execute(plan) + val correctAnswer = plan + comparePlans(actual, correctAnswer) + } + } + + test("Preserve non-deterministic exprs") { + val plan = nonNullableRelation + .where(Rand(0) === Rand(0) && Rand(1) <=> Rand(1)).analyze + val actual = Optimize.execute(plan) + val correctAnswer = plan + comparePlans(actual, correctAnswer) + } + + test("Nullable Simplification Primitive: <=>") { + val plan = nullableRelation.select('a <=> 'a).analyze + val actual = Optimize.execute(plan) + val correctAnswer = nullableRelation.select(Alias(TrueLiteral, "(a <=> a)")()).analyze + comparePlans(actual, correctAnswer) + } + + test("Non-Nullable Simplification Primitive") { + val plan = nonNullableRelation + .select('a === 'a, 'a <=> 'a, 'a <= 'a, 'a >= 'a, 'a < 'a, 'a > 'a).analyze + val actual = Optimize.execute(plan) + val correctAnswer = nonNullableRelation + .select( + Alias(TrueLiteral, "(a = a)")(), + Alias(TrueLiteral, "(a <=> a)")(), + Alias(TrueLiteral, "(a <= a)")(), + Alias(TrueLiteral, "(a >= a)")(), + Alias(FalseLiteral, "(a < a)")(), + Alias(FalseLiteral, "(a > a)")()) + .analyze + comparePlans(actual, correctAnswer) + } + + test("Expression Normalization") { + val plan = nonNullableRelation.where( + 'a * Literal(100) + Pi() === Pi() + Literal(100) * 'a && + DateAdd(CurrentDate(), 'a + Literal(2)) <= DateAdd(CurrentDate(), Literal(2) + 'a)) + .analyze + val actual = Optimize.execute(plan) + val correctAnswer = nonNullableRelation.analyze + comparePlans(actual, correctAnswer) + } +} -- cgit v1.2.3 From 2dacc81ec31233e558855a26340ad4662d470387 Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Mon, 11 Apr 2016 10:42:51 -0700 Subject: [SPARK-14494][SQL] Fix the race conditions in MemoryStream and MemorySink ## What changes were proposed in this pull request? Make sure accessing mutable variables in MemoryStream and MemorySink are protected by `synchronized`. This is probably why MemorySinkSuite failed here: https://amplab.cs.berkeley.edu/jenkins/job/spark-master-test-sbt-hadoop-2.2/650/testReport/junit/org.apache.spark.sql.streaming/MemorySinkSuite/registering_as_a_table/ ## How was this patch tested? Existing unit tests. Author: Shixiong Zhu Closes #12261 from zsxwing/memory-race-condition. --- .../spark/sql/execution/streaming/memory.scala | 25 ++++++++++++++-------- 1 file changed, 16 insertions(+), 9 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala index 351ef404a8..3820968324 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution.streaming import java.util.concurrent.atomic.AtomicInteger +import javax.annotation.concurrent.GuardedBy import scala.collection.mutable.ArrayBuffer import scala.util.control.NonFatal @@ -47,8 +48,11 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) protected val encoder = encoderFor[A] protected val logicalPlan = StreamingExecutionRelation(this) protected val output = logicalPlan.output + + @GuardedBy("this") protected val batches = new ArrayBuffer[Dataset[A]] + @GuardedBy("this") protected var currentOffset: LongOffset = new LongOffset(-1) def schema: StructType = encoder.schema @@ -67,10 +71,10 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) def addData(data: TraversableOnce[A]): Offset = { import sqlContext.implicits._ + val ds = data.toVector.toDS() + logDebug(s"Adding ds: $ds") this.synchronized { currentOffset = currentOffset + 1 - val ds = data.toVector.toDS() - logDebug(s"Adding ds: $ds") batches.append(ds) currentOffset } @@ -78,10 +82,12 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) override def toString: String = s"MemoryStream[${output.mkString(",")}]" - override def getOffset: Option[Offset] = if (batches.isEmpty) { - None - } else { - Some(currentOffset) + override def getOffset: Option[Offset] = synchronized { + if (batches.isEmpty) { + None + } else { + Some(currentOffset) + } } /** @@ -91,7 +97,7 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) val startOrdinal = start.map(_.asInstanceOf[LongOffset]).getOrElse(LongOffset(-1)).offset.toInt + 1 val endOrdinal = end.asInstanceOf[LongOffset].offset.toInt + 1 - val newBlocks = batches.slice(startOrdinal, endOrdinal) + val newBlocks = synchronized { batches.slice(startOrdinal, endOrdinal) } logDebug( s"MemoryBatch [$startOrdinal, $endOrdinal]: ${newBlocks.flatMap(_.collect()).mkString(", ")}") @@ -110,6 +116,7 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) */ class MemorySink(val schema: StructType) extends Sink with Logging { /** An order list of batches that have been written to this [[Sink]]. */ + @GuardedBy("this") private val batches = new ArrayBuffer[Array[Row]]() /** Returns all rows that are stored in this [[Sink]]. */ @@ -117,7 +124,7 @@ class MemorySink(val schema: StructType) extends Sink with Logging { batches.flatten } - def lastBatch: Seq[Row] = batches.last + def lastBatch: Seq[Row] = synchronized { batches.last } def toDebugString: String = synchronized { batches.zipWithIndex.map { case (b, i) => @@ -128,7 +135,7 @@ class MemorySink(val schema: StructType) extends Sink with Logging { }.mkString("\n") } - override def addBatch(batchId: Long, data: DataFrame): Unit = { + override def addBatch(batchId: Long, data: DataFrame): Unit = synchronized { if (batchId == batches.size) { logDebug(s"Committing batch $batchId") batches.append(data.collect()) -- cgit v1.2.3 From 89a41c5b7a3f727b44a7f615a1352ca006d12f73 Mon Sep 17 00:00:00 2001 From: Oliver Pierson Date: Mon, 11 Apr 2016 12:02:48 -0700 Subject: [SPARK-13600][MLLIB] Use approxQuantile from DataFrame stats in QuantileDiscretizer ## What changes were proposed in this pull request? QuantileDiscretizer can return an unexpected number of buckets in certain cases. This PR proposes to fix this issue and also refactor QuantileDiscretizer to use approxQuantiles from DataFrame stats functions. ## How was this patch tested? QuantileDiscretizerSuite unit tests (some existing tests will change or even be removed in this PR) Author: Oliver Pierson Closes #11553 from oliverpierson/SPARK-13600. --- .../spark/ml/feature/QuantileDiscretizer.scala | 119 +++++---------------- .../ml/feature/QuantileDiscretizerSuite.scala | 115 +++++++------------- 2 files changed, 65 insertions(+), 169 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala index efe8b93d82..5c7993af64 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala @@ -37,7 +37,7 @@ private[feature] trait QuantileDiscretizerBase extends Params with HasInputCol with HasOutputCol with HasSeed { /** - * Maximum number of buckets (quantiles, or categories) into which data points are grouped. Must + * Number of buckets (quantiles, or categories) into which data points are grouped. Must * be >= 2. * default: 2 * @group param @@ -49,6 +49,21 @@ private[feature] trait QuantileDiscretizerBase extends Params /** @group getParam */ def getNumBuckets: Int = getOrDefault(numBuckets) + + /** + * Relative error (see documentation for + * [[org.apache.spark.sql.DataFrameStatFunctions.approxQuantile approxQuantile]] for description) + * Must be a number in [0, 1]. + * default: 0.001 + * @group param + */ + val relativeError = new DoubleParam(this, "relativeError", "The relative target precision " + + "for approxQuantile", + ParamValidators.inRange(0.0, 1.0)) + setDefault(relativeError -> 0.001) + + /** @group getParam */ + def getRelativeError: Double = getOrDefault(relativeError) } /** @@ -56,8 +71,7 @@ private[feature] trait QuantileDiscretizerBase extends Params * `QuantileDiscretizer` takes a column with continuous features and outputs a column with binned * categorical features. The bin ranges are chosen by taking a sample of the data and dividing it * into roughly equal parts. The lower and upper bin bounds will be -Infinity and +Infinity, - * covering all real values. This attempts to find numBuckets partitions based on a sample of data, - * but it may find fewer depending on the data sample values. + * covering all real values. */ @Experimental final class QuantileDiscretizer(override val uid: String) @@ -65,6 +79,9 @@ final class QuantileDiscretizer(override val uid: String) def this() = this(Identifiable.randomUID("quantileDiscretizer")) + /** @group setParam */ + def setRelativeError(value: Double): this.type = set(relativeError, value) + /** @group setParam */ def setNumBuckets(value: Int): this.type = set(numBuckets, value) @@ -89,11 +106,11 @@ final class QuantileDiscretizer(override val uid: String) @Since("2.0.0") override def fit(dataset: Dataset[_]): Bucketizer = { - val samples = QuantileDiscretizer - .getSampledInput(dataset.select($(inputCol)), $(numBuckets), $(seed)) - .map { case Row(feature: Double) => feature } - val candidates = QuantileDiscretizer.findSplitCandidates(samples, $(numBuckets) - 1) - val splits = QuantileDiscretizer.getSplits(candidates) + val splits = dataset.stat.approxQuantile($(inputCol), + (0.0 to 1.0 by 1.0/$(numBuckets)).toArray, $(relativeError)) + splits(0) = Double.NegativeInfinity + splits(splits.length - 1) = Double.PositiveInfinity + val bucketizer = new Bucketizer(uid).setSplits(splits) copyValues(bucketizer.setParent(this)) } @@ -104,92 +121,6 @@ final class QuantileDiscretizer(override val uid: String) @Since("1.6.0") object QuantileDiscretizer extends DefaultParamsReadable[QuantileDiscretizer] with Logging { - /** - * Minimum number of samples required for finding splits, regardless of number of bins. If - * the dataset has fewer rows than this value, the entire dataset will be used. - */ - private[spark] val minSamplesRequired: Int = 10000 - - /** - * Sampling from the given dataset to collect quantile statistics. - */ - private[feature] - def getSampledInput(dataset: Dataset[_], numBins: Int, seed: Long): Array[Row] = { - val totalSamples = dataset.count() - require(totalSamples > 0, - "QuantileDiscretizer requires non-empty input dataset but was given an empty input.") - val requiredSamples = math.max(numBins * numBins, minSamplesRequired) - val fraction = math.min(requiredSamples.toDouble / totalSamples, 1.0) - dataset.toDF.sample(withReplacement = false, fraction, new XORShiftRandom(seed).nextInt()) - .collect() - } - - /** - * Compute split points with respect to the sample distribution. - */ - private[feature] - def findSplitCandidates(samples: Array[Double], numSplits: Int): Array[Double] = { - val valueCountMap = samples.foldLeft(Map.empty[Double, Int]) { (m, x) => - m + ((x, m.getOrElse(x, 0) + 1)) - } - val valueCounts = valueCountMap.toSeq.sortBy(_._1).toArray ++ Array((Double.MaxValue, 1)) - val possibleSplits = valueCounts.length - 1 - if (possibleSplits <= numSplits) { - valueCounts.dropRight(1).map(_._1) - } else { - val stride: Double = math.ceil(samples.length.toDouble / (numSplits + 1)) - val splitsBuilder = mutable.ArrayBuilder.make[Double] - var index = 1 - // currentCount: sum of counts of values that have been visited - var currentCount = valueCounts(0)._2 - // targetCount: target value for `currentCount`. If `currentCount` is closest value to - // `targetCount`, then current value is a split threshold. After finding a split threshold, - // `targetCount` is added by stride. - var targetCount = stride - while (index < valueCounts.length) { - val previousCount = currentCount - currentCount += valueCounts(index)._2 - val previousGap = math.abs(previousCount - targetCount) - val currentGap = math.abs(currentCount - targetCount) - // If adding count of current value to currentCount makes the gap between currentCount and - // targetCount smaller, previous value is a split threshold. - if (previousGap < currentGap) { - splitsBuilder += valueCounts(index - 1)._1 - targetCount += stride - } - index += 1 - } - splitsBuilder.result() - } - } - - /** - * Adjust split candidates to proper splits by: adding positive/negative infinity to both sides as - * needed, and adding a default split value of 0 if no good candidates are found. - */ - private[feature] def getSplits(candidates: Array[Double]): Array[Double] = { - val effectiveValues = if (candidates.nonEmpty) { - if (candidates.head == Double.NegativeInfinity - && candidates.last == Double.PositiveInfinity) { - candidates.drop(1).dropRight(1) - } else if (candidates.head == Double.NegativeInfinity) { - candidates.drop(1) - } else if (candidates.last == Double.PositiveInfinity) { - candidates.dropRight(1) - } else { - candidates - } - } else { - candidates - } - - if (effectiveValues.isEmpty) { - Array(Double.NegativeInfinity, 0, Double.PositiveInfinity) - } else { - Array(Double.NegativeInfinity) ++ effectiveValues ++ Array(Double.PositiveInfinity) - } - } - @Since("1.6.0") override def load(path: String): QuantileDiscretizer = super.load(path) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala index 25fabf64d5..8895d630a0 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala @@ -17,78 +17,60 @@ package org.apache.spark.ml.feature -import org.apache.spark.{SparkContext, SparkFunSuite} -import org.apache.spark.ml.attribute.{Attribute, NominalAttribute} +import org.apache.spark.SparkFunSuite import org.apache.spark.ml.util.DefaultReadWriteTest import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.sql.{Row, SQLContext} +import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.functions.udf class QuantileDiscretizerSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { - import org.apache.spark.ml.feature.QuantileDiscretizerSuite._ - - test("Test quantile discretizer") { - checkDiscretizedData(sc, - Array[Double](1, 2, 3, 3, 3, 3, 3, 3, 3), - 10, - Array[Double](1, 2, 3, 3, 3, 3, 3, 3, 3), - Array("-Infinity, 1.0", "1.0, 2.0", "2.0, 3.0", "3.0, Infinity")) - - checkDiscretizedData(sc, - Array[Double](1, 2, 3, 3, 3, 3, 3, 3, 3), - 4, - Array[Double](1, 2, 3, 3, 3, 3, 3, 3, 3), - Array("-Infinity, 1.0", "1.0, 2.0", "2.0, 3.0", "3.0, Infinity")) - - checkDiscretizedData(sc, - Array[Double](1, 2, 3, 3, 3, 3, 3, 3, 3), - 3, - Array[Double](0, 1, 2, 2, 2, 2, 2, 2, 2), - Array("-Infinity, 2.0", "2.0, 3.0", "3.0, Infinity")) + test("Test observed number of buckets and their sizes match expected values") { + val sqlCtx = SQLContext.getOrCreate(sc) + import sqlCtx.implicits._ - checkDiscretizedData(sc, - Array[Double](1, 2, 3, 3, 3, 3, 3, 3, 3), - 2, - Array[Double](0, 1, 1, 1, 1, 1, 1, 1, 1), - Array("-Infinity, 2.0", "2.0, Infinity")) + val datasetSize = 100000 + val numBuckets = 5 + val df = sc.parallelize(1.0 to datasetSize by 1.0).map(Tuple1.apply).toDF("input") + val discretizer = new QuantileDiscretizer() + .setInputCol("input") + .setOutputCol("result") + .setNumBuckets(numBuckets) + val result = discretizer.fit(df).transform(df) - } + val observedNumBuckets = result.select("result").distinct.count + assert(observedNumBuckets === numBuckets, + "Observed number of buckets does not equal expected number of buckets.") - test("Test getting splits") { - val splitTestPoints = Array( - Array[Double]() -> Array(Double.NegativeInfinity, 0, Double.PositiveInfinity), - Array(Double.NegativeInfinity) -> Array(Double.NegativeInfinity, 0, Double.PositiveInfinity), - Array(Double.PositiveInfinity) -> Array(Double.NegativeInfinity, 0, Double.PositiveInfinity), - Array(Double.NegativeInfinity, Double.PositiveInfinity) - -> Array(Double.NegativeInfinity, 0, Double.PositiveInfinity), - Array(0.0) -> Array(Double.NegativeInfinity, 0, Double.PositiveInfinity), - Array(1.0) -> Array(Double.NegativeInfinity, 1, Double.PositiveInfinity), - Array(0.0, 1.0) -> Array(Double.NegativeInfinity, 0, 1, Double.PositiveInfinity) - ) - for ((ori, res) <- splitTestPoints) { - assert(QuantileDiscretizer.getSplits(ori) === res, "Returned splits are invalid.") + val relativeError = discretizer.getRelativeError + val isGoodBucket = udf { + (size: Int) => math.abs( size - (datasetSize / numBuckets)) <= (relativeError * datasetSize) } + val numGoodBuckets = result.groupBy("result").count.filter(isGoodBucket($"count")).count + assert(numGoodBuckets === numBuckets, + "Bucket sizes are not within expected relative error tolerance.") } - test("Test splits on dataset larger than minSamplesRequired") { + test("Test transform method on unseen data") { val sqlCtx = SQLContext.getOrCreate(sc) import sqlCtx.implicits._ - val datasetSize = QuantileDiscretizer.minSamplesRequired + 1 - val numBuckets = 5 - val df = sc.parallelize((1.0 to datasetSize by 1.0).map(Tuple1.apply)).toDF("input") + val trainDF = sc.parallelize(1.0 to 100.0 by 1.0).map(Tuple1.apply).toDF("input") + val testDF = sc.parallelize(-10.0 to 110.0 by 1.0).map(Tuple1.apply).toDF("input") val discretizer = new QuantileDiscretizer() .setInputCol("input") .setOutputCol("result") - .setNumBuckets(numBuckets) - .setSeed(1) + .setNumBuckets(5) - val result = discretizer.fit(df).transform(df) - val observedNumBuckets = result.select("result").distinct.count + val result = discretizer.fit(trainDF).transform(testDF) + val firstBucketSize = result.filter(result("result") === 0.0).count + val lastBucketSize = result.filter(result("result") === 4.0).count - assert(observedNumBuckets === numBuckets, - "Observed number of buckets does not equal expected number of buckets.") + assert(firstBucketSize === 30L, + s"Size of first bucket ${firstBucketSize} did not equal expected value of 30.") + assert(lastBucketSize === 31L, + s"Size of last bucket ${lastBucketSize} did not equal expected value of 31.") } test("read/write") { @@ -98,34 +80,17 @@ class QuantileDiscretizerSuite .setNumBuckets(6) testDefaultReadWrite(t) } -} - -private object QuantileDiscretizerSuite extends SparkFunSuite { - def checkDiscretizedData( - sc: SparkContext, - data: Array[Double], - numBucket: Int, - expectedResult: Array[Double], - expectedAttrs: Array[String]): Unit = { + test("Verify resulting model has parent") { val sqlCtx = SQLContext.getOrCreate(sc) import sqlCtx.implicits._ - val df = sc.parallelize(data.map(Tuple1.apply)).toDF("input") - val discretizer = new QuantileDiscretizer().setInputCol("input").setOutputCol("result") - .setNumBuckets(numBucket).setSeed(1) + val df = sc.parallelize(1 to 100).map(Tuple1.apply).toDF("input") + val discretizer = new QuantileDiscretizer() + .setInputCol("input") + .setOutputCol("result") + .setNumBuckets(5) val model = discretizer.fit(df) assert(model.hasParent) - val result = model.transform(df) - - val transformedFeatures = result.select("result").collect() - .map { case Row(transformedFeature: Double) => transformedFeature } - val transformedAttrs = Attribute.fromStructField(result.schema("result")) - .asInstanceOf[NominalAttribute].values.get - - assert(transformedFeatures === expectedResult, - "Transformed features do not equal expected features.") - assert(transformedAttrs === expectedAttrs, - "Transformed attributes do not equal expected attributes.") } } -- cgit v1.2.3 From 3f0f40800bda98026f8a5e45c0a4ae2600c2d693 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Mon, 11 Apr 2016 14:01:05 -0700 Subject: [SPARK-14298][ML][MLLIB] Add unit test for EM LDA disable checkpointing ## What changes were proposed in this pull request? This is follow up for #12089, add unit test for EM LDA which test disable checkpointing when set ```checkpointInterval = -1```. ## How was this patch tested? unit test. cc jkbradley Author: Yanbo Liang Closes #12286 from yanboliang/spark-14298-followup. --- .../test/scala/org/apache/spark/ml/clustering/LDASuite.scala | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala index ee8eae8f69..17d6e9fc2e 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala @@ -289,4 +289,15 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead assert(model.getCheckpointFiles.isEmpty) } + + test("EM LDA disable checkpointing") { + // Checkpoint dir is set by MLlibTestSparkContext + val lda = new LDA().setK(2).setSeed(1).setOptimizer("em").setMaxIter(3) + .setCheckpointInterval(-1) + val model_ = lda.fit(dataset) + assert(model_.isInstanceOf[DistributedLDAModel]) + val model = model_.asInstanceOf[DistributedLDAModel] + + assert(model.getCheckpointFiles.isEmpty) + } } -- cgit v1.2.3 From 94de63053ecd709f44213d09bb43a8b2c5a8b4bb Mon Sep 17 00:00:00 2001 From: Luciano Resende Date: Mon, 11 Apr 2016 16:40:45 -0700 Subject: [SPARK-10521][SQL] Utilize Docker for test DB2 JDBC Dialect support Add integration tests based on docker to test DB2 JDBC dialect support Author: Luciano Resende Closes #9893 from lresende/SPARK-10521. --- external/docker-integration-tests/pom.xml | 30 ++++ .../spark/sql/jdbc/DB2IntegrationSuite.scala | 157 +++++++++++++++++++++ .../sql/jdbc/DockerJDBCIntegrationSuite.scala | 20 ++- .../spark/sql/jdbc/MySQLIntegrationSuite.scala | 2 + .../spark/sql/jdbc/OracleIntegrationSuite.scala | 2 + .../spark/sql/jdbc/PostgresIntegrationSuite.scala | 2 + pom.xml | 2 +- project/SparkBuild.scala | 4 +- 8 files changed, 215 insertions(+), 4 deletions(-) create mode 100644 external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/DB2IntegrationSuite.scala diff --git a/external/docker-integration-tests/pom.xml b/external/docker-integration-tests/pom.xml index 1764aa9465..17fd7d781c 100644 --- a/external/docker-integration-tests/pom.xml +++ b/external/docker-integration-tests/pom.xml @@ -34,6 +34,13 @@ docker-integration-tests + + + db2 + https://app.camunda.com/nexus/content/repositories/public/ + + + com.spotify @@ -180,5 +187,28 @@ + + + + com.ibm.db2.jcc + db2jcc4 + 10.5.0.5 + jar + diff --git a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/DB2IntegrationSuite.scala b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/DB2IntegrationSuite.scala new file mode 100644 index 0000000000..4fe1ef6697 --- /dev/null +++ b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/DB2IntegrationSuite.scala @@ -0,0 +1,157 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.jdbc + +import java.math.BigDecimal +import java.sql.{Connection, Date, Timestamp} +import java.util.Properties + +import org.scalatest._ + +import org.apache.spark.tags.DockerTest + +@DockerTest +@Ignore // AMPLab Jenkins needs to be updated before shared memory works on docker +class DB2IntegrationSuite extends DockerJDBCIntegrationSuite { + override val db = new DatabaseOnDocker { + override val imageName = "lresende/db2express-c:10.5.0.5-3.10.0" + override val env = Map( + "DB2INST1_PASSWORD" -> "rootpass", + "LICENSE" -> "accept" + ) + override val usesIpc = true + override val jdbcPort: Int = 50000 + override def getJdbcUrl(ip: String, port: Int): String = + s"jdbc:db2://$ip:$port/foo:user=db2inst1;password=rootpass;" + override def getStartupProcessName: Option[String] = Some("db2start") + } + + override def dataPreparation(conn: Connection): Unit = { + conn.prepareStatement("CREATE TABLE tbl (x INTEGER, y VARCHAR(8))").executeUpdate() + conn.prepareStatement("INSERT INTO tbl VALUES (42,'fred')").executeUpdate() + conn.prepareStatement("INSERT INTO tbl VALUES (17,'dave')").executeUpdate() + + conn.prepareStatement("CREATE TABLE numbers (onebit BIT(1), tenbits BIT(10), " + + "small SMALLINT, med MEDIUMINT, nor INT, big BIGINT, deci DECIMAL(40,20), flt FLOAT, " + + "dbl DOUBLE)").executeUpdate() + conn.prepareStatement("INSERT INTO numbers VALUES (b'0', b'1000100101', " + + "17, 77777, 123456789, 123456789012345, 123456789012345.123456789012345, " + + "42.75, 1.0000000000000002)").executeUpdate() + + conn.prepareStatement("CREATE TABLE dates (d DATE, t TIME, dt DATETIME, ts TIMESTAMP, " + + "yr YEAR)").executeUpdate() + conn.prepareStatement("INSERT INTO dates VALUES ('1991-11-09', '13:31:24', " + + "'1996-01-01 01:23:45', '2009-02-13 23:31:30', '2001')").executeUpdate() + + // TODO: Test locale conversion for strings. + conn.prepareStatement("CREATE TABLE strings (a CHAR(10), b VARCHAR(10), c CLOB, d BLOB, " + + "e CHAR FOR BIT DATA)").executeUpdate() + conn.prepareStatement("INSERT INTO strings VALUES ('the', 'quick', 'brown', 'fox', 'jumps'") + .executeUpdate() + } + + test("Basic test") { + val df = sqlContext.read.jdbc(jdbcUrl, "tbl", new Properties) + val rows = df.collect() + assert(rows.length == 2) + val types = rows(0).toSeq.map(x => x.getClass.toString) + assert(types.length == 2) + assert(types(0).equals("class java.lang.Integer")) + assert(types(1).equals("class java.lang.String")) + } + + test("Numeric types") { + val df = sqlContext.read.jdbc(jdbcUrl, "numbers", new Properties) + val rows = df.collect() + assert(rows.length == 1) + val types = rows(0).toSeq.map(x => x.getClass.toString) + assert(types.length == 9) + assert(types(0).equals("class java.lang.Boolean")) + assert(types(1).equals("class java.lang.Long")) + assert(types(2).equals("class java.lang.Integer")) + assert(types(3).equals("class java.lang.Integer")) + assert(types(4).equals("class java.lang.Integer")) + assert(types(5).equals("class java.lang.Long")) + assert(types(6).equals("class java.math.BigDecimal")) + assert(types(7).equals("class java.lang.Double")) + assert(types(8).equals("class java.lang.Double")) + assert(rows(0).getBoolean(0) == false) + assert(rows(0).getLong(1) == 0x225) + assert(rows(0).getInt(2) == 17) + assert(rows(0).getInt(3) == 77777) + assert(rows(0).getInt(4) == 123456789) + assert(rows(0).getLong(5) == 123456789012345L) + val bd = new BigDecimal("123456789012345.12345678901234500000") + assert(rows(0).getAs[BigDecimal](6).equals(bd)) + assert(rows(0).getDouble(7) == 42.75) + assert(rows(0).getDouble(8) == 1.0000000000000002) + } + + test("Date types") { + val df = sqlContext.read.jdbc(jdbcUrl, "dates", new Properties) + val rows = df.collect() + assert(rows.length == 1) + val types = rows(0).toSeq.map(x => x.getClass.toString) + assert(types.length == 5) + assert(types(0).equals("class java.sql.Date")) + assert(types(1).equals("class java.sql.Timestamp")) + assert(types(2).equals("class java.sql.Timestamp")) + assert(types(3).equals("class java.sql.Timestamp")) + assert(types(4).equals("class java.sql.Date")) + assert(rows(0).getAs[Date](0).equals(Date.valueOf("1991-11-09"))) + assert(rows(0).getAs[Timestamp](1).equals(Timestamp.valueOf("1970-01-01 13:31:24"))) + assert(rows(0).getAs[Timestamp](2).equals(Timestamp.valueOf("1996-01-01 01:23:45"))) + assert(rows(0).getAs[Timestamp](3).equals(Timestamp.valueOf("2009-02-13 23:31:30"))) + assert(rows(0).getAs[Date](4).equals(Date.valueOf("2001-01-01"))) + } + + test("String types") { + val df = sqlContext.read.jdbc(jdbcUrl, "strings", new Properties) + val rows = df.collect() + assert(rows.length == 1) + val types = rows(0).toSeq.map(x => x.getClass.toString) + assert(types.length == 9) + assert(types(0).equals("class java.lang.String")) + assert(types(1).equals("class java.lang.String")) + assert(types(2).equals("class java.lang.String")) + assert(types(3).equals("class java.lang.String")) + assert(types(4).equals("class java.lang.String")) + assert(types(5).equals("class java.lang.String")) + assert(types(6).equals("class [B")) + assert(types(7).equals("class [B")) + assert(types(8).equals("class [B")) + assert(rows(0).getString(0).equals("the")) + assert(rows(0).getString(1).equals("quick")) + assert(rows(0).getString(2).equals("brown")) + assert(rows(0).getString(3).equals("fox")) + assert(rows(0).getString(4).equals("jumps")) + assert(rows(0).getString(5).equals("over")) + assert(java.util.Arrays.equals(rows(0).getAs[Array[Byte]](6), Array[Byte](116, 104, 101, 0))) + assert(java.util.Arrays.equals(rows(0).getAs[Array[Byte]](7), Array[Byte](108, 97, 122, 121))) + assert(java.util.Arrays.equals(rows(0).getAs[Array[Byte]](8), Array[Byte](100, 111, 103))) + } + + test("Basic write test") { + val df1 = sqlContext.read.jdbc(jdbcUrl, "numbers", new Properties) + val df2 = sqlContext.read.jdbc(jdbcUrl, "dates", new Properties) + val df3 = sqlContext.read.jdbc(jdbcUrl, "strings", new Properties) + df1.write.jdbc(jdbcUrl, "numberscopy", new Properties) + df2.write.jdbc(jdbcUrl, "datescopy", new Properties) + df3.write.jdbc(jdbcUrl, "stringscopy", new Properties) + } +} diff --git a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/DockerJDBCIntegrationSuite.scala b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/DockerJDBCIntegrationSuite.scala index f73231fc80..c36f4d5f95 100644 --- a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/DockerJDBCIntegrationSuite.scala +++ b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/DockerJDBCIntegrationSuite.scala @@ -44,6 +44,11 @@ abstract class DatabaseOnDocker { */ val env: Map[String, String] + /** + * Wheather or not to use ipc mode for shared memory when starting docker image + */ + val usesIpc: Boolean + /** * The container-internal JDBC port that the database listens on. */ @@ -53,6 +58,11 @@ abstract class DatabaseOnDocker { * Return a JDBC URL that connects to the database running at the given IP address and port. */ def getJdbcUrl(ip: String, port: Int): String + + /** + * Optional process to run when container starts + */ + def getStartupProcessName: Option[String] } abstract class DockerJDBCIntegrationSuite @@ -97,17 +107,23 @@ abstract class DockerJDBCIntegrationSuite val dockerIp = DockerUtils.getDockerIp() val hostConfig: HostConfig = HostConfig.builder() .networkMode("bridge") + .ipcMode(if (db.usesIpc) "host" else "") .portBindings( Map(s"${db.jdbcPort}/tcp" -> List(PortBinding.of(dockerIp, externalPort)).asJava).asJava) .build() // Create the database container: - val config = ContainerConfig.builder() + val containerConfigBuilder = ContainerConfig.builder() .image(db.imageName) .networkDisabled(false) .env(db.env.map { case (k, v) => s"$k=$v" }.toSeq.asJava) .hostConfig(hostConfig) .exposedPorts(s"${db.jdbcPort}/tcp") - .build() + if(db.getStartupProcessName.isDefined) { + containerConfigBuilder + .cmd(db.getStartupProcessName.get) + } + val config = containerConfigBuilder.build() + // Create the database container: containerId = docker.createContainer(config).id // Start the container and wait until the database can accept JDBC connections: docker.startContainer(containerId) diff --git a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MySQLIntegrationSuite.scala b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MySQLIntegrationSuite.scala index c68e4dc493..a70ed98b52 100644 --- a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MySQLIntegrationSuite.scala +++ b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MySQLIntegrationSuite.scala @@ -30,9 +30,11 @@ class MySQLIntegrationSuite extends DockerJDBCIntegrationSuite { override val env = Map( "MYSQL_ROOT_PASSWORD" -> "rootpass" ) + override val usesIpc = false override val jdbcPort: Int = 3306 override def getJdbcUrl(ip: String, port: Int): String = s"jdbc:mysql://$ip:$port/mysql?user=root&password=rootpass" + override def getStartupProcessName: Option[String] = None } override def dataPreparation(conn: Connection): Unit = { diff --git a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/OracleIntegrationSuite.scala b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/OracleIntegrationSuite.scala index 8a0f938f7e..2fc174eb1b 100644 --- a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/OracleIntegrationSuite.scala +++ b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/OracleIntegrationSuite.scala @@ -52,9 +52,11 @@ class OracleIntegrationSuite extends DockerJDBCIntegrationSuite with SharedSQLCo override val env = Map( "ORACLE_ROOT_PASSWORD" -> "oracle" ) + override val usesIpc = false override val jdbcPort: Int = 1521 override def getJdbcUrl(ip: String, port: Int): String = s"jdbc:oracle:thin:system/oracle@//$ip:$port/xe" + override def getStartupProcessName: Option[String] = None } override def dataPreparation(conn: Connection): Unit = { diff --git a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala index d55cdcf28b..79dd70116e 100644 --- a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala +++ b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala @@ -32,9 +32,11 @@ class PostgresIntegrationSuite extends DockerJDBCIntegrationSuite { override val env = Map( "POSTGRES_PASSWORD" -> "rootpass" ) + override val usesIpc = false override val jdbcPort = 5432 override def getJdbcUrl(ip: String, port: Int): String = s"jdbc:postgresql://$ip:$port/postgres?user=postgres&password=rootpass" + override def getStartupProcessName: Option[String] = None } override def dataPreparation(conn: Connection): Unit = { diff --git a/pom.xml b/pom.xml index 38843b4f74..4585c8b9c2 100644 --- a/pom.xml +++ b/pom.xml @@ -666,7 +666,7 @@ com.spotify docker-client shaded - 3.4.0 + 3.6.6 test diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index c5688ecec6..a58dd7e7f1 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -366,8 +366,10 @@ object Flume { object DockerIntegrationTests { // This serves to override the override specified in DependencyOverrides: lazy val settings = Seq( - dependencyOverrides += "com.google.guava" % "guava" % "18.0" + dependencyOverrides += "com.google.guava" % "guava" % "18.0", + resolvers ++= Seq("DB2" at "https://app.camunda.com/nexus/content/repositories/public/") ) + } /** -- cgit v1.2.3 From 6f27027d96ada29d8bb1d626f2cc7c856df3d597 Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Mon, 11 Apr 2016 18:33:54 -0700 Subject: [SPARK-14475] Propagate user-defined context from driver to executors ## What changes were proposed in this pull request? This adds a new API call `TaskContext.getLocalProperty` for getting properties set in the driver from executors. These local properties are automatically propagated from the driver to executors. For streaming, the context for streaming tasks will be the initial driver context when ssc.start() is called. ## How was this patch tested? Unit tests. cc JoshRosen Author: Eric Liang Closes #12248 from ericl/sc-2813. --- .../main/scala/org/apache/spark/SparkContext.scala | 6 ++++-- .../main/scala/org/apache/spark/TaskContext.scala | 9 +++++++- .../scala/org/apache/spark/TaskContextImpl.scala | 5 +++++ .../scala/org/apache/spark/executor/Executor.scala | 17 ++++++++++++++- .../org/apache/spark/scheduler/DAGScheduler.scala | 4 ++-- .../org/apache/spark/scheduler/ResultTask.scala | 5 ++++- .../apache/spark/scheduler/ShuffleMapTask.scala | 9 +++++--- .../scala/org/apache/spark/scheduler/Task.scala | 20 +++++++++++++++--- .../scala/org/apache/spark/AccumulatorSuite.scala | 5 +++-- .../test/scala/org/apache/spark/ShuffleSuite.scala | 7 ++++--- .../apache/spark/memory/MemoryTestingUtils.scala | 3 +++ .../org/apache/spark/scheduler/FakeTask.scala | 4 +++- .../spark/scheduler/NotSerializableFakeTask.scala | 3 ++- .../apache/spark/scheduler/TaskContextSuite.scala | 24 ++++++++++++++++++---- .../spark/scheduler/TaskSetManagerSuite.scala | 4 ++-- .../spark/storage/BlockInfoManagerSuite.scala | 5 ++++- project/MimaExcludes.scala | 3 +++ .../UnsafeFixedWidthAggregationMapSuite.scala | 3 +++ .../execution/UnsafeKVExternalSorterSuite.scala | 3 +++ .../sql/execution/UnsafeRowSerializerSuite.scala | 3 ++- .../apache/spark/streaming/StreamingContext.scala | 8 ++++++++ .../spark/streaming/scheduler/JobGenerator.scala | 5 ----- .../spark/streaming/scheduler/JobScheduler.scala | 9 ++++++-- .../spark/streaming/StreamingContextSuite.scala | 10 ++++++++- 24 files changed, 138 insertions(+), 36 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 9ec5cedf25..f0d152f05a 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -602,8 +602,10 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli } /** - * Set a local property that affects jobs submitted from this thread, such as the - * Spark fair scheduler pool. + * Set a local property that affects jobs submitted from this thread, such as the Spark fair + * scheduler pool. User-defined properties may also be set here. These properties are propagated + * through to worker tasks and can be accessed there via + * [[org.apache.spark.TaskContext#getLocalProperty]]. */ def setLocalProperty(key: String, value: String) { if (value == null) { diff --git a/core/src/main/scala/org/apache/spark/TaskContext.scala b/core/src/main/scala/org/apache/spark/TaskContext.scala index bfcacbf229..757c1b5116 100644 --- a/core/src/main/scala/org/apache/spark/TaskContext.scala +++ b/core/src/main/scala/org/apache/spark/TaskContext.scala @@ -18,6 +18,7 @@ package org.apache.spark import java.io.Serializable +import java.util.Properties import org.apache.spark.annotation.DeveloperApi import org.apache.spark.executor.TaskMetrics @@ -64,7 +65,7 @@ object TaskContext { * An empty task context that does not represent an actual task. */ private[spark] def empty(): TaskContextImpl = { - new TaskContextImpl(0, 0, 0, 0, null, null) + new TaskContextImpl(0, 0, 0, 0, null, new Properties, null) } } @@ -162,6 +163,12 @@ abstract class TaskContext extends Serializable { */ def taskAttemptId(): Long + /** + * Get a local property set upstream in the driver, or null if it is missing. See also + * [[org.apache.spark.SparkContext.setLocalProperty]]. + */ + def getLocalProperty(key: String): String + @DeveloperApi def taskMetrics(): TaskMetrics diff --git a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala index c9354b3e55..fa0b2d3d28 100644 --- a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala +++ b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala @@ -17,6 +17,8 @@ package org.apache.spark +import java.util.Properties + import scala.collection.mutable.ArrayBuffer import org.apache.spark.executor.TaskMetrics @@ -32,6 +34,7 @@ private[spark] class TaskContextImpl( override val taskAttemptId: Long, override val attemptNumber: Int, override val taskMemoryManager: TaskMemoryManager, + localProperties: Properties, @transient private val metricsSystem: MetricsSystem, initialAccumulators: Seq[Accumulator[_]] = InternalAccumulator.createAll()) extends TaskContext @@ -118,6 +121,8 @@ private[spark] class TaskContextImpl( override def isInterrupted(): Boolean = interrupted + override def getLocalProperty(key: String): String = localProperties.getProperty(key) + override def getMetricsSources(sourceName: String): Seq[Source] = metricsSystem.getSourcesByName(sourceName) diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index afa4d6093a..9f94fdef24 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -21,6 +21,7 @@ import java.io.{File, NotSerializableException} import java.lang.management.ManagementFactory import java.net.URL import java.nio.ByteBuffer +import java.util.Properties import java.util.concurrent.{ConcurrentHashMap, TimeUnit} import scala.collection.JavaConverters._ @@ -206,9 +207,16 @@ private[spark] class Executor( startGCTime = computeTotalGcTime() try { - val (taskFiles, taskJars, taskBytes) = Task.deserializeWithDependencies(serializedTask) + val (taskFiles, taskJars, taskProps, taskBytes) = + Task.deserializeWithDependencies(serializedTask) + + // Must be set before updateDependencies() is called, in case fetching dependencies + // requires access to properties contained within (e.g. for access control). + Executor.taskDeserializationProps.set(taskProps) + updateDependencies(taskFiles, taskJars) task = ser.deserialize[Task[Any]](taskBytes, Thread.currentThread.getContextClassLoader) + task.localProperties = taskProps task.setTaskMemoryManager(taskMemoryManager) // If this task has been killed before we deserialized it, let's quit now. Otherwise, @@ -506,3 +514,10 @@ private[spark] class Executor( heartbeater.scheduleAtFixedRate(heartbeatTask, initialDelay, intervalMs, TimeUnit.MILLISECONDS) } } + +private[spark] object Executor { + // This is reserved for internal use by components that need to read task properties before a + // task is fully deserialized. When possible, the TaskContext.getLocalProperty call should be + // used instead. + val taskDeserializationProps: ThreadLocal[Properties] = new ThreadLocal[Properties] +} diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 5cdc91316b..4609b244e6 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -1036,7 +1036,7 @@ class DAGScheduler( val locs = taskIdToLocations(id) val part = stage.rdd.partitions(id) new ShuffleMapTask(stage.id, stage.latestInfo.attemptId, - taskBinary, part, locs, stage.internalAccumulators) + taskBinary, part, locs, stage.internalAccumulators, properties) } case stage: ResultStage => @@ -1046,7 +1046,7 @@ class DAGScheduler( val part = stage.rdd.partitions(p) val locs = taskIdToLocations(id) new ResultTask(stage.id, stage.latestInfo.attemptId, - taskBinary, part, locs, id, stage.internalAccumulators) + taskBinary, part, locs, id, properties, stage.internalAccumulators) } } } catch { diff --git a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala index cd2736e196..db6276f75d 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala @@ -19,6 +19,7 @@ package org.apache.spark.scheduler import java.io._ import java.nio.ByteBuffer +import java.util.Properties import org.apache.spark._ import org.apache.spark.broadcast.Broadcast @@ -38,6 +39,7 @@ import org.apache.spark.rdd.RDD * @param locs preferred task execution locations for locality scheduling * @param outputId index of the task in this job (a job can launch tasks on only a subset of the * input RDD's partitions). + * @param localProperties copy of thread-local properties set by the user on the driver side. * @param _initialAccums initial set of accumulators to be used in this task for tracking * internal metrics. Other accumulators will be registered later when * they are deserialized on the executors. @@ -49,8 +51,9 @@ private[spark] class ResultTask[T, U]( partition: Partition, locs: Seq[TaskLocation], val outputId: Int, + localProperties: Properties, _initialAccums: Seq[Accumulator[_]] = InternalAccumulator.createAll()) - extends Task[U](stageId, stageAttemptId, partition.index, _initialAccums) + extends Task[U](stageId, stageAttemptId, partition.index, _initialAccums, localProperties) with Serializable { @transient private[this] val preferredLocs: Seq[TaskLocation] = { diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala index e30964a01b..b7cab7013e 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala @@ -18,6 +18,7 @@ package org.apache.spark.scheduler import java.nio.ByteBuffer +import java.util.Properties import scala.language.existentials @@ -42,6 +43,7 @@ import org.apache.spark.shuffle.ShuffleWriter * @param _initialAccums initial set of accumulators to be used in this task for tracking * internal metrics. Other accumulators will be registered later when * they are deserialized on the executors. + * @param localProperties copy of thread-local properties set by the user on the driver side. */ private[spark] class ShuffleMapTask( stageId: Int, @@ -49,13 +51,14 @@ private[spark] class ShuffleMapTask( taskBinary: Broadcast[Array[Byte]], partition: Partition, @transient private var locs: Seq[TaskLocation], - _initialAccums: Seq[Accumulator[_]]) - extends Task[MapStatus](stageId, stageAttemptId, partition.index, _initialAccums) + _initialAccums: Seq[Accumulator[_]], + localProperties: Properties) + extends Task[MapStatus](stageId, stageAttemptId, partition.index, _initialAccums, localProperties) with Logging { /** A constructor used only in test suites. This does not require passing in an RDD. */ def this(partitionId: Int) { - this(0, 0, null, new Partition { override def index: Int = 0 }, null, null) + this(0, 0, null, new Partition { override def index: Int = 0 }, null, null, new Properties) } @transient private val preferredLocs: Seq[TaskLocation] = { diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala index c91d8fbfc4..1ff9d7795f 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala @@ -19,6 +19,7 @@ package org.apache.spark.scheduler import java.io.{DataInputStream, DataOutputStream} import java.nio.ByteBuffer +import java.util.Properties import scala.collection.mutable.HashMap @@ -46,12 +47,14 @@ import org.apache.spark.util.{ByteBufferInputStream, ByteBufferOutputStream, Uti * @param initialAccumulators initial set of accumulators to be used in this task for tracking * internal metrics. Other accumulators will be registered later when * they are deserialized on the executors. + * @param localProperties copy of thread-local properties set by the user on the driver side. */ private[spark] abstract class Task[T]( val stageId: Int, val stageAttemptId: Int, val partitionId: Int, - val initialAccumulators: Seq[Accumulator[_]]) extends Serializable { + val initialAccumulators: Seq[Accumulator[_]], + @transient var localProperties: Properties) extends Serializable { /** * Called by [[org.apache.spark.executor.Executor]] to run this task. @@ -71,6 +74,7 @@ private[spark] abstract class Task[T]( taskAttemptId, attemptNumber, taskMemoryManager, + localProperties, metricsSystem, initialAccumulators) TaskContext.setTaskContext(context) @@ -212,6 +216,11 @@ private[spark] object Task { dataOut.writeLong(timestamp) } + // Write the task properties separately so it is available before full task deserialization. + val propBytes = Utils.serialize(task.localProperties) + dataOut.writeInt(propBytes.length) + dataOut.write(propBytes) + // Write the task itself and finish dataOut.flush() val taskBytes = serializer.serialize(task) @@ -227,7 +236,7 @@ private[spark] object Task { * @return (taskFiles, taskJars, taskBytes) */ def deserializeWithDependencies(serializedTask: ByteBuffer) - : (HashMap[String, Long], HashMap[String, Long], ByteBuffer) = { + : (HashMap[String, Long], HashMap[String, Long], Properties, ByteBuffer) = { val in = new ByteBufferInputStream(serializedTask) val dataIn = new DataInputStream(in) @@ -246,8 +255,13 @@ private[spark] object Task { taskJars(dataIn.readUTF()) = dataIn.readLong() } + val propLength = dataIn.readInt() + val propBytes = new Array[Byte](propLength) + dataIn.readFully(propBytes, 0, propLength) + val taskProps = Utils.deserialize[Properties](propBytes) + // Create a sub-buffer for the rest of the data, which is the serialized Task object val subBuffer = serializedTask.slice() // ByteBufferInputStream will have read just up to task - (taskFiles, taskJars, subBuffer) + (taskFiles, taskJars, taskProps, subBuffer) } } diff --git a/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala b/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala index ec192a8543..37879d11ca 100644 --- a/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala +++ b/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark +import java.util.Properties import java.util.concurrent.Semaphore import javax.annotation.concurrent.GuardedBy @@ -292,7 +293,7 @@ class AccumulatorSuite extends SparkFunSuite with Matchers with LocalSparkContex dummyTask, mutable.HashMap(), mutable.HashMap(), serInstance) // Now we're on the executors. // Deserialize the task and assert that its accumulators are zero'ed out. - val (_, _, taskBytes) = Task.deserializeWithDependencies(taskSer) + val (_, _, _, taskBytes) = Task.deserializeWithDependencies(taskSer) val taskDeser = serInstance.deserialize[DummyTask]( taskBytes, Thread.currentThread.getContextClassLoader) // Assert that executors see only zeros @@ -403,6 +404,6 @@ private class SaveInfoListener extends SparkListener { private[spark] class DummyTask( val internalAccums: Seq[Accumulator[_]], val externalAccums: Seq[Accumulator[_]]) - extends Task[Int](0, 0, 0, internalAccums) { + extends Task[Int](0, 0, 0, internalAccums, new Properties) { override def runTask(c: TaskContext): Int = 1 } diff --git a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala index 6ffa1c8ac1..00f3f15c45 100644 --- a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala +++ b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark +import java.util.Properties import java.util.concurrent.{Callable, CyclicBarrier, Executors, ExecutorService} import org.scalatest.Matchers @@ -335,7 +336,7 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC // first attempt -- its successful val writer1 = manager.getWriter[Int, Int](shuffleHandle, 0, - new TaskContextImpl(0, 0, 0L, 0, taskMemoryManager, metricsSystem, + new TaskContextImpl(0, 0, 0L, 0, taskMemoryManager, new Properties, metricsSystem, InternalAccumulator.create(sc))) val data1 = (1 to 10).map { x => x -> x} @@ -343,7 +344,7 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC // just to simulate the fact that the records may get written differently // depending on what gets spilled, what gets combined, etc. val writer2 = manager.getWriter[Int, Int](shuffleHandle, 0, - new TaskContextImpl(0, 0, 1L, 0, taskMemoryManager, metricsSystem, + new TaskContextImpl(0, 0, 1L, 0, taskMemoryManager, new Properties, metricsSystem, InternalAccumulator.create(sc))) val data2 = (11 to 20).map { x => x -> x} @@ -372,7 +373,7 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC } val reader = manager.getReader[Int, Int](shuffleHandle, 0, 1, - new TaskContextImpl(1, 0, 2L, 0, taskMemoryManager, metricsSystem, + new TaskContextImpl(1, 0, 2L, 0, taskMemoryManager, new Properties, metricsSystem, InternalAccumulator.create(sc))) val readData = reader.read().toIndexedSeq assert(readData === data1.toIndexedSeq || readData === data2.toIndexedSeq) diff --git a/core/src/test/scala/org/apache/spark/memory/MemoryTestingUtils.scala b/core/src/test/scala/org/apache/spark/memory/MemoryTestingUtils.scala index 2b5e4b80e9..362cd861cc 100644 --- a/core/src/test/scala/org/apache/spark/memory/MemoryTestingUtils.scala +++ b/core/src/test/scala/org/apache/spark/memory/MemoryTestingUtils.scala @@ -17,6 +17,8 @@ package org.apache.spark.memory +import java.util.Properties + import org.apache.spark.{SparkEnv, TaskContext, TaskContextImpl} /** @@ -31,6 +33,7 @@ object MemoryTestingUtils { taskAttemptId = 0, attemptNumber = 0, taskMemoryManager = taskMemoryManager, + localProperties = new Properties, metricsSystem = env.metricsSystem) } } diff --git a/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala b/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala index f7e16af9d3..e3e6df6831 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala @@ -17,12 +17,14 @@ package org.apache.spark.scheduler +import java.util.Properties + import org.apache.spark.TaskContext class FakeTask( stageId: Int, prefLocs: Seq[TaskLocation] = Nil) - extends Task[Int](stageId, 0, 0, Seq.empty) { + extends Task[Int](stageId, 0, 0, Seq.empty, new Properties) { override def runTask(context: TaskContext): Int = 0 override def preferredLocations: Seq[TaskLocation] = prefLocs } diff --git a/core/src/test/scala/org/apache/spark/scheduler/NotSerializableFakeTask.scala b/core/src/test/scala/org/apache/spark/scheduler/NotSerializableFakeTask.scala index 1dca4bd89f..76a7087645 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/NotSerializableFakeTask.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/NotSerializableFakeTask.scala @@ -18,6 +18,7 @@ package org.apache.spark.scheduler import java.io.{IOException, ObjectInputStream, ObjectOutputStream} +import java.util.Properties import org.apache.spark.TaskContext @@ -25,7 +26,7 @@ import org.apache.spark.TaskContext * A Task implementation that fails to serialize. */ private[spark] class NotSerializableFakeTask(myId: Int, stageId: Int) - extends Task[Array[Byte]](stageId, 0, 0, Seq.empty) { + extends Task[Array[Byte]](stageId, 0, 0, Seq.empty, new Properties) { override def runTask(context: TaskContext): Array[Byte] = Array.empty[Byte] override def preferredLocations: Seq[TaskLocation] = Seq[TaskLocation]() diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala index c4cf2f9f70..86911d2211 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala @@ -17,12 +17,14 @@ package org.apache.spark.scheduler +import java.util.Properties + import org.mockito.Matchers.any import org.mockito.Mockito._ import org.scalatest.BeforeAndAfter import org.apache.spark._ -import org.apache.spark.executor.TaskMetricsSuite +import org.apache.spark.executor.{Executor, TaskMetricsSuite} import org.apache.spark.memory.TaskMemoryManager import org.apache.spark.metrics.source.JvmSource import org.apache.spark.network.util.JavaUtils @@ -59,7 +61,8 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark val closureSerializer = SparkEnv.get.closureSerializer.newInstance() val func = (c: TaskContext, i: Iterator[String]) => i.next() val taskBinary = sc.broadcast(JavaUtils.bufferToArray(closureSerializer.serialize((rdd, func)))) - val task = new ResultTask[String, String](0, 0, taskBinary, rdd.partitions(0), Seq.empty, 0) + val task = new ResultTask[String, String]( + 0, 0, taskBinary, rdd.partitions(0), Seq.empty, 0, new Properties) intercept[RuntimeException] { task.run(0, 0, null) } @@ -79,7 +82,8 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark val closureSerializer = SparkEnv.get.closureSerializer.newInstance() val func = (c: TaskContext, i: Iterator[String]) => i.next() val taskBinary = sc.broadcast(JavaUtils.bufferToArray(closureSerializer.serialize((rdd, func)))) - val task = new ResultTask[String, String](0, 0, taskBinary, rdd.partitions(0), Seq.empty, 0) + val task = new ResultTask[String, String]( + 0, 0, taskBinary, rdd.partitions(0), Seq.empty, 0, new Properties) intercept[RuntimeException] { task.run(0, 0, null) } @@ -170,9 +174,10 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark val initialAccums = InternalAccumulator.createAll() // Create a dummy task. We won't end up running this; we just want to collect // accumulator updates from it. - val task = new Task[Int](0, 0, 0, Seq.empty[Accumulator[_]]) { + val task = new Task[Int](0, 0, 0, Seq.empty[Accumulator[_]], new Properties) { context = new TaskContextImpl(0, 0, 0L, 0, new TaskMemoryManager(SparkEnv.get.memoryManager, 0L), + new Properties, SparkEnv.get.metricsSystem, initialAccums) context.taskMetrics.registerAccumulator(acc1) @@ -189,6 +194,17 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark TaskMetricsSuite.assertUpdatesEquals(accumUpdates3, accumUpdates4) } + test("localProperties are propagated to executors correctly") { + sc = new SparkContext("local", "test") + sc.setLocalProperty("testPropKey", "testPropValue") + val res = sc.parallelize(Array(1), 1).map(i => i).map(i => { + val inTask = TaskContext.get().getLocalProperty("testPropKey") + val inDeser = Executor.taskDeserializationProps.get().getProperty("testPropKey") + s"$inTask,$inDeser" + }).collect() + assert(res === Array("testPropValue,testPropValue")) + } + } private object TaskContextSuite { diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala index 167d3fd2e4..ade8e84d84 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.scheduler -import java.util.Random +import java.util.{Properties, Random} import scala.collection.Map import scala.collection.mutable @@ -138,7 +138,7 @@ class FakeTaskScheduler(sc: SparkContext, liveExecutors: (String, String)* /* ex /** * A Task implementation that results in a large serialized task. */ -class LargeTask(stageId: Int) extends Task[Array[Byte]](stageId, 0, 0, Seq.empty) { +class LargeTask(stageId: Int) extends Task[Array[Byte]](stageId, 0, 0, Seq.empty, new Properties) { val randomBuffer = new Array[Byte](TaskSetManager.TASK_SIZE_TO_WARN_KB * 1024) val random = new Random(0) random.nextBytes(randomBuffer) diff --git a/core/src/test/scala/org/apache/spark/storage/BlockInfoManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockInfoManagerSuite.scala index 7ee76aa4c6..9d1bd7ec89 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockInfoManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockInfoManagerSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.storage +import java.util.Properties + import scala.concurrent.{Await, ExecutionContext, Future} import scala.language.implicitConversions import scala.reflect.ClassTag @@ -58,7 +60,8 @@ class BlockInfoManagerSuite extends SparkFunSuite with BeforeAndAfterEach { private def withTaskId[T](taskAttemptId: Long)(block: => T): T = { try { - TaskContext.setTaskContext(new TaskContextImpl(0, 0, taskAttemptId, 0, null, null)) + TaskContext.setTaskContext( + new TaskContextImpl(0, 0, taskAttemptId, 0, null, new Properties, null)) block } finally { TaskContext.unset() diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 290de794dc..a30581eb48 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -625,6 +625,9 @@ object MimaExcludes { ) ++ Seq( // [SPARK-13048][ML][MLLIB] keepLastCheckpoint option for LDA EM optimizer ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.mllib.clustering.DistributedLDAModel.this") + ) ++ Seq( + // [SPARK-14475] Propagate user-defined context from driver to executors + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.TaskContext.getLocalProperty") ) case v if v.startsWith("1.6") => Seq( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala index 4dc7d3461c..c1555114e8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution +import java.util.Properties + import scala.collection.mutable import scala.util.{Random, Try} import scala.util.control.NonFatal @@ -71,6 +73,7 @@ class UnsafeFixedWidthAggregationMapSuite taskAttemptId = Random.nextInt(10000), attemptNumber = 0, taskMemoryManager = taskMemoryManager, + localProperties = new Properties, metricsSystem = null)) try { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala index 476d93fc2a..03d4be8ee5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution +import java.util.Properties + import scala.util.Random import org.apache.spark._ @@ -117,6 +119,7 @@ class UnsafeKVExternalSorterSuite extends SparkFunSuite with SharedSQLContext { taskAttemptId = 98456, attemptNumber = 0, taskMemoryManager = taskMemMgr, + localProperties = new Properties, metricsSystem = null)) val sorter = new UnsafeKVExternalSorter( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala index 1f3779373b..7db1f9654b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution import java.io.{ByteArrayInputStream, ByteArrayOutputStream, File} +import java.util.Properties import org.apache.spark._ import org.apache.spark.memory.TaskMemoryManager @@ -113,7 +114,7 @@ class UnsafeRowSerializerSuite extends SparkFunSuite with LocalSparkContext { } val taskMemoryManager = new TaskMemoryManager(sc.env.memoryManager, 0) val taskContext = new TaskContextImpl( - 0, 0, 0, 0, taskMemoryManager, null, InternalAccumulator.create(sc)) + 0, 0, 0, 0, taskMemoryManager, new Properties, null, InternalAccumulator.create(sc)) val sorter = new ExternalSorter[Int, UnsafeRow, UnsafeRow]( taskContext, diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala index cc187f5cb4..928739a416 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala @@ -18,6 +18,7 @@ package org.apache.spark.streaming import java.io.{InputStream, NotSerializableException} +import java.util.Properties import java.util.concurrent.atomic.{AtomicInteger, AtomicReference} import scala.collection.Map @@ -25,6 +26,7 @@ import scala.collection.mutable.Queue import scala.reflect.ClassTag import scala.util.control.NonFatal +import org.apache.commons.lang.SerializationUtils import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.apache.hadoop.io.{BytesWritable, LongWritable, Text} @@ -198,6 +200,10 @@ class StreamingContext private[streaming] ( private val startSite = new AtomicReference[CallSite](null) + // Copy of thread-local properties from SparkContext. These properties will be set in all tasks + // submitted by this StreamingContext after start. + private[streaming] val savedProperties = new AtomicReference[Properties](new Properties) + private[streaming] def getStartSite(): CallSite = startSite.get() private var shutdownHookRef: AnyRef = _ @@ -573,6 +579,8 @@ class StreamingContext private[streaming] ( sparkContext.setCallSite(startSite.get) sparkContext.clearJobGroup() sparkContext.setLocalProperty(SparkContext.SPARK_JOB_INTERRUPT_ON_CANCEL, "false") + savedProperties.set(SerializationUtils.clone( + sparkContext.localProperties.get()).asInstanceOf[Properties]) scheduler.start() } state = StreamingContextState.ACTIVE diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala index 86f069b0bd..307ff1f7ec 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala @@ -241,11 +241,6 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging { /** Generate jobs and perform checkpoint for the given `time`. */ private def generateJobs(time: Time) { - // Set the SparkEnv in this thread, so that job generation code can access the environment - // Example: BlockRDDs are created in this thread, and it needs to access BlockManager - // Update: This is probably redundant after threadlocal stuff in SparkEnv has been removed. - SparkEnv.set(ssc.env) - // Checkpoint all RDDs marked for checkpointing to ensure their lineages are // truncated periodically. Otherwise, we may run into stack overflows (SPARK-6847). ssc.sparkContext.setLocalProperty(RDD.CHECKPOINT_ALL_MARKED_ANCESTORS, "true") diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala index 303c325274..ac18f73ea8 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala @@ -17,11 +17,14 @@ package org.apache.spark.streaming.scheduler +import java.util.Properties import java.util.concurrent.{ConcurrentHashMap, TimeUnit} import scala.collection.JavaConverters._ import scala.util.Failure +import org.apache.commons.lang.SerializationUtils + import org.apache.spark.internal.Logging import org.apache.spark.rdd.{PairRDDFunctions, RDD} import org.apache.spark.streaming._ @@ -214,7 +217,10 @@ class JobScheduler(val ssc: StreamingContext) extends Logging { import JobScheduler._ def run() { + val oldProps = ssc.sparkContext.getLocalProperties try { + ssc.sparkContext.setLocalProperties( + SerializationUtils.clone(ssc.savedProperties.get()).asInstanceOf[Properties]) val formattedTime = UIUtils.formatBatchTime( job.time.milliseconds, ssc.graph.batchDuration.milliseconds, showYYYYMMSS = false) val batchUrl = s"/streaming/batch/?id=${job.time.milliseconds}" @@ -248,8 +254,7 @@ class JobScheduler(val ssc: StreamingContext) extends Logging { // JobScheduler has been stopped. } } finally { - ssc.sc.setLocalProperty(JobScheduler.BATCH_TIME_PROPERTY_KEY, null) - ssc.sc.setLocalProperty(JobScheduler.OUTPUT_OP_ID_PROPERTY_KEY, null) + ssc.sparkContext.setLocalProperties(oldProps) } } } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala index a80154e2fc..806e181f61 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala @@ -182,7 +182,7 @@ class StreamingContextSuite extends SparkFunSuite with BeforeAndAfter with Timeo assert(ssc.scheduler.isStarted === false) } - test("start should set job group and description of streaming jobs correctly") { + test("start should set local properties of streaming jobs correctly") { ssc = new StreamingContext(conf, batchDuration) ssc.sc.setJobGroup("non-streaming", "non-streaming", true) val sc = ssc.sc @@ -190,16 +190,22 @@ class StreamingContextSuite extends SparkFunSuite with BeforeAndAfter with Timeo @volatile var jobGroupFound: String = "" @volatile var jobDescFound: String = "" @volatile var jobInterruptFound: String = "" + @volatile var customPropFound: String = "" @volatile var allFound: Boolean = false addInputStream(ssc).foreachRDD { rdd => jobGroupFound = sc.getLocalProperty(SparkContext.SPARK_JOB_GROUP_ID) jobDescFound = sc.getLocalProperty(SparkContext.SPARK_JOB_DESCRIPTION) jobInterruptFound = sc.getLocalProperty(SparkContext.SPARK_JOB_INTERRUPT_ON_CANCEL) + customPropFound = sc.getLocalProperty("customPropKey") allFound = true } + ssc.sc.setLocalProperty("customPropKey", "value1") ssc.start() + // Local props set after start should be ignored + ssc.sc.setLocalProperty("customPropKey", "value2") + eventually(timeout(10 seconds), interval(10 milliseconds)) { assert(allFound === true) } @@ -208,11 +214,13 @@ class StreamingContextSuite extends SparkFunSuite with BeforeAndAfter with Timeo assert(jobGroupFound === null) assert(jobDescFound.contains("Streaming job from")) assert(jobInterruptFound === "false") + assert(customPropFound === "value1") // Verify current thread's thread-local properties have not changed assert(sc.getLocalProperty(SparkContext.SPARK_JOB_GROUP_ID) === "non-streaming") assert(sc.getLocalProperty(SparkContext.SPARK_JOB_DESCRIPTION) === "non-streaming") assert(sc.getLocalProperty(SparkContext.SPARK_JOB_INTERRUPT_ON_CANCEL) === "true") + assert(sc.getLocalProperty("customPropKey") === "value2") } test("start multiple times") { -- cgit v1.2.3 From 26d7af9119a73d851c86314b4a207c0bfe437082 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Mon, 11 Apr 2016 19:06:38 -0700 Subject: [SPARK-14520][SQL] Use correct return type in VectorizedParquetInputFormat ## What changes were proposed in this pull request? JIRA: https://issues.apache.org/jira/browse/SPARK-14520 `VectorizedParquetInputFormat` inherits `ParquetInputFormat` and overrides `createRecordReader`. However, its overridden `createRecordReader` returns a `ParquetRecordReader`. It should return a `RecordReader`. Otherwise, `ClassCastException` will be thrown. ## How was this patch tested? Existing tests. Author: Liang-Chi Hsieh Closes #12292 from viirya/fix-vectorized-input-format. --- .../spark/sql/execution/datasources/parquet/ParquetRelation.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala index ca6803b737..bcb2b2de13 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala @@ -476,8 +476,8 @@ private[sql] class DefaultSource final class VectorizedParquetInputFormat extends ParquetInputFormat[InternalRow] { override def createRecordReader( inputSplit: InputSplit, - taskAttemptContext: TaskAttemptContext): ParquetRecordReader[InternalRow] = { - new VectorizedParquetRecordReader().asInstanceOf[ParquetRecordReader[InternalRow]] + taskAttemptContext: TaskAttemptContext): RecordReader[Void, InternalRow] = { + new VectorizedParquetRecordReader().asInstanceOf[RecordReader[Void, InternalRow]] } } -- cgit v1.2.3 From e9e1adc036643c6b126237903b8e79ab379b1d32 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Tue, 12 Apr 2016 03:24:26 +0100 Subject: [MINOR][ML] Fixed MLlib build warnings ## What changes were proposed in this pull request? Fixes to eliminate warnings during package and doc builds. ## How was this patch tested? Existing unit tests Author: Joseph K. Bradley Closes #12263 from jkbradley/warning-cleanups. --- .../org/apache/spark/examples/mllib/JavaStratifiedSamplingExample.java | 1 + .../test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala | 3 +++ 2 files changed, 4 insertions(+) diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaStratifiedSamplingExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaStratifiedSamplingExample.java index 86c389e11c..72bbb2a8fa 100644 --- a/examples/src/main/java/org/apache/spark/examples/mllib/JavaStratifiedSamplingExample.java +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaStratifiedSamplingExample.java @@ -35,6 +35,7 @@ public class JavaStratifiedSamplingExample { SparkConf conf = new SparkConf().setAppName("JavaStratifiedSamplingExample"); JavaSparkContext jsc = new JavaSparkContext(conf); + @SuppressWarnings("unchecked") // $example on$ List> list = new ArrayList<>( Arrays.>asList( diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala index e64551f03c..cd402b1e1f 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala @@ -327,7 +327,9 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { case n: InternalNode => n.split match { case s: CategoricalSplit => assert(s.leftCategories === Array(1.0)) + case _ => throw new AssertionError("model.rootNode.split was not a CategoricalSplit") } + case _ => throw new AssertionError("model.rootNode was not an InternalNode") } } @@ -352,6 +354,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { assert(n.leftChild.isInstanceOf[InternalNode]) assert(n.rightChild.isInstanceOf[InternalNode]) Array(n.leftChild.asInstanceOf[InternalNode], n.rightChild.asInstanceOf[InternalNode]) + case _ => throw new AssertionError("rootNode was not an InternalNode") } // Single group second level tree construction. -- cgit v1.2.3 From 83fb96403bcfb1566e9d765690744824724737ac Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Mon, 11 Apr 2016 20:59:45 -0700 Subject: [SPARK-14132][SPARK-14133][SQL] Alter table partition DDLs ## What changes were proposed in this pull request? This implements a few alter table partition commands using the `SessionCatalog`. In particular: ``` ALTER TABLE ... ADD PARTITION ... ALTER TABLE ... DROP PARTITION ... ALTER TABLE ... RENAME PARTITION ... TO ... ``` The following operations are not supported, and an `AnalysisException` with a helpful error message will be thrown if the user tries to use them: ``` ALTER TABLE ... EXCHANGE PARTITION ... ALTER TABLE ... ARCHIVE PARTITION ... ALTER TABLE ... UNARCHIVE PARTITION ... ALTER TABLE ... TOUCH ... ALTER TABLE ... COMPACT ... ALTER TABLE ... CONCATENATE MSCK REPAIR TABLE ... ``` ## How was this patch tested? `DDLSuite`, `DDLCommandSuite` and `HiveDDLCommandSuite` Author: Andrew Or Closes #12220 from andrewor14/alter-partition-ddl. --- .../apache/spark/sql/catalyst/parser/SqlBase.g4 | 6 +- .../sql/catalyst/catalog/CatalogTestCases.scala | 18 ++- .../sql/catalyst/catalog/SessionCatalogSuite.scala | 28 +++- .../spark/sql/execution/SparkSqlParser.scala | 56 +++---- .../apache/spark/sql/execution/command/ddl.scala | 104 ++++++++----- .../sql/execution/command/DDLCommandSuite.scala | 118 +++++---------- .../spark/sql/execution/command/DDLSuite.scala | 162 ++++++++++++++++++++- .../hive/execution/HiveCompatibilitySuite.scala | 10 +- .../spark/sql/hive/HiveExternalCatalog.scala | 21 +-- .../apache/spark/sql/hive/client/HiveClient.scala | 9 +- .../spark/sql/hive/client/HiveClientImpl.scala | 22 ++- .../spark/sql/hive/HiveDDLCommandSuite.scala | 12 ++ 12 files changed, 360 insertions(+), 206 deletions(-) diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index 2f2e060b38..0e2cd39448 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -148,7 +148,7 @@ hiveNativeCommands | ROLLBACK WORK? | SHOW PARTITIONS tableIdentifier partitionSpec? | DFS .*? - | (CREATE | ALTER | DROP | SHOW | DESC | DESCRIBE | MSCK | LOAD) .*? + | (CREATE | ALTER | DROP | SHOW | DESC | DESCRIBE | LOAD) .*? ; unsupportedHiveNativeCommands @@ -177,6 +177,7 @@ unsupportedHiveNativeCommands | kw1=UNLOCK kw2=DATABASE | kw1=CREATE kw2=TEMPORARY kw3=MACRO | kw1=DROP kw2=TEMPORARY kw3=MACRO + | kw1=MSCK kw2=REPAIR kw3=TABLE ; createTableHeader @@ -651,7 +652,7 @@ nonReserved | AFTER | CASCADE | RESTRICT | BUCKETS | CLUSTERED | SORTED | PURGE | INPUTFORMAT | OUTPUTFORMAT | INPUTDRIVER | OUTPUTDRIVER | DBPROPERTIES | DFS | TRUNCATE | METADATA | REPLICATION | COMPUTE | STATISTICS | ANALYZE | PARTITIONED | EXTERNAL | DEFINED | RECORDWRITER - | REVOKE | GRANT | LOCK | UNLOCK | MSCK | EXPORT | IMPORT | LOAD | VALUES | COMMENT | ROLE + | REVOKE | GRANT | LOCK | UNLOCK | MSCK | REPAIR | EXPORT | IMPORT | LOAD | VALUES | COMMENT | ROLE | ROLES | COMPACTIONS | PRINCIPALS | TRANSACTIONS | INDEX | INDEXES | LOCKS | OPTION ; @@ -867,6 +868,7 @@ GRANT: 'GRANT'; LOCK: 'LOCK'; UNLOCK: 'UNLOCK'; MSCK: 'MSCK'; +REPAIR: 'REPAIR'; EXPORT: 'EXPORT'; IMPORT: 'IMPORT'; LOAD: 'LOAD'; diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/CatalogTestCases.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/CatalogTestCases.scala index 0009438b31..0d9b0851fa 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/CatalogTestCases.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/CatalogTestCases.scala @@ -281,31 +281,37 @@ abstract class CatalogTestCases extends SparkFunSuite with BeforeAndAfterEach { test("drop partitions") { val catalog = newBasicCatalog() assert(catalogPartitionsEqual(catalog, "db2", "tbl2", Seq(part1, part2))) - catalog.dropPartitions("db2", "tbl2", Seq(part1.spec), ignoreIfNotExists = false) + catalog.dropPartitions( + "db2", "tbl2", Seq(part1.spec), ignoreIfNotExists = false) assert(catalogPartitionsEqual(catalog, "db2", "tbl2", Seq(part2))) resetState() val catalog2 = newBasicCatalog() assert(catalogPartitionsEqual(catalog2, "db2", "tbl2", Seq(part1, part2))) - catalog2.dropPartitions("db2", "tbl2", Seq(part1.spec, part2.spec), ignoreIfNotExists = false) + catalog2.dropPartitions( + "db2", "tbl2", Seq(part1.spec, part2.spec), ignoreIfNotExists = false) assert(catalog2.listPartitions("db2", "tbl2").isEmpty) } test("drop partitions when database/table does not exist") { val catalog = newBasicCatalog() intercept[AnalysisException] { - catalog.dropPartitions("does_not_exist", "tbl1", Seq(), ignoreIfNotExists = false) + catalog.dropPartitions( + "does_not_exist", "tbl1", Seq(), ignoreIfNotExists = false) } intercept[AnalysisException] { - catalog.dropPartitions("db2", "does_not_exist", Seq(), ignoreIfNotExists = false) + catalog.dropPartitions( + "db2", "does_not_exist", Seq(), ignoreIfNotExists = false) } } test("drop partitions that do not exist") { val catalog = newBasicCatalog() intercept[AnalysisException] { - catalog.dropPartitions("db2", "tbl2", Seq(part3.spec), ignoreIfNotExists = false) + catalog.dropPartitions( + "db2", "tbl2", Seq(part3.spec), ignoreIfNotExists = false) } - catalog.dropPartitions("db2", "tbl2", Seq(part3.spec), ignoreIfNotExists = true) + catalog.dropPartitions( + "db2", "tbl2", Seq(part3.spec), ignoreIfNotExists = true) } test("get partition") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala index 862fc275ad..426273e1e3 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala @@ -496,19 +496,25 @@ class SessionCatalogSuite extends SparkFunSuite { val sessionCatalog = new SessionCatalog(externalCatalog) assert(catalogPartitionsEqual(externalCatalog, "db2", "tbl2", Seq(part1, part2))) sessionCatalog.dropPartitions( - TableIdentifier("tbl2", Some("db2")), Seq(part1.spec), ignoreIfNotExists = false) + TableIdentifier("tbl2", Some("db2")), + Seq(part1.spec), + ignoreIfNotExists = false) assert(catalogPartitionsEqual(externalCatalog, "db2", "tbl2", Seq(part2))) // Drop partitions without explicitly specifying database sessionCatalog.setCurrentDatabase("db2") sessionCatalog.dropPartitions( - TableIdentifier("tbl2"), Seq(part2.spec), ignoreIfNotExists = false) + TableIdentifier("tbl2"), + Seq(part2.spec), + ignoreIfNotExists = false) assert(externalCatalog.listPartitions("db2", "tbl2").isEmpty) // Drop multiple partitions at once sessionCatalog.createPartitions( TableIdentifier("tbl2", Some("db2")), Seq(part1, part2), ignoreIfExists = false) assert(catalogPartitionsEqual(externalCatalog, "db2", "tbl2", Seq(part1, part2))) sessionCatalog.dropPartitions( - TableIdentifier("tbl2", Some("db2")), Seq(part1.spec, part2.spec), ignoreIfNotExists = false) + TableIdentifier("tbl2", Some("db2")), + Seq(part1.spec, part2.spec), + ignoreIfNotExists = false) assert(externalCatalog.listPartitions("db2", "tbl2").isEmpty) } @@ -516,11 +522,15 @@ class SessionCatalogSuite extends SparkFunSuite { val catalog = new SessionCatalog(newBasicCatalog()) intercept[AnalysisException] { catalog.dropPartitions( - TableIdentifier("tbl1", Some("does_not_exist")), Seq(), ignoreIfNotExists = false) + TableIdentifier("tbl1", Some("does_not_exist")), + Seq(), + ignoreIfNotExists = false) } intercept[AnalysisException] { catalog.dropPartitions( - TableIdentifier("does_not_exist", Some("db2")), Seq(), ignoreIfNotExists = false) + TableIdentifier("does_not_exist", Some("db2")), + Seq(), + ignoreIfNotExists = false) } } @@ -528,10 +538,14 @@ class SessionCatalogSuite extends SparkFunSuite { val catalog = new SessionCatalog(newBasicCatalog()) intercept[AnalysisException] { catalog.dropPartitions( - TableIdentifier("tbl2", Some("db2")), Seq(part3.spec), ignoreIfNotExists = false) + TableIdentifier("tbl2", Some("db2")), + Seq(part3.spec), + ignoreIfNotExists = false) } catalog.dropPartitions( - TableIdentifier("tbl2", Some("db2")), Seq(part3.spec), ignoreIfNotExists = true) + TableIdentifier("tbl2", Some("db2")), + Seq(part3.spec), + ignoreIfNotExists = true) } test("get partition") { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala index 3da715cdb3..73d9640c35 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala @@ -493,7 +493,9 @@ class SparkSqlAstBuilder extends AstBuilder { */ override def visitAddTablePartition( ctx: AddTablePartitionContext): LogicalPlan = withOrigin(ctx) { - if (ctx.VIEW != null) throw new ParseException(s"Operation not allowed: partitioned views", ctx) + if (ctx.VIEW != null) { + throw new AnalysisException(s"Operation not allowed: partitioned views") + } // Create partition spec to location mapping. val specsAndLocs = if (ctx.partitionSpec.isEmpty) { ctx.partitionSpecLocation.asScala.map { @@ -509,8 +511,7 @@ class SparkSqlAstBuilder extends AstBuilder { AlterTableAddPartition( visitTableIdentifier(ctx.tableIdentifier), specsAndLocs, - ctx.EXISTS != null)( - command(ctx)) + ctx.EXISTS != null) } /** @@ -523,11 +524,8 @@ class SparkSqlAstBuilder extends AstBuilder { */ override def visitExchangeTablePartition( ctx: ExchangeTablePartitionContext): LogicalPlan = withOrigin(ctx) { - AlterTableExchangePartition( - visitTableIdentifier(ctx.from), - visitTableIdentifier(ctx.to), - visitNonOptionalPartitionSpec(ctx.partitionSpec))( - command(ctx)) + throw new AnalysisException( + "Operation not allowed: ALTER TABLE ... EXCHANGE PARTITION ...") } /** @@ -543,8 +541,7 @@ class SparkSqlAstBuilder extends AstBuilder { AlterTableRenamePartition( visitTableIdentifier(ctx.tableIdentifier), visitNonOptionalPartitionSpec(ctx.from), - visitNonOptionalPartitionSpec(ctx.to))( - command(ctx)) + visitNonOptionalPartitionSpec(ctx.to)) } /** @@ -561,13 +558,16 @@ class SparkSqlAstBuilder extends AstBuilder { */ override def visitDropTablePartitions( ctx: DropTablePartitionsContext): LogicalPlan = withOrigin(ctx) { - if (ctx.VIEW != null) throw new ParseException(s"Operation not allowed: partitioned views", ctx) + if (ctx.VIEW != null) { + throw new AnalysisException(s"Operation not allowed: partitioned views") + } + if (ctx.PURGE != null) { + throw new AnalysisException(s"Operation not allowed: PURGE") + } AlterTableDropPartition( visitTableIdentifier(ctx.tableIdentifier), ctx.partitionSpec.asScala.map(visitNonOptionalPartitionSpec), - ctx.EXISTS != null, - ctx.PURGE != null)( - command(ctx)) + ctx.EXISTS != null) } /** @@ -580,10 +580,8 @@ class SparkSqlAstBuilder extends AstBuilder { */ override def visitArchiveTablePartition( ctx: ArchiveTablePartitionContext): LogicalPlan = withOrigin(ctx) { - AlterTableArchivePartition( - visitTableIdentifier(ctx.tableIdentifier), - visitNonOptionalPartitionSpec(ctx.partitionSpec))( - command(ctx)) + throw new AnalysisException( + "Operation not allowed: ALTER TABLE ... ARCHIVE PARTITION ...") } /** @@ -596,10 +594,8 @@ class SparkSqlAstBuilder extends AstBuilder { */ override def visitUnarchiveTablePartition( ctx: UnarchiveTablePartitionContext): LogicalPlan = withOrigin(ctx) { - AlterTableUnarchivePartition( - visitTableIdentifier(ctx.tableIdentifier), - visitNonOptionalPartitionSpec(ctx.partitionSpec))( - command(ctx)) + throw new AnalysisException( + "Operation not allowed: ALTER TABLE ... UNARCHIVE PARTITION ...") } /** @@ -658,10 +654,7 @@ class SparkSqlAstBuilder extends AstBuilder { * }}} */ override def visitTouchTable(ctx: TouchTableContext): LogicalPlan = withOrigin(ctx) { - AlterTableTouch( - visitTableIdentifier(ctx.tableIdentifier), - Option(ctx.partitionSpec).map(visitNonOptionalPartitionSpec))( - command(ctx)) + throw new AnalysisException("Operation not allowed: ALTER TABLE ... TOUCH ...") } /** @@ -673,11 +666,7 @@ class SparkSqlAstBuilder extends AstBuilder { * }}} */ override def visitCompactTable(ctx: CompactTableContext): LogicalPlan = withOrigin(ctx) { - AlterTableCompact( - visitTableIdentifier(ctx.tableIdentifier), - Option(ctx.partitionSpec).map(visitNonOptionalPartitionSpec), - string(ctx.STRING))( - command(ctx)) + throw new AnalysisException("Operation not allowed: ALTER TABLE ... COMPACT ...") } /** @@ -689,10 +678,7 @@ class SparkSqlAstBuilder extends AstBuilder { * }}} */ override def visitConcatenateTable(ctx: ConcatenateTableContext): LogicalPlan = withOrigin(ctx) { - AlterTableMerge( - visitTableIdentifier(ctx.tableIdentifier), - Option(ctx.partitionSpec).map(visitNonOptionalPartitionSpec))( - command(ctx)) + throw new AnalysisException("Operation not allowed: ALTER TABLE ... CONCATENATE") } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala index 8a37cf8f4c..c55b1a690e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala @@ -20,7 +20,8 @@ package org.apache.spark.sql.execution.command import org.apache.spark.internal.Logging import org.apache.spark.sql.{AnalysisException, Row, SQLContext} import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.catalyst.catalog.{CatalogDatabase, CatalogTable, CatalogTableType} +import org.apache.spark.sql.catalyst.catalog.{CatalogDatabase, CatalogTable} +import org.apache.spark.sql.catalyst.catalog.{CatalogTablePartition, CatalogTableType} import org.apache.spark.sql.catalyst.catalog.ExternalCatalog.TablePartitionSpec import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} import org.apache.spark.sql.types._ @@ -348,53 +349,94 @@ case class AlterTableSerDeProperties( } /** - * Add Partition in ALTER TABLE/VIEW: add the table/view partitions. + * Add Partition in ALTER TABLE: add the table partitions. + * * 'partitionSpecsAndLocs': the syntax of ALTER VIEW is identical to ALTER TABLE, * EXCEPT that it is ILLEGAL to specify a LOCATION clause. * An error message will be issued if the partition exists, unless 'ifNotExists' is true. + * + * The syntax of this command is: + * {{{ + * ALTER TABLE table ADD [IF NOT EXISTS] PARTITION spec [LOCATION 'loc1'] + * }}} */ case class AlterTableAddPartition( tableName: TableIdentifier, partitionSpecsAndLocs: Seq[(TablePartitionSpec, Option[String])], - ifNotExists: Boolean)(sql: String) - extends NativeDDLCommand(sql) with Logging + ifNotExists: Boolean) + extends RunnableCommand { + + override def run(sqlContext: SQLContext): Seq[Row] = { + val catalog = sqlContext.sessionState.catalog + val table = catalog.getTableMetadata(tableName) + if (DDLUtils.isDatasourceTable(table)) { + throw new AnalysisException( + "alter table add partition is not allowed for tables defined using the datasource API") + } + val parts = partitionSpecsAndLocs.map { case (spec, location) => + // inherit table storage format (possibly except for location) + CatalogTablePartition(spec, table.storage.copy(locationUri = location)) + } + catalog.createPartitions(tableName, parts, ignoreIfExists = ifNotExists) + Seq.empty[Row] + } +} + +/** + * Alter a table partition's spec. + * + * The syntax of this command is: + * {{{ + * ALTER TABLE table PARTITION spec1 RENAME TO PARTITION spec2; + * }}} + */ case class AlterTableRenamePartition( tableName: TableIdentifier, oldPartition: TablePartitionSpec, - newPartition: TablePartitionSpec)(sql: String) - extends NativeDDLCommand(sql) with Logging + newPartition: TablePartitionSpec) + extends RunnableCommand { -case class AlterTableExchangePartition( - fromTableName: TableIdentifier, - toTableName: TableIdentifier, - spec: TablePartitionSpec)(sql: String) - extends NativeDDLCommand(sql) with Logging + override def run(sqlContext: SQLContext): Seq[Row] = { + sqlContext.sessionState.catalog.renamePartitions( + tableName, Seq(oldPartition), Seq(newPartition)) + Seq.empty[Row] + } + +} /** - * Drop Partition in ALTER TABLE/VIEW: to drop a particular partition for a table/view. + * Drop Partition in ALTER TABLE: to drop a particular partition for a table. + * * This removes the data and metadata for this partition. * The data is actually moved to the .Trash/Current directory if Trash is configured, * unless 'purge' is true, but the metadata is completely lost. * An error message will be issued if the partition does not exist, unless 'ifExists' is true. * Note: purge is always false when the target is a view. + * + * The syntax of this command is: + * {{{ + * ALTER TABLE table DROP [IF EXISTS] PARTITION spec1[, PARTITION spec2, ...] [PURGE]; + * }}} */ case class AlterTableDropPartition( tableName: TableIdentifier, specs: Seq[TablePartitionSpec], - ifExists: Boolean, - purge: Boolean)(sql: String) - extends NativeDDLCommand(sql) with Logging + ifExists: Boolean) + extends RunnableCommand { -case class AlterTableArchivePartition( - tableName: TableIdentifier, - spec: TablePartitionSpec)(sql: String) - extends NativeDDLCommand(sql) with Logging + override def run(sqlContext: SQLContext): Seq[Row] = { + val catalog = sqlContext.sessionState.catalog + val table = catalog.getTableMetadata(tableName) + if (DDLUtils.isDatasourceTable(table)) { + throw new AnalysisException( + "alter table drop partition is not allowed for tables defined using the datasource API") + } + catalog.dropPartitions(tableName, specs, ignoreIfNotExists = ifExists) + Seq.empty[Row] + } -case class AlterTableUnarchivePartition( - tableName: TableIdentifier, - spec: TablePartitionSpec)(sql: String) - extends NativeDDLCommand(sql) with Logging +} case class AlterTableSetFileFormat( tableName: TableIdentifier, @@ -453,22 +495,6 @@ case class AlterTableSetLocation( } -case class AlterTableTouch( - tableName: TableIdentifier, - partitionSpec: Option[TablePartitionSpec])(sql: String) - extends NativeDDLCommand(sql) with Logging - -case class AlterTableCompact( - tableName: TableIdentifier, - partitionSpec: Option[TablePartitionSpec], - compactType: String)(sql: String) - extends NativeDDLCommand(sql) with Logging - -case class AlterTableMerge( - tableName: TableIdentifier, - partitionSpec: Option[TablePartitionSpec])(sql: String) - extends NativeDDLCommand(sql) with Logging - case class AlterTableChangeCol( tableName: TableIdentifier, partitionSpec: Option[TablePartitionSpec], diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala index ac69518ddf..1c8dd68286 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.execution.command +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.catalyst.plans.PlanTest @@ -24,9 +25,17 @@ import org.apache.spark.sql.catalyst.plans.logical.Project import org.apache.spark.sql.execution.SparkSqlParser import org.apache.spark.sql.types._ +// TODO: merge this with DDLSuite (SPARK-14441) class DDLCommandSuite extends PlanTest { private val parser = SparkSqlParser + private def assertUnsupported(sql: String): Unit = { + val e = intercept[AnalysisException] { + parser.parsePlan(sql) + } + assert(e.getMessage.toLowerCase.contains("operation not allowed")) + } + test("create database") { val sql = """ @@ -326,11 +335,11 @@ class DDLCommandSuite extends PlanTest { Seq( (Map("dt" -> "2008-08-08", "country" -> "us"), Some("location1")), (Map("dt" -> "2009-09-09", "country" -> "uk"), None)), - ifNotExists = true)(sql1) + ifNotExists = true) val expected2 = AlterTableAddPartition( TableIdentifier("table_name", None), Seq((Map("dt" -> "2008-08-08"), Some("loc"))), - ifNotExists = false)(sql2) + ifNotExists = false) comparePlans(parsed1, expected1) comparePlans(parsed2, expected2) @@ -369,22 +378,16 @@ class DDLCommandSuite extends PlanTest { val expected = AlterTableRenamePartition( TableIdentifier("table_name", None), Map("dt" -> "2008-08-08", "country" -> "us"), - Map("dt" -> "2008-09-09", "country" -> "uk"))(sql) + Map("dt" -> "2008-09-09", "country" -> "uk")) comparePlans(parsed, expected) } - test("alter table: exchange partition") { - val sql = + test("alter table: exchange partition (not supported)") { + assertUnsupported( """ |ALTER TABLE table_name_1 EXCHANGE PARTITION |(dt='2008-08-08', country='us') WITH TABLE table_name_2 - """.stripMargin - val parsed = parser.parsePlan(sql) - val expected = AlterTableExchangePartition( - TableIdentifier("table_name_1", None), - TableIdentifier("table_name_2", None), - Map("dt" -> "2008-08-08", "country" -> "us"))(sql) - comparePlans(parsed, expected) + """.stripMargin) } // ALTER TABLE table_name DROP [IF EXISTS] PARTITION spec1[, PARTITION spec2, ...] [PURGE] @@ -405,7 +408,10 @@ class DDLCommandSuite extends PlanTest { val sql2_view = sql2_table.replace("TABLE", "VIEW").replace("PURGE", "") val parsed1_table = parser.parsePlan(sql1_table) - val parsed2_table = parser.parsePlan(sql2_table) + val e = intercept[ParseException] { + parser.parsePlan(sql2_table) + } + assert(e.getMessage.contains("Operation not allowed")) intercept[ParseException] { parser.parsePlan(sql1_view) @@ -420,36 +426,17 @@ class DDLCommandSuite extends PlanTest { Seq( Map("dt" -> "2008-08-08", "country" -> "us"), Map("dt" -> "2009-09-09", "country" -> "uk")), - ifExists = true, - purge = false)(sql1_table) - val expected2_table = AlterTableDropPartition( - tableIdent, - Seq( - Map("dt" -> "2008-08-08", "country" -> "us"), - Map("dt" -> "2009-09-09", "country" -> "uk")), - ifExists = false, - purge = true)(sql2_table) + ifExists = true) comparePlans(parsed1_table, expected1_table) - comparePlans(parsed2_table, expected2_table) } - test("alter table: archive partition") { - val sql = "ALTER TABLE table_name ARCHIVE PARTITION (dt='2008-08-08', country='us')" - val parsed = parser.parsePlan(sql) - val expected = AlterTableArchivePartition( - TableIdentifier("table_name", None), - Map("dt" -> "2008-08-08", "country" -> "us"))(sql) - comparePlans(parsed, expected) + test("alter table: archive partition (not supported)") { + assertUnsupported("ALTER TABLE table_name ARCHIVE PARTITION (dt='2008-08-08', country='us')") } - test("alter table: unarchive partition") { - val sql = "ALTER TABLE table_name UNARCHIVE PARTITION (dt='2008-08-08', country='us')" - val parsed = parser.parsePlan(sql) - val expected = AlterTableUnarchivePartition( - TableIdentifier("table_name", None), - Map("dt" -> "2008-08-08", "country" -> "us"))(sql) - comparePlans(parsed, expected) + test("alter table: unarchive partition (not supported)") { + assertUnsupported("ALTER TABLE table_name UNARCHIVE PARTITION (dt='2008-08-08', country='us')") } test("alter table: set file format") { @@ -505,55 +492,24 @@ class DDLCommandSuite extends PlanTest { comparePlans(parsed2, expected2) } - test("alter table: touch") { - val sql1 = "ALTER TABLE table_name TOUCH" - val sql2 = "ALTER TABLE table_name TOUCH PARTITION (dt='2008-08-08', country='us')" - val parsed1 = parser.parsePlan(sql1) - val parsed2 = parser.parsePlan(sql2) - val tableIdent = TableIdentifier("table_name", None) - val expected1 = AlterTableTouch( - tableIdent, - None)(sql1) - val expected2 = AlterTableTouch( - tableIdent, - Some(Map("dt" -> "2008-08-08", "country" -> "us")))(sql2) - comparePlans(parsed1, expected1) - comparePlans(parsed2, expected2) + test("alter table: touch (not supported)") { + assertUnsupported("ALTER TABLE table_name TOUCH") + assertUnsupported("ALTER TABLE table_name TOUCH PARTITION (dt='2008-08-08', country='us')") } - test("alter table: compact") { - val sql1 = "ALTER TABLE table_name COMPACT 'compaction_type'" - val sql2 = + test("alter table: compact (not supported)") { + assertUnsupported("ALTER TABLE table_name COMPACT 'compaction_type'") + assertUnsupported( """ - |ALTER TABLE table_name PARTITION (dt='2008-08-08', country='us') - |COMPACT 'MAJOR' - """.stripMargin - val parsed1 = parser.parsePlan(sql1) - val parsed2 = parser.parsePlan(sql2) - val tableIdent = TableIdentifier("table_name", None) - val expected1 = AlterTableCompact( - tableIdent, - None, - "compaction_type")(sql1) - val expected2 = AlterTableCompact( - tableIdent, - Some(Map("dt" -> "2008-08-08", "country" -> "us")), - "MAJOR")(sql2) - comparePlans(parsed1, expected1) - comparePlans(parsed2, expected2) + |ALTER TABLE table_name PARTITION (dt='2008-08-08', country='us') + |COMPACT 'MAJOR' + """.stripMargin) } - test("alter table: concatenate") { - val sql1 = "ALTER TABLE table_name CONCATENATE" - val sql2 = "ALTER TABLE table_name PARTITION (dt='2008-08-08', country='us') CONCATENATE" - val parsed1 = parser.parsePlan(sql1) - val parsed2 = parser.parsePlan(sql2) - val tableIdent = TableIdentifier("table_name", None) - val expected1 = AlterTableMerge(tableIdent, None)(sql1) - val expected2 = AlterTableMerge( - tableIdent, Some(Map("dt" -> "2008-08-08", "country" -> "us")))(sql2) - comparePlans(parsed1, expected1) - comparePlans(parsed2, expected2) + test("alter table: concatenate (not supported)") { + assertUnsupported("ALTER TABLE table_name CONCATENATE") + assertUnsupported( + "ALTER TABLE table_name PARTITION (dt='2008-08-08', country='us') CONCATENATE") } test("alter table: change column name/type/position/comment") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala index c6479bf33e..40a8b0e614 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala @@ -58,6 +58,10 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { assert(e.getMessage.toLowerCase.contains("operation not allowed")) } + private def maybeWrapException[T](expectException: Boolean)(body: => T): Unit = { + if (expectException) intercept[AnalysisException] { body } else body + } + private def createDatabase(catalog: SessionCatalog, name: String): Unit = { catalog.createDatabase(CatalogDatabase(name, "", "", Map()), ignoreIfExists = false) } @@ -320,6 +324,62 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { assertUnsupported("ALTER TABLE dbx.tab1 NOT STORED AS DIRECTORIES") } + test("alter table: add partition") { + testAddPartitions(isDatasourceTable = false) + } + + test("alter table: add partition (datasource table)") { + testAddPartitions(isDatasourceTable = true) + } + + test("alter table: add partition is not supported for views") { + assertUnsupported("ALTER VIEW dbx.tab1 ADD IF NOT EXISTS PARTITION (b='2')") + } + + test("alter table: drop partition") { + testDropPartitions(isDatasourceTable = false) + } + + test("alter table: drop partition (datasource table)") { + testDropPartitions(isDatasourceTable = true) + } + + test("alter table: drop partition is not supported for views") { + assertUnsupported("ALTER VIEW dbx.tab1 DROP IF EXISTS PARTITION (b='2')") + } + + test("alter table: rename partition") { + val catalog = sqlContext.sessionState.catalog + val tableIdent = TableIdentifier("tab1", Some("dbx")) + val part1 = Map("a" -> "1") + val part2 = Map("b" -> "2") + val part3 = Map("c" -> "3") + createDatabase(catalog, "dbx") + createTable(catalog, tableIdent) + createTablePartition(catalog, part1, tableIdent) + createTablePartition(catalog, part2, tableIdent) + createTablePartition(catalog, part3, tableIdent) + assert(catalog.listPartitions(tableIdent).map(_.spec).toSet == + Set(part1, part2, part3)) + sql("ALTER TABLE dbx.tab1 PARTITION (a='1') RENAME TO PARTITION (a='100')") + sql("ALTER TABLE dbx.tab1 PARTITION (b='2') RENAME TO PARTITION (b='200')") + assert(catalog.listPartitions(tableIdent).map(_.spec).toSet == + Set(Map("a" -> "100"), Map("b" -> "200"), part3)) + // rename without explicitly specifying database + catalog.setCurrentDatabase("dbx") + sql("ALTER TABLE tab1 PARTITION (a='100') RENAME TO PARTITION (a='10')") + assert(catalog.listPartitions(tableIdent).map(_.spec).toSet == + Set(Map("a" -> "10"), Map("b" -> "200"), part3)) + // table to alter does not exist + intercept[AnalysisException] { + sql("ALTER TABLE does_not_exist PARTITION (c='3') RENAME TO PARTITION (c='333')") + } + // partition to rename does not exist + intercept[AnalysisException] { + sql("ALTER TABLE tab1 PARTITION (x='300') RENAME TO PARTITION (x='333')") + } + } + // TODO: ADD a testcase for Drop Database in Restric when we can create tables in SQLContext test("show tables") { @@ -487,10 +547,6 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { assert(storageFormat.locationUri === Some(expected)) } } - // Optionally expect AnalysisException - def maybeWrapException[T](expectException: Boolean)(body: => T): Unit = { - if (expectException) intercept[AnalysisException] { body } else body - } // set table location sql("ALTER TABLE dbx.tab1 SET LOCATION '/path/to/your/lovely/heart'") verifyLocation("/path/to/your/lovely/heart") @@ -564,4 +620,102 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { } } + private def testAddPartitions(isDatasourceTable: Boolean): Unit = { + val catalog = sqlContext.sessionState.catalog + val tableIdent = TableIdentifier("tab1", Some("dbx")) + val part1 = Map("a" -> "1") + val part2 = Map("b" -> "2") + val part3 = Map("c" -> "3") + val part4 = Map("d" -> "4") + createDatabase(catalog, "dbx") + createTable(catalog, tableIdent) + createTablePartition(catalog, part1, tableIdent) + if (isDatasourceTable) { + convertToDatasourceTable(catalog, tableIdent) + } + assert(catalog.listPartitions(tableIdent).map(_.spec).toSet == Set(part1)) + maybeWrapException(isDatasourceTable) { + sql("ALTER TABLE dbx.tab1 ADD IF NOT EXISTS " + + "PARTITION (b='2') LOCATION 'paris' PARTITION (c='3')") + } + if (!isDatasourceTable) { + assert(catalog.listPartitions(tableIdent).map(_.spec).toSet == Set(part1, part2, part3)) + assert(catalog.getPartition(tableIdent, part1).storage.locationUri.isEmpty) + assert(catalog.getPartition(tableIdent, part2).storage.locationUri == Some("paris")) + assert(catalog.getPartition(tableIdent, part3).storage.locationUri.isEmpty) + } + // add partitions without explicitly specifying database + catalog.setCurrentDatabase("dbx") + maybeWrapException(isDatasourceTable) { + sql("ALTER TABLE tab1 ADD IF NOT EXISTS PARTITION (d='4')") + } + if (!isDatasourceTable) { + assert(catalog.listPartitions(tableIdent).map(_.spec).toSet == + Set(part1, part2, part3, part4)) + } + // table to alter does not exist + intercept[AnalysisException] { + sql("ALTER TABLE does_not_exist ADD IF NOT EXISTS PARTITION (d='4')") + } + // partition to add already exists + intercept[AnalysisException] { + sql("ALTER TABLE tab1 ADD PARTITION (d='4')") + } + maybeWrapException(isDatasourceTable) { + sql("ALTER TABLE tab1 ADD IF NOT EXISTS PARTITION (d='4')") + } + if (!isDatasourceTable) { + assert(catalog.listPartitions(tableIdent).map(_.spec).toSet == + Set(part1, part2, part3, part4)) + } + } + + private def testDropPartitions(isDatasourceTable: Boolean): Unit = { + val catalog = sqlContext.sessionState.catalog + val tableIdent = TableIdentifier("tab1", Some("dbx")) + val part1 = Map("a" -> "1") + val part2 = Map("b" -> "2") + val part3 = Map("c" -> "3") + val part4 = Map("d" -> "4") + createDatabase(catalog, "dbx") + createTable(catalog, tableIdent) + createTablePartition(catalog, part1, tableIdent) + createTablePartition(catalog, part2, tableIdent) + createTablePartition(catalog, part3, tableIdent) + createTablePartition(catalog, part4, tableIdent) + assert(catalog.listPartitions(tableIdent).map(_.spec).toSet == + Set(part1, part2, part3, part4)) + if (isDatasourceTable) { + convertToDatasourceTable(catalog, tableIdent) + } + maybeWrapException(isDatasourceTable) { + sql("ALTER TABLE dbx.tab1 DROP IF EXISTS PARTITION (d='4'), PARTITION (c='3')") + } + if (!isDatasourceTable) { + assert(catalog.listPartitions(tableIdent).map(_.spec).toSet == Set(part1, part2)) + } + // drop partitions without explicitly specifying database + catalog.setCurrentDatabase("dbx") + maybeWrapException(isDatasourceTable) { + sql("ALTER TABLE tab1 DROP IF EXISTS PARTITION (b='2')") + } + if (!isDatasourceTable) { + assert(catalog.listPartitions(tableIdent).map(_.spec) == Seq(part1)) + } + // table to alter does not exist + intercept[AnalysisException] { + sql("ALTER TABLE does_not_exist DROP IF EXISTS PARTITION (b='2')") + } + // partition to drop does not exist + intercept[AnalysisException] { + sql("ALTER TABLE tab1 DROP PARTITION (x='300')") + } + maybeWrapException(isDatasourceTable) { + sql("ALTER TABLE tab1 DROP IF EXISTS PARTITION (x='300')") + } + if (!isDatasourceTable) { + assert(catalog.listPartitions(tableIdent).map(_.spec) == Seq(part1)) + } + } + } diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala index 9e3cb18d45..f0eeda09db 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala @@ -376,7 +376,13 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { // Create partitioned view is not supported "create_like_view", - "describe_formatted_view_partitioned" + "describe_formatted_view_partitioned", + + // This uses CONCATENATE, which we don't support + "alter_merge_2", + + // TOUCH is not supported + "touch" ) /** @@ -392,7 +398,6 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "alter2", "alter3", "alter5", - "alter_merge_2", "alter_partition_format_loc", "alter_partition_with_whitelist", "alter_rename_partition", @@ -897,7 +902,6 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "timestamp_comparison", "timestamp_lazy", "timestamp_null", - "touch", "transform_ppr1", "transform_ppr2", "truncate_table", diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala index a49ce33ba1..482f47428d 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala @@ -219,26 +219,7 @@ private[spark] class HiveExternalCatalog(client: HiveClient) extends ExternalCat parts: Seq[TablePartitionSpec], ignoreIfNotExists: Boolean): Unit = withClient { requireTableExists(db, table) - // Note: Unfortunately Hive does not currently support `ignoreIfNotExists` so we - // need to implement it here ourselves. This is currently somewhat expensive because - // we make multiple synchronous calls to Hive for each partition we want to drop. - val partsToDrop = - if (ignoreIfNotExists) { - parts.filter { spec => - try { - getPartition(db, table, spec) - true - } catch { - // Filter out the partitions that do not actually exist - case _: AnalysisException => false - } - } - } else { - parts - } - if (partsToDrop.nonEmpty) { - client.dropPartitions(db, table, partsToDrop) - } + client.dropPartitions(db, table, parts, ignoreIfNotExists) } override def renamePartitions( diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClient.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClient.scala index 94794b1572..6f7e7bf451 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClient.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClient.scala @@ -120,16 +120,13 @@ private[hive] trait HiveClient { ignoreIfExists: Boolean): Unit /** - * Drop one or many partitions in the given table. - * - * Note: Unfortunately, Hive does not currently provide a way to ignore this call if the - * partitions do not already exist. The seemingly relevant flag `ifExists` in - * [[org.apache.hadoop.hive.metastore.PartitionDropOptions]] is not read anywhere. + * Drop one or many partitions in the given table, assuming they exist. */ def dropPartitions( db: String, table: String, - specs: Seq[ExternalCatalog.TablePartitionSpec]): Unit + specs: Seq[ExternalCatalog.TablePartitionSpec], + ignoreIfNotExists: Boolean): Unit /** * Rename one or many existing table partitions, assuming they exist. diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala index a037671ef0..39e26acd7f 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala @@ -26,7 +26,7 @@ import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.apache.hadoop.hive.cli.CliSessionState import org.apache.hadoop.hive.conf.HiveConf -import org.apache.hadoop.hive.metastore.{TableType => HiveTableType} +import org.apache.hadoop.hive.metastore.{PartitionDropOptions, TableType => HiveTableType} import org.apache.hadoop.hive.metastore.api.{Database => HiveDatabase, FieldSchema, Function => HiveFunction, FunctionType, PrincipalType, ResourceType, ResourceUri} import org.apache.hadoop.hive.ql.Driver import org.apache.hadoop.hive.ql.metadata.{Partition => HivePartition, Table => HiveTable} @@ -367,9 +367,25 @@ private[hive] class HiveClientImpl( override def dropPartitions( db: String, table: String, - specs: Seq[ExternalCatalog.TablePartitionSpec]): Unit = withHiveState { + specs: Seq[ExternalCatalog.TablePartitionSpec], + ignoreIfNotExists: Boolean): Unit = withHiveState { // TODO: figure out how to drop multiple partitions in one call - specs.foreach { s => client.dropPartition(db, table, s.values.toList.asJava, true) } + val hiveTable = client.getTable(db, table, true /* throw exception */) + specs.foreach { s => + // The provided spec here can be a partial spec, i.e. it will match all partitions + // whose specs are supersets of this partial spec. E.g. If a table has partitions + // (b='1', c='1') and (b='1', c='2'), a partial spec of (b='1') will match both. + val matchingParts = client.getPartitions(hiveTable, s.asJava).asScala + if (matchingParts.isEmpty && !ignoreIfNotExists) { + throw new AnalysisException( + s"partition to drop '$s' does not exist in table '$table' database '$db'") + } + matchingParts.foreach { hivePartition => + val dropOptions = new PartitionDropOptions + dropOptions.ifExists = ignoreIfNotExists + client.dropPartition(db, table, hivePartition.getValues, dropOptions) + } + } } override def renamePartitions( diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDDLCommandSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDDLCommandSuite.scala index a144da4997..e8086aec32 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDDLCommandSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDDLCommandSuite.scala @@ -41,6 +41,13 @@ class HiveDDLCommandSuite extends PlanTest { }.head } + private def assertUnsupported(sql: String): Unit = { + val e = intercept[ParseException] { + parser.parsePlan(sql) + } + assert(e.getMessage.toLowerCase.contains("unsupported")) + } + test("Test CTAS #1") { val s1 = """CREATE EXTERNAL TABLE IF NOT EXISTS mydb.page_view @@ -367,4 +374,9 @@ class HiveDDLCommandSuite extends PlanTest { parser.parsePlan(v1).isInstanceOf[HiveNativeCommand] } } + + test("MSCK repair table (not supported)") { + assertUnsupported("MSCK REPAIR TABLE tab1") + } + } -- cgit v1.2.3 From 2d81ba542e12db65c2bd67357093244be9403102 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Mon, 11 Apr 2016 22:33:05 -0700 Subject: [SPARK-14362][SPARK-14406][SQL][FOLLOW-UP] DDL Native Support: Drop View and Drop Table #### What changes were proposed in this pull request? In this PR, we are trying to address the comment in the original PR: https://github.com/apache/spark/commit/dfce9665c4b2b29a19e6302216dae2800da68ff9#commitcomment-17057030 In this PR, we checks if table/view exists at the beginning and then does not need to capture the exceptions, including `NoSuchTableException` and `InvalidTableException`. We still capture the NonFatal exception when doing `sqlContext.cacheManager.tryUncacheQuery`. #### How was this patch tested? The existing test cases should cover the code changes of this PR. Author: gatorsmile Closes #12321 from gatorsmile/dropViewFollowup. --- .../apache/spark/sql/execution/command/ddl.scala | 50 +++++++++++----------- 1 file changed, 26 insertions(+), 24 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala index c55b1a690e..758a7e45d2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution.command +import scala.util.control.NonFatal + import org.apache.spark.internal.Logging import org.apache.spark.sql.{AnalysisException, Row, SQLContext} import org.apache.spark.sql.catalyst.TableIdentifier @@ -192,31 +194,31 @@ case class DropTable( override def run(sqlContext: SQLContext): Seq[Row] = { val catalog = sqlContext.sessionState.catalog - // If the command DROP VIEW is to drop a table or DROP TABLE is to drop a view - // issue an exception. - catalog.getTableMetadataOption(tableName).map(_.tableType match { - case CatalogTableType.VIRTUAL_VIEW if !isView => - throw new AnalysisException( - "Cannot drop a view with DROP TABLE. Please use DROP VIEW instead") - case o if o != CatalogTableType.VIRTUAL_VIEW && isView => - throw new AnalysisException( - s"Cannot drop a table with DROP VIEW. Please use DROP TABLE instead") - case _ => - }) - - try { - sqlContext.cacheManager.tryUncacheQuery(sqlContext.table(tableName.quotedString)) - } catch { - // This table's metadata is not in Hive metastore (e.g. the table does not exist). - case e if e.getClass.getName == "org.apache.hadoop.hive.ql.metadata.InvalidTableException" => - case _: org.apache.spark.sql.catalyst.analysis.NoSuchTableException => - // Other Throwables can be caused by users providing wrong parameters in OPTIONS - // (e.g. invalid paths). We catch it and log a warning message. - // Users should be able to drop such kinds of tables regardless if there is an error. - case e: Throwable => log.warn(s"${e.getMessage}", e) + if (!catalog.tableExists(tableName)) { + if (!ifExists) { + val objectName = if (isView) "View" else "Table" + logError(s"$objectName '${tableName.quotedString}' does not exist") + } + } else { + // If the command DROP VIEW is to drop a table or DROP TABLE is to drop a view + // issue an exception. + catalog.getTableMetadataOption(tableName).map(_.tableType match { + case CatalogTableType.VIRTUAL_VIEW if !isView => + throw new AnalysisException( + "Cannot drop a view with DROP TABLE. Please use DROP VIEW instead") + case o if o != CatalogTableType.VIRTUAL_VIEW && isView => + throw new AnalysisException( + s"Cannot drop a table with DROP VIEW. Please use DROP TABLE instead") + case _ => + }) + try { + sqlContext.cacheManager.tryUncacheQuery(sqlContext.table(tableName.quotedString)) + } catch { + case NonFatal(e) => log.warn(s"${e.getMessage}", e) + } + catalog.invalidateTable(tableName) + catalog.dropTable(tableName, ifExists) } - catalog.invalidateTable(tableName) - catalog.dropTable(tableName, ifExists) Seq.empty[Row] } } -- cgit v1.2.3 From 52a801124f429ab133f9a3867c1da6ebd8fa7d4e Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Mon, 11 Apr 2016 22:58:35 -0700 Subject: [SPARK-14554][SQL] disable whole stage codegen if there are too many input columns ## What changes were proposed in this pull request? In https://github.com/apache/spark/pull/12047/files#diff-94a1f59bcc9b6758c4ca874652437634R529, we may split field expressions codes in `CreateExternalRow` to support wide table. However, the whole stage codegen framework doesn't support it, because the input for expressions is not always the input row, but can be `CodeGenContext.currentVars`, which doesn't work well with `CodeGenContext.splitExpressions`. Actually we do have a check to guard against this cases, but it's incomplete, it only checks output fields. This PR improves the whole stage codegen support check, to disable it if there are too many input fields, so that we can avoid splitting field expressions codes in `CreateExternalRow` for whole stage codegen. TODO: Is it a better solution if we can make `CodeGenContext.currentVars` work well with `CodeGenContext.splitExpressions`? ## How was this patch tested? new test in DatasetSuite. Author: Wenchen Fan Closes #12322 from cloud-fan/codegen. --- .../scala/org/apache/spark/sql/execution/WholeStageCodegen.scala | 7 +++++-- sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala | 6 ++++++ 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala index c4594f0480..447dbe7018 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala @@ -446,8 +446,11 @@ case class CollapseCodegenStages(conf: SQLConf) extends Rule[SparkPlan] { case plan: CodegenSupport if plan.supportCodegen => val willFallback = plan.expressions.exists(_.find(e => !supportCodegen(e)).isDefined) // the generated code will be huge if there are too many columns - val haveTooManyFields = numOfNestedFields(plan.schema) > conf.wholeStageMaxNumFields - !willFallback && !haveTooManyFields + val hasTooManyOutputFields = + numOfNestedFields(plan.schema) > conf.wholeStageMaxNumFields + val hasTooManyInputFields = + plan.children.map(p => numOfNestedFields(p.schema)).exists(_ > conf.wholeStageMaxNumFields) + !willFallback && !hasTooManyOutputFields && !hasTooManyInputFields case _ => false } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index e8e801084f..47251681e3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -620,6 +620,12 @@ class DatasetSuite extends QueryTest with SharedSQLContext { val df = streaming.join(static, Seq("b")) assert(df.isStreaming, "streaming Dataset returned false for 'isStreaming'.") } + + test("SPARK-14554: Dataset.map may generate wrong java code for wide table") { + val wideDF = sqlContext.range(10).select(Seq.tabulate(1000) {i => ('id + i).as(s"c$i")} : _*) + // Make sure the generated code for this plan can compile and execute. + wideDF.map(_.getLong(0)).collect() + } } case class OtherTuple(_1: String, _2: Int) -- cgit v1.2.3 From 678b96e77bf77a64b8df14b19db5a3bb18febfe3 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Mon, 11 Apr 2016 22:59:42 -0700 Subject: [SPARK-14535][SQL] Remove buildInternalScan from FileFormat ## What changes were proposed in this pull request? Now `HadoopFsRelation` with all kinds of file formats can be handled in `FileSourceStrategy`, we can remove the branches for `HadoopFsRelation` in `FileSourceStrategy` and the `buildInternalScan` API from `FileFormat`. ## How was this patch tested? existing tests. Author: Wenchen Fan Closes #12300 from cloud-fan/remove. --- .../spark/ml/source/libsvm/LibSVMRelation.scala | 34 +- .../execution/datasources/DataSourceStrategy.scala | 390 --------------------- .../execution/datasources/FileSourceStrategy.scala | 10 +- .../execution/datasources/csv/DefaultSource.scala | 31 -- .../execution/datasources/json/JSONRelation.scala | 29 -- .../datasources/parquet/ParquetRelation.scala | 110 +----- .../execution/datasources/text/DefaultSource.scala | 39 --- .../org/apache/spark/sql/internal/SQLConf.scala | 8 - .../org/apache/spark/sql/sources/interfaces.scala | 10 - .../datasources/FileSourceStrategySuite.scala | 12 - .../apache/spark/sql/hive/orc/OrcRelation.scala | 13 - 11 files changed, 5 insertions(+), 681 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala index 2e9b6be9a2..4737b6fe52 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala @@ -178,39 +178,6 @@ class DefaultSource extends FileFormat with DataSourceRegister { } } - override def buildInternalScan( - sqlContext: SQLContext, - dataSchema: StructType, - requiredColumns: Array[String], - filters: Array[Filter], - bucketSet: Option[BitSet], - inputFiles: Seq[FileStatus], - broadcastedConf: Broadcast[SerializableConfiguration], - options: Map[String, String]): RDD[InternalRow] = { - // TODO: This does not handle cases where column pruning has been performed. - - verifySchema(dataSchema) - val dataFiles = inputFiles.filterNot(_.getPath.getName startsWith "_") - - val path = if (dataFiles.length == 1) dataFiles.head.getPath.toUri.toString - else if (dataFiles.isEmpty) throw new IOException("No input path specified for libsvm data") - else throw new IOException("Multiple input paths are not supported for libsvm data.") - - val numFeatures = options.getOrElse("numFeatures", "-1").toInt - val vectorType = options.getOrElse("vectorType", "sparse") - - val sc = sqlContext.sparkContext - val baseRdd = MLUtils.loadLibSVMFile(sc, path, numFeatures) - val sparse = vectorType == "sparse" - baseRdd.map { pt => - val features = if (sparse) pt.features.toSparse else pt.features.toDense - Row(pt.label, features) - }.mapPartitions { externalRows => - val converter = RowEncoder(dataSchema) - externalRows.map(converter.toRow) - } - } - override def buildReader( sqlContext: SQLContext, dataSchema: StructType, @@ -218,6 +185,7 @@ class DefaultSource extends FileFormat with DataSourceRegister { requiredSchema: StructType, filters: Seq[Filter], options: Map[String, String]): (PartitionedFile) => Iterator[InternalRow] = { + verifySchema(dataSchema) val numFeatures = options("numFeatures").toInt assert(numFeatures > 0) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index 8c183317f6..c3885a3be5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -110,133 +110,6 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { filters, (a, _) => toCatalystRDD(l, a, t.buildScan(a.map(_.name).toArray))) :: Nil - // Scanning partitioned HadoopFsRelation - case PhysicalOperation(projects, filters, l @ LogicalRelation(t: HadoopFsRelation, _, _)) - if t.partitionSchema.nonEmpty => - // We divide the filter expressions into 3 parts - val partitionColumns = AttributeSet( - t.partitionSchema.map(c => l.output.find(_.name == c.name).get)) - - // Only pruning the partition keys - val partitionFilters = filters.filter(_.references.subsetOf(partitionColumns)) - - // Only pushes down predicates that do not reference partition keys. - val pushedFilters = filters.filter(_.references.intersect(partitionColumns).isEmpty) - - // Predicates with both partition keys and attributes - val partitionAndNormalColumnFilters = - filters.toSet -- partitionFilters.toSet -- pushedFilters.toSet - - val selectedPartitions = t.location.listFiles(partitionFilters) - - logInfo { - val total = t.partitionSpec.partitions.length - val selected = selectedPartitions.length - val percentPruned = (1 - selected.toDouble / total.toDouble) * 100 - s"Selected $selected partitions out of $total, pruned $percentPruned% partitions." - } - - // need to add projections from "partitionAndNormalColumnAttrs" in if it is not empty - val partitionAndNormalColumnAttrs = AttributeSet(partitionAndNormalColumnFilters) - val partitionAndNormalColumnProjs = if (partitionAndNormalColumnAttrs.isEmpty) { - projects - } else { - (partitionAndNormalColumnAttrs ++ projects).toSeq - } - - // Prune the buckets based on the pushed filters that do not contain partitioning key - // since the bucketing key is not allowed to use the columns in partitioning key - val bucketSet = getBuckets(pushedFilters, t.bucketSpec) - val scan = buildPartitionedTableScan( - l, - partitionAndNormalColumnProjs, - pushedFilters, - bucketSet, - t.partitionSpec.partitionColumns, - selectedPartitions, - t.options) - - // Add a Projection to guarantee the original projection: - // this is because "partitionAndNormalColumnAttrs" may be different - // from the original "projects", in elements or their ordering - - partitionAndNormalColumnFilters.reduceLeftOption(expressions.And).map(cf => - if (projects.isEmpty || projects == partitionAndNormalColumnProjs) { - // if the original projection is empty, no need for the additional Project either - execution.Filter(cf, scan) - } else { - execution.Project(projects, execution.Filter(cf, scan)) - } - ).getOrElse(scan) :: Nil - - // TODO: The code for planning bucketed/unbucketed/partitioned/unpartitioned tables contains - // a lot of duplication and produces overly complicated RDDs. - - // Scanning non-partitioned HadoopFsRelation - case PhysicalOperation(projects, filters, l @ LogicalRelation(t: HadoopFsRelation, _, _)) => - // See buildPartitionedTableScan for the reason that we need to create a shard - // broadcast HadoopConf. - val sharedHadoopConf = SparkHadoopUtil.get.conf - val confBroadcast = - t.sqlContext.sparkContext.broadcast(new SerializableConfiguration(sharedHadoopConf)) - - t.bucketSpec match { - case Some(spec) if t.sqlContext.conf.bucketingEnabled => - val scanBuilder: (Seq[Attribute], Array[Filter]) => RDD[InternalRow] = { - (requiredColumns: Seq[Attribute], filters: Array[Filter]) => { - val bucketed = - t.location - .allFiles() - .filterNot(_.getPath.getName startsWith "_") - .groupBy { f => - BucketingUtils - .getBucketId(f.getPath.getName) - .getOrElse(sys.error(s"Invalid bucket file ${f.getPath}")) - } - - val bucketedDataMap = bucketed.mapValues { bucketFiles => - t.fileFormat.buildInternalScan( - t.sqlContext, - t.dataSchema, - requiredColumns.map(_.name).toArray, - filters, - None, - bucketFiles, - confBroadcast, - t.options).coalesce(1) - } - - val bucketedRDD = new UnionRDD(t.sqlContext.sparkContext, - (0 until spec.numBuckets).map { bucketId => - bucketedDataMap.getOrElse(bucketId, t.sqlContext.emptyResult: RDD[InternalRow]) - }) - bucketedRDD - } - } - - pruneFilterProject( - l, - projects, - filters, - scanBuilder) :: Nil - - case _ => - pruneFilterProject( - l, - projects, - filters, - (a, f) => - t.fileFormat.buildInternalScan( - t.sqlContext, - t.dataSchema, - a.map(_.name).toArray, - f, - None, - t.location.allFiles(), - confBroadcast, - t.options)) :: Nil - } - case l @ LogicalRelation(baseRelation: TableScan, _, _) => execution.DataSourceScan.create( l.output, toCatalystRDD(l, baseRelation.buildScan()), baseRelation) :: Nil @@ -248,218 +121,6 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { case _ => Nil } - private def buildPartitionedTableScan( - logicalRelation: LogicalRelation, - projections: Seq[NamedExpression], - filters: Seq[Expression], - buckets: Option[BitSet], - partitionColumns: StructType, - partitions: Seq[Partition], - options: Map[String, String]): SparkPlan = { - val relation = logicalRelation.relation.asInstanceOf[HadoopFsRelation] - - // Because we are creating one RDD per partition, we need to have a shared HadoopConf. - // Otherwise, the cost of broadcasting HadoopConf in every RDD will be high. - val sharedHadoopConf = SparkHadoopUtil.get.conf - val confBroadcast = - relation.sqlContext.sparkContext.broadcast(new SerializableConfiguration(sharedHadoopConf)) - val partitionColumnNames = partitionColumns.fieldNames.toSet - - // Now, we create a scan builder, which will be used by pruneFilterProject. This scan builder - // will union all partitions and attach partition values if needed. - val scanBuilder: (Seq[Attribute], Array[Filter]) => RDD[InternalRow] = { - (requiredColumns: Seq[Attribute], filters: Array[Filter]) => { - - relation.bucketSpec match { - case Some(spec) if relation.sqlContext.conf.bucketingEnabled => - val requiredDataColumns = - requiredColumns.filterNot(c => partitionColumnNames.contains(c.name)) - - // Builds RDD[Row]s for each selected partition. - val perPartitionRows: Seq[(Int, RDD[InternalRow])] = partitions.flatMap { - case Partition(partitionValues, files) => - val bucketed = files.groupBy { f => - BucketingUtils - .getBucketId(f.getPath.getName) - .getOrElse(sys.error(s"Invalid bucket file ${f.getPath}")) - } - - bucketed.map { bucketFiles => - // Don't scan any partition columns to save I/O. Here we are being optimistic and - // assuming partition columns data stored in data files are always consistent with - // those partition values encoded in partition directory paths. - val dataRows = relation.fileFormat.buildInternalScan( - relation.sqlContext, - relation.dataSchema, - requiredDataColumns.map(_.name).toArray, - filters, - buckets, - bucketFiles._2, - confBroadcast, - options) - - // Merges data values with partition values. - bucketFiles._1 -> mergeWithPartitionValues( - requiredColumns, - requiredDataColumns, - partitionColumns, - partitionValues, - dataRows) - } - } - - val bucketedDataMap: Map[Int, Seq[RDD[InternalRow]]] = - perPartitionRows.groupBy(_._1).mapValues(_.map(_._2)) - - val bucketed = new UnionRDD(relation.sqlContext.sparkContext, - (0 until spec.numBuckets).map { bucketId => - bucketedDataMap.get(bucketId).map(i => i.reduce(_ ++ _).coalesce(1)).getOrElse { - relation.sqlContext.emptyResult: RDD[InternalRow] - } - }) - bucketed - - case _ => - val requiredDataColumns = - requiredColumns.filterNot(c => partitionColumnNames.contains(c.name)) - - // Builds RDD[Row]s for each selected partition. - val perPartitionRows = partitions.map { - case Partition(partitionValues, files) => - val dataRows = relation.fileFormat.buildInternalScan( - relation.sqlContext, - relation.dataSchema, - requiredDataColumns.map(_.name).toArray, - filters, - buckets, - files, - confBroadcast, - options) - - // Merges data values with partition values. - mergeWithPartitionValues( - requiredColumns, - requiredDataColumns, - partitionColumns, - partitionValues, - dataRows) - } - new UnionRDD(relation.sqlContext.sparkContext, perPartitionRows) - } - } - } - - // Create the scan operator. If needed, add Filter and/or Project on top of the scan. - // The added Filter/Project is on top of the unioned RDD. We do not want to create - // one Filter/Project for every partition. - val sparkPlan = pruneFilterProject( - logicalRelation, - projections, - filters, - scanBuilder) - - sparkPlan - } - - /** - * Creates a ColumnarBatch that contains the values for `requiredColumns`. These columns can - * either come from `input` (columns scanned from the data source) or from the partitioning - * values (data from `partitionValues`). This is done *once* per physical partition. When - * the column is from `input`, it just references the same underlying column. When using - * partition columns, the column is populated once. - * TODO: there's probably a cleaner way to do this. - */ - private def projectedColumnBatch( - input: ColumnarBatch, - requiredColumns: Seq[Attribute], - dataColumns: Seq[Attribute], - partitionColumnSchema: StructType, - partitionValues: InternalRow) : ColumnarBatch = { - val result = ColumnarBatch.allocate(StructType.fromAttributes(requiredColumns)) - var resultIdx = 0 - var inputIdx = 0 - - while (resultIdx < requiredColumns.length) { - val attr = requiredColumns(resultIdx) - if (inputIdx < dataColumns.length && requiredColumns(resultIdx) == dataColumns(inputIdx)) { - result.setColumn(resultIdx, input.column(inputIdx)) - inputIdx += 1 - } else { - require(partitionColumnSchema.fields.count(_.name == attr.name) == 1) - var partitionIdx = 0 - partitionColumnSchema.fields.foreach { f => { - if (f.name.equals(attr.name)) { - ColumnVectorUtils.populate(result.column(resultIdx), partitionValues, partitionIdx) - } - partitionIdx += 1 - }} - } - resultIdx += 1 - } - result - } - - private def mergeWithPartitionValues( - requiredColumns: Seq[Attribute], - dataColumns: Seq[Attribute], - partitionColumnSchema: StructType, - partitionValues: InternalRow, - dataRows: RDD[InternalRow]): RDD[InternalRow] = { - // If output columns contain any partition column(s), we need to merge scanned data - // columns and requested partition columns to form the final result. - if (requiredColumns != dataColumns) { - // Builds `AttributeReference`s for all partition columns so that we can use them to project - // required partition columns. Note that if a partition column appears in `requiredColumns`, - // we should use the `AttributeReference` in `requiredColumns`. - val partitionColumns = { - val requiredColumnMap = requiredColumns.map(a => a.name -> a).toMap - partitionColumnSchema.toAttributes.map { a => - requiredColumnMap.getOrElse(a.name, a) - } - } - - val mapPartitionsFunc = (_: TaskContext, _: Int, iterator: Iterator[Object]) => { - // Note that we can't use an `UnsafeRowJoiner` to replace the following `JoinedRow` and - // `UnsafeProjection`. Because the projection may also adjust column order. - val mutableJoinedRow = new JoinedRow() - val unsafePartitionValues = UnsafeProjection.create(partitionColumnSchema)(partitionValues) - val unsafeProjection = - UnsafeProjection.create(requiredColumns, dataColumns ++ partitionColumns) - - // If we are returning batches directly, we need to augment them with the partitioning - // columns. We want to do this without a row by row operation. - var columnBatch: ColumnarBatch = null - var mergedBatch: ColumnarBatch = null - - iterator.map { input => { - if (input.isInstanceOf[InternalRow]) { - unsafeProjection(mutableJoinedRow( - input.asInstanceOf[InternalRow], unsafePartitionValues)) - } else { - require(input.isInstanceOf[ColumnarBatch]) - val inputBatch = input.asInstanceOf[ColumnarBatch] - if (inputBatch != mergedBatch) { - mergedBatch = inputBatch - columnBatch = projectedColumnBatch(inputBatch, requiredColumns, - dataColumns, partitionColumnSchema, partitionValues) - } - columnBatch.setNumRows(inputBatch.numRows()) - columnBatch - } - }} - } - - // This is an internal RDD whose call site the user should not be concerned with - // Since we create many of these (one per partition), the time spent on computing - // the call site may add up. - Utils.withDummyCallSite(dataRows.sparkContext) { - new MapPartitionsRDD(dataRows, mapPartitionsFunc, preservesPartitioning = false) - }.asInstanceOf[RDD[InternalRow]] - } else { - dataRows - } - } - // Get the bucket ID based on the bucketing values. // Restriction: Bucket pruning works iff the bucketing column has one and only one column. def getBucketId(bucketColumn: Attribute, numBuckets: Int, value: Any): Int = { @@ -472,57 +133,6 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { bucketIdGeneration(mutableRow).getInt(0) } - // Get the bucket BitSet by reading the filters that only contains bucketing keys. - // Note: When the returned BitSet is None, no pruning is possible. - // Restriction: Bucket pruning works iff the bucketing column has one and only one column. - private def getBuckets( - filters: Seq[Expression], - bucketSpec: Option[BucketSpec]): Option[BitSet] = { - - if (bucketSpec.isEmpty || - bucketSpec.get.numBuckets == 1 || - bucketSpec.get.bucketColumnNames.length != 1) { - // None means all the buckets need to be scanned - return None - } - - // Just get the first because bucketing pruning only works when the column has one column - val bucketColumnName = bucketSpec.get.bucketColumnNames.head - val numBuckets = bucketSpec.get.numBuckets - val matchedBuckets = new BitSet(numBuckets) - matchedBuckets.clear() - - filters.foreach { - case expressions.EqualTo(a: Attribute, Literal(v, _)) if a.name == bucketColumnName => - matchedBuckets.set(getBucketId(a, numBuckets, v)) - case expressions.EqualTo(Literal(v, _), a: Attribute) if a.name == bucketColumnName => - matchedBuckets.set(getBucketId(a, numBuckets, v)) - case expressions.EqualNullSafe(a: Attribute, Literal(v, _)) if a.name == bucketColumnName => - matchedBuckets.set(getBucketId(a, numBuckets, v)) - case expressions.EqualNullSafe(Literal(v, _), a: Attribute) if a.name == bucketColumnName => - matchedBuckets.set(getBucketId(a, numBuckets, v)) - // Because we only convert In to InSet in Optimizer when there are more than certain - // items. So it is possible we still get an In expression here that needs to be pushed - // down. - case expressions.In(a: Attribute, list) - if list.forall(_.isInstanceOf[Literal]) && a.name == bucketColumnName => - val hSet = list.map(e => e.eval(EmptyRow)) - hSet.foreach(e => matchedBuckets.set(getBucketId(a, numBuckets, e))) - case expressions.IsNull(a: Attribute) if a.name == bucketColumnName => - matchedBuckets.set(getBucketId(a, numBuckets, null)) - case _ => - } - - logInfo { - val selected = matchedBuckets.cardinality() - val percentPruned = (1 - selected.toDouble / numBuckets.toDouble) * 100 - s"Selected $selected buckets out of $numBuckets, pruned $percentPruned% partitions." - } - - // None means all the buckets need to be scanned - if (matchedBuckets.cardinality() == 0) None else Some(matchedBuckets) - } - // Based on Public API. protected def pruneFilterProject( relation: LogicalRelation, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala index aa1f76450c..bcddf72851 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala @@ -55,15 +55,7 @@ import org.apache.spark.sql.sources._ */ private[sql] object FileSourceStrategy extends Strategy with Logging { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { - case PhysicalOperation(projects, filters, l @ LogicalRelation(files: HadoopFsRelation, _, _)) - if (files.fileFormat.toString == "TestFileFormat" || - files.fileFormat.isInstanceOf[parquet.DefaultSource] || - files.fileFormat.toString == "ORC" || - files.fileFormat.toString == "LibSVM" || - files.fileFormat.isInstanceOf[csv.DefaultSource] || - files.fileFormat.isInstanceOf[text.DefaultSource] || - files.fileFormat.isInstanceOf[json.DefaultSource]) && - files.sqlContext.conf.useFileScan => + case PhysicalOperation(projects, filters, l @ LogicalRelation(files: HadoopFsRelation, _, _)) => // Filters on this relation fall into four categories based on where we can use them to avoid // reading unneeded data: // - partition keys only - used to prune directories to read diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/DefaultSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/DefaultSource.scala index 34fcbdf871..06a371b88b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/DefaultSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/DefaultSource.scala @@ -133,37 +133,6 @@ class DefaultSource extends FileFormat with DataSourceRegister { } } - /** - * This supports to eliminate unneeded columns before producing an RDD - * containing all of its tuples as Row objects. This reads all the tokens of each line - * and then drop unneeded tokens without casting and type-checking by mapping - * both the indices produced by `requiredColumns` and the ones of tokens. - */ - override def buildInternalScan( - sqlContext: SQLContext, - dataSchema: StructType, - requiredColumns: Array[String], - filters: Array[Filter], - bucketSet: Option[BitSet], - inputFiles: Seq[FileStatus], - broadcastedConf: Broadcast[SerializableConfiguration], - options: Map[String, String]): RDD[InternalRow] = { - // TODO: Filter before calling buildInternalScan. - val csvFiles = inputFiles.filterNot(_.getPath.getName startsWith "_") - - val csvOptions = new CSVOptions(options) - val pathsString = csvFiles.map(_.getPath.toUri.toString) - val header = dataSchema.fields.map(_.name) - val tokenizedRdd = tokenRdd(sqlContext, csvOptions, header, pathsString) - val rows = CSVRelation.parseCsv(tokenizedRdd, dataSchema, requiredColumns, csvOptions) - - val requiredDataSchema = StructType(requiredColumns.map(c => dataSchema.find(_.name == c).get)) - rows.mapPartitions { iterator => - val unsafeProjection = UnsafeProjection.create(requiredDataSchema) - iterator.map(unsafeProjection) - } - } - private def baseRdd( sqlContext: SQLContext, options: CSVOptions, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala index 42cd25a18c..f32fea4183 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala @@ -93,35 +93,6 @@ class DefaultSource extends FileFormat with DataSourceRegister { } } - override def buildInternalScan( - sqlContext: SQLContext, - dataSchema: StructType, - requiredColumns: Array[String], - filters: Array[Filter], - bucketSet: Option[BitSet], - inputFiles: Seq[FileStatus], - broadcastedConf: Broadcast[SerializableConfiguration], - options: Map[String, String]): RDD[InternalRow] = { - // TODO: Filter files for all formats before calling buildInternalScan. - val jsonFiles = inputFiles.filterNot(_.getPath.getName startsWith "_") - - val parsedOptions: JSONOptions = new JSONOptions(options) - val requiredDataSchema = StructType(requiredColumns.map(dataSchema(_))) - val columnNameOfCorruptRecord = - parsedOptions.columnNameOfCorruptRecord - .getOrElse(sqlContext.conf.columnNameOfCorruptRecord) - val rows = JacksonParser.parse( - createBaseRdd(sqlContext, jsonFiles), - requiredDataSchema, - columnNameOfCorruptRecord, - parsedOptions) - - rows.mapPartitions { iterator => - val unsafeProjection = UnsafeProjection.create(requiredDataSchema) - iterator.map(unsafeProjection) - } - } - override def buildReader( sqlContext: SQLContext, dataSchema: StructType, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala index bcb2b2de13..dbda094996 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala @@ -251,12 +251,12 @@ private[sql] class DefaultSource } /** - * Returns whether the reader will the rows as batch or not. + * Returns whether the reader will return the rows as batch or not. */ override def supportBatch(sqlContext: SQLContext, schema: StructType): Boolean = { val conf = SQLContext.getActive().get.conf - conf.useFileScan && conf.parquetVectorizedReaderEnabled && - conf.wholeStageEnabled && schema.length <= conf.wholeStageMaxNumFields && + conf.parquetVectorizedReaderEnabled && conf.wholeStageEnabled && + schema.length <= conf.wholeStageMaxNumFields && schema.forall(_.dataType.isInstanceOf[AtomicType]) } @@ -375,110 +375,6 @@ private[sql] class DefaultSource } } } - - override def buildInternalScan( - sqlContext: SQLContext, - dataSchema: StructType, - requiredColumns: Array[String], - filters: Array[Filter], - bucketSet: Option[BitSet], - allFiles: Seq[FileStatus], - broadcastedConf: Broadcast[SerializableConfiguration], - options: Map[String, String]): RDD[InternalRow] = { - val useMetadataCache = sqlContext.getConf(SQLConf.PARQUET_CACHE_METADATA) - val parquetFilterPushDown = sqlContext.conf.parquetFilterPushDown - val assumeBinaryIsString = sqlContext.conf.isParquetBinaryAsString - val assumeInt96IsTimestamp = sqlContext.conf.isParquetINT96AsTimestamp - - // Parquet row group size. We will use this value as the value for - // mapreduce.input.fileinputformat.split.minsize and mapred.min.split.size if the value - // of these flags are smaller than the parquet row group size. - val parquetBlockSize = ParquetOutputFormat.getLongBlockSize(broadcastedConf.value.value) - - // Create the function to set variable Parquet confs at both driver and executor side. - val initLocalJobFuncOpt = - ParquetRelation.initializeLocalJobFunc( - requiredColumns, - filters, - dataSchema, - parquetBlockSize, - useMetadataCache, - parquetFilterPushDown, - assumeBinaryIsString, - assumeInt96IsTimestamp) _ - - val inputFiles = splitFiles(allFiles).data.toArray - - // Create the function to set input paths at the driver side. - val setInputPaths = - ParquetRelation.initializeDriverSideJobFunc(inputFiles, parquetBlockSize) _ - - val allPrimitiveTypes = dataSchema.forall(_.dataType.isInstanceOf[AtomicType]) - val inputFormatCls = if (sqlContext.conf.parquetVectorizedReaderEnabled - && allPrimitiveTypes) { - classOf[VectorizedParquetInputFormat] - } else { - classOf[ParquetInputFormat[InternalRow]] - } - - Utils.withDummyCallSite(sqlContext.sparkContext) { - new SqlNewHadoopRDD( - sqlContext = sqlContext, - broadcastedConf = broadcastedConf, - initDriverSideJobFuncOpt = Some(setInputPaths), - initLocalJobFuncOpt = Some(initLocalJobFuncOpt), - inputFormatClass = inputFormatCls, - valueClass = classOf[InternalRow]) { - - val cacheMetadata = useMetadataCache - - @transient val cachedStatuses = inputFiles.map { f => - // In order to encode the authority of a Path containing special characters such as '/' - // (which does happen in some S3N credentials), we need to use the string returned by the - // URI of the path to create a new Path. - val pathWithEscapedAuthority = escapePathUserInfo(f.getPath) - new FileStatus( - f.getLen, f.isDirectory, f.getReplication, f.getBlockSize, f.getModificationTime, - f.getAccessTime, f.getPermission, f.getOwner, f.getGroup, pathWithEscapedAuthority) - }.toSeq - - private def escapePathUserInfo(path: Path): Path = { - val uri = path.toUri - new Path(new URI( - uri.getScheme, uri.getRawUserInfo, uri.getHost, uri.getPort, uri.getPath, - uri.getQuery, uri.getFragment)) - } - - // Overridden so we can inject our own cached files statuses. - override def getPartitions: Array[SparkPartition] = { - val inputFormat = new ParquetInputFormat[InternalRow] { - override def listStatus(jobContext: JobContext): JList[FileStatus] = { - if (cacheMetadata) cachedStatuses.asJava else super.listStatus(jobContext) - } - } - - val jobContext = new JobContextImpl(getConf(isDriverSide = true), jobId) - val rawSplits = inputFormat.getSplits(jobContext) - - Array.tabulate[SparkPartition](rawSplits.size) { i => - new SqlNewHadoopPartition( - id, i, rawSplits.get(i).asInstanceOf[InputSplit with Writable]) - } - } - } - } - } -} - -/** - * The ParquetInputFormat that create VectorizedParquetRecordReader. - */ -final class VectorizedParquetInputFormat extends ParquetInputFormat[InternalRow] { - override def createRecordReader( - inputSplit: InputSplit, - taskAttemptContext: TaskAttemptContext): RecordReader[Void, InternalRow] = { - new VectorizedParquetRecordReader().asInstanceOf[RecordReader[Void, InternalRow]] - } } // NOTE: This class is instantiated and used on executor side only, no need to be serializable. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/DefaultSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/DefaultSource.scala index 99459ba1d3..28b03ee7c3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/DefaultSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/DefaultSource.scala @@ -88,45 +88,6 @@ class DefaultSource extends FileFormat with DataSourceRegister { } } - override def buildInternalScan( - sqlContext: SQLContext, - dataSchema: StructType, - requiredColumns: Array[String], - filters: Array[Filter], - bucketSet: Option[BitSet], - inputFiles: Seq[FileStatus], - broadcastedConf: Broadcast[SerializableConfiguration], - options: Map[String, String]): RDD[InternalRow] = { - verifySchema(dataSchema) - - val job = Job.getInstance(sqlContext.sparkContext.hadoopConfiguration) - val conf = job.getConfiguration - val paths = inputFiles - .filterNot(_.getPath.getName startsWith "_") - .map(_.getPath) - .sortBy(_.toUri) - - if (paths.nonEmpty) { - FileInputFormat.setInputPaths(job, paths: _*) - } - - sqlContext.sparkContext.hadoopRDD( - conf.asInstanceOf[JobConf], classOf[TextInputFormat], classOf[LongWritable], classOf[Text]) - .mapPartitions { iter => - val unsafeRow = new UnsafeRow(1) - val bufferHolder = new BufferHolder(unsafeRow) - val unsafeRowWriter = new UnsafeRowWriter(bufferHolder, 1) - - iter.map { case (_, line) => - // Writes to an UnsafeRow directly - bufferHolder.reset() - unsafeRowWriter.write(0, line.getBytes, 0, line.getLength) - unsafeRow.setTotalSize(bufferHolder.totalSize()) - unsafeRow - } - } - } - override def buildReader( sqlContext: SQLContext, dataSchema: StructType, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index b58f960897..e74fb00cb2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -145,12 +145,6 @@ object SQLConf { .booleanConf .createWithDefault(true) - val USE_FILE_SCAN = SQLConfigBuilder("spark.sql.sources.fileScan") - .internal() - .doc("Use the new FileScanRDD path for reading HDSF based data sources.") - .booleanConf - .createWithDefault(true) - val PARQUET_SCHEMA_MERGING_ENABLED = SQLConfigBuilder("spark.sql.parquet.mergeSchema") .doc("When true, the Parquet data source merges schemas collected from all data files, " + "otherwise the schema is picked from the summary file or a random data file " + @@ -481,8 +475,6 @@ private[sql] class SQLConf extends Serializable with CatalystConf with Logging { def useCompression: Boolean = getConf(COMPRESS_CACHED) - def useFileScan: Boolean = getConf(USE_FILE_SCAN) - def parquetCompressionCodec: String = getConf(PARQUET_COMPRESSION) def parquetCacheMetadata: Boolean = getConf(PARQUET_CACHE_METADATA) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala index 6acb41dd1f..65b1f61349 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala @@ -458,16 +458,6 @@ trait FileFormat { options: Map[String, String], dataSchema: StructType): OutputWriterFactory - def buildInternalScan( - sqlContext: SQLContext, - dataSchema: StructType, - requiredColumns: Array[String], - filters: Array[Filter], - bucketSet: Option[BitSet], - inputFiles: Seq[FileStatus], - broadcastedConf: Broadcast[SerializableConfiguration], - options: Map[String, String]): RDD[InternalRow] - /** * Returns whether this format support returning columnar batch or not. * diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala index 41f536fc37..90d7f53884 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala @@ -365,18 +365,6 @@ class TestFileFormat extends FileFormat { throw new NotImplementedError("JUST FOR TESTING") } - override def buildInternalScan( - sqlContext: SQLContext, - dataSchema: StructType, - requiredColumns: Array[String], - filters: Array[Filter], - bucketSet: Option[BitSet], - inputFiles: Seq[FileStatus], - broadcastedConf: Broadcast[SerializableConfiguration], - options: Map[String, String]): RDD[InternalRow] = { - throw new NotImplementedError("JUST FOR TESTING") - } - override def buildReader( sqlContext: SQLContext, dataSchema: StructType, diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala index 43f445edcb..e915f3dfe2 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala @@ -111,19 +111,6 @@ private[sql] class DefaultSource } } - override def buildInternalScan( - sqlContext: SQLContext, - dataSchema: StructType, - requiredColumns: Array[String], - filters: Array[Filter], - bucketSet: Option[BitSet], - inputFiles: Seq[FileStatus], - broadcastedConf: Broadcast[SerializableConfiguration], - options: Map[String, String]): RDD[InternalRow] = { - val output = StructType(requiredColumns.map(dataSchema(_))).toAttributes - OrcTableScan(sqlContext, output, filters, inputFiles).execute() - } - override def buildReader( sqlContext: SQLContext, dataSchema: StructType, -- cgit v1.2.3 From b0f5497e9520575e5082fa8ce8be5569f43abe74 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Tue, 12 Apr 2016 00:43:28 -0700 Subject: [SPARK-14508][BUILD] Add a new ScalaStyle Rule `OmitBracesInCase` ## What changes were proposed in this pull request? According to the [Spark Code Style Guide](https://cwiki.apache.org/confluence/display/SPARK/Spark+Code+Style+Guide) and [Scala Style Guide](http://docs.scala-lang.org/style/control-structures.html#curlybraces), we had better enforce the following rule. ``` case: Always omit braces in case clauses. ``` This PR makes a new ScalaStyle rule, 'OmitBracesInCase', and enforces it to the code. ## How was this patch tested? Pass the Jenkins tests (including Scala style checking) Author: Dongjoon Hyun Closes #12280 from dongjoon-hyun/SPARK-14508. --- .../main/scala/org/apache/spark/SparkContext.scala | 12 ++---- .../src/main/scala/org/apache/spark/SparkEnv.scala | 3 +- .../apache/spark/api/python/PythonHadoopUtil.scala | 3 +- .../org/apache/spark/api/python/PythonRDD.scala | 2 +- .../org/apache/spark/deploy/SparkHadoopUtil.scala | 12 ++---- .../org/apache/spark/deploy/master/Master.scala | 48 ++++++++-------------- .../spark/deploy/master/MasterArguments.scala | 2 +- .../deploy/master/ZooKeeperPersistenceEngine.scala | 3 +- .../mesos/MesosClusterDispatcherArguments.scala | 3 +- .../spark/deploy/worker/ExecutorRunner.scala | 6 +-- .../org/apache/spark/deploy/worker/Worker.scala | 12 ++---- .../spark/deploy/worker/WorkerArguments.scala | 3 +- .../executor/CoarseGrainedExecutorBackend.scala | 3 +- .../org/apache/spark/metrics/MetricsSystem.scala | 3 +- .../org/apache/spark/partial/BoundedDouble.scala | 3 +- .../org/apache/spark/rdd/DoubleRDDFunctions.scala | 4 +- .../org/apache/spark/rdd/OrderedRDDFunctions.scala | 3 +- .../apache/spark/rdd/ParallelCollectionRDD.scala | 9 ++-- .../spark/rdd/PartitionerAwareUnionRDD.scala | 3 +- .../apache/spark/scheduler/InputFormatInfo.scala | 6 +-- .../org/apache/spark/scheduler/SplitInfo.scala | 3 +- .../apache/spark/scheduler/TaskResultGetter.scala | 2 +- .../apache/spark/scheduler/TaskSetManager.scala | 9 ++-- .../mesos/MesosClusterPersistenceEngine.scala | 3 +- .../cluster/mesos/MesosSchedulerBackendUtil.scala | 6 +-- .../cluster/mesos/MesosSchedulerUtils.scala | 6 +-- .../org/apache/spark/serializer/Serializer.scala | 3 +- .../storage/ShuffleBlockFetcherIterator.scala | 6 +-- .../spark/ui/exec/ExecutorThreadDumpPage.scala | 3 +- .../scala/org/apache/spark/util/EventLoop.scala | 3 +- .../org/apache/spark/util/SizeEstimator.scala | 3 +- .../scala/org/apache/spark/DistributedSuite.scala | 2 +- .../org/apache/spark/SparkContextInfoSuite.scala | 6 +-- .../scala/org/apache/spark/UnpersistSuite.scala | 2 +- .../scala/org/apache/spark/rpc/RpcEnvSuite.scala | 15 +++---- .../apache/spark/examples/CassandraCQLTest.scala | 6 +-- .../org/apache/spark/examples/CassandraTest.scala | 6 +-- .../scala/org/apache/spark/examples/LocalALS.scala | 6 +-- .../spark/examples/ml/OneVsRestExample.scala | 6 +-- .../spark/examples/mllib/DecisionTreeRunner.scala | 6 +-- .../streaming/kinesis/KinesisRecordProcessor.scala | 25 +++++------ .../org/apache/spark/ml/r/SparkRWrappers.scala | 6 +-- .../spark/mllib/clustering/GaussianMixture.scala | 3 +- .../mllib/clustering/GaussianMixtureModel.scala | 3 +- .../org/apache/spark/mllib/clustering/KMeans.scala | 6 +-- .../mllib/stat/test/KolmogorovSmirnovTest.scala | 3 +- .../apache/spark/repl/ExecutorClassLoader.scala | 36 ++++++++-------- scalastyle-config.xml | 5 +++ .../spark/sql/catalyst/expressions/Cast.scala | 3 +- .../expressions/EquivalentExpressions.scala | 4 +- .../org/apache/spark/sql/RandomDataGenerator.scala | 18 +++----- .../org/apache/spark/sql/DataFrameSuite.scala | 9 ++-- .../execution/vectorized/ColumnarBatchSuite.scala | 8 +--- .../apache/spark/streaming/dstream/DStream.scala | 3 +- .../streaming/dstream/DStreamCheckpointData.scala | 3 +- .../spark/streaming/dstream/FileInputDStream.scala | 3 +- .../spark/streaming/dstream/StateDStream.scala | 26 ++++-------- .../spark/streaming/BasicOperationsSuite.scala | 3 +- .../apache/spark/streaming/CheckpointSuite.scala | 6 +-- .../apache/spark/streaming/MasterFailureTest.scala | 3 +- .../spark/deploy/yarn/ApplicationMaster.scala | 3 +- .../apache/spark/deploy/yarn/YarnAllocator.scala | 3 +- .../scheduler/cluster/YarnSchedulerBackend.scala | 6 +-- .../deploy/yarn/YarnSparkHadoopUtilSuite.scala | 24 ++++------- 64 files changed, 164 insertions(+), 293 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index f0d152f05a..966198dd5e 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -2397,9 +2397,8 @@ object SparkContext extends Logging { } catch { // TODO: Enumerate the exact reasons why it can fail // But irrespective of it, it means we cannot proceed ! - case e: Exception => { + case e: Exception => throw new SparkException("YARN mode not available ?", e) - } } val backend = try { val clazz = @@ -2407,9 +2406,8 @@ object SparkContext extends Logging { val cons = clazz.getConstructor(classOf[TaskSchedulerImpl], classOf[SparkContext]) cons.newInstance(scheduler, sc).asInstanceOf[CoarseGrainedSchedulerBackend] } catch { - case e: Exception => { + case e: Exception => throw new SparkException("YARN mode not available ?", e) - } } scheduler.initialize(backend) (backend, scheduler) @@ -2421,9 +2419,8 @@ object SparkContext extends Logging { cons.newInstance(sc).asInstanceOf[TaskSchedulerImpl] } catch { - case e: Exception => { + case e: Exception => throw new SparkException("YARN mode not available ?", e) - } } val backend = try { @@ -2432,9 +2429,8 @@ object SparkContext extends Logging { val cons = clazz.getConstructor(classOf[TaskSchedulerImpl], classOf[SparkContext]) cons.newInstance(scheduler, sc).asInstanceOf[CoarseGrainedSchedulerBackend] } catch { - case e: Exception => { + case e: Exception => throw new SparkException("YARN mode not available ?", e) - } } scheduler.initialize(backend) diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index ab89f4c4e4..3d11db7461 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -101,14 +101,13 @@ class SparkEnv ( // We only need to delete the tmp dir create by driver, because sparkFilesDir is point to the // current working dir in executor which we do not need to delete. driverTmpDirToDelete match { - case Some(path) => { + case Some(path) => try { Utils.deleteRecursively(new File(path)) } catch { case e: Exception => logWarning(s"Exception while deleting Spark temp dir: $path", e) } - } case None => // We just need to delete tmp dir created by driver, so do nothing on executor } } diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonHadoopUtil.scala b/core/src/main/scala/org/apache/spark/api/python/PythonHadoopUtil.scala index 6f6730690f..6259bead3e 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonHadoopUtil.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonHadoopUtil.scala @@ -134,11 +134,10 @@ private[python] class JavaToWritableConverter extends Converter[Any, Writable] { mapWritable.put(convertToWritable(k), convertToWritable(v)) } mapWritable - case array: Array[Any] => { + case array: Array[Any] => val arrayWriteable = new ArrayWritable(classOf[Writable]) arrayWriteable.set(array.map(convertToWritable(_))) arrayWriteable - } case other => throw new SparkException( s"Data of type ${other.getClass.getName} cannot be used") } diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index 4bca16a234..ab5b6c8380 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -470,7 +470,7 @@ private[spark] object PythonRDD extends Logging { objs.append(obj) } } catch { - case eof: EOFException => {} + case eof: EOFException => // No-op } JavaRDD.fromRDD(sc.sc.parallelize(objs, parallelism)) } finally { diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala index 41ac308808..cda9d38c6a 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala @@ -152,10 +152,9 @@ class SparkHadoopUtil extends Logging { val baselineBytesRead = f() Some(() => f() - baselineBytesRead) } catch { - case e @ (_: NoSuchMethodException | _: ClassNotFoundException) => { + case e @ (_: NoSuchMethodException | _: ClassNotFoundException) => logDebug("Couldn't find method for retrieving thread-level FileSystem input data", e) None - } } } @@ -174,10 +173,9 @@ class SparkHadoopUtil extends Logging { val baselineBytesWritten = f() Some(() => f() - baselineBytesWritten) } catch { - case e @ (_: NoSuchMethodException | _: ClassNotFoundException) => { + case e @ (_: NoSuchMethodException | _: ClassNotFoundException) => logDebug("Couldn't find method for retrieving thread-level FileSystem output data", e) None - } } } @@ -315,7 +313,7 @@ class SparkHadoopUtil extends Logging { */ def substituteHadoopVariables(text: String, hadoopConf: Configuration): String = { text match { - case HADOOP_CONF_PATTERN(matched) => { + case HADOOP_CONF_PATTERN(matched) => logDebug(text + " matched " + HADOOP_CONF_PATTERN) val key = matched.substring(13, matched.length() - 1) // remove ${hadoopconf- .. } val eval = Option[String](hadoopConf.get(key)) @@ -330,11 +328,9 @@ class SparkHadoopUtil extends Logging { // Continue to substitute more variables. substituteHadoopVariables(eval.get, hadoopConf) } - } - case _ => { + case _ => logDebug(text + " didn't match " + HADOOP_CONF_PATTERN) text - } } } diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index 01901bbf85..9bd3fc1033 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -217,7 +217,7 @@ private[deploy] class Master( } override def receive: PartialFunction[Any, Unit] = { - case ElectedLeader => { + case ElectedLeader => val (storedApps, storedDrivers, storedWorkers) = persistenceEngine.readPersistedData(rpcEnv) state = if (storedApps.isEmpty && storedDrivers.isEmpty && storedWorkers.isEmpty) { RecoveryState.ALIVE @@ -233,16 +233,14 @@ private[deploy] class Master( } }, WORKER_TIMEOUT_MS, TimeUnit.MILLISECONDS) } - } case CompleteRecovery => completeRecovery() - case RevokedLeadership => { + case RevokedLeadership => logError("Leadership has been revoked -- master shutting down.") System.exit(0) - } - case RegisterApplication(description, driver) => { + case RegisterApplication(description, driver) => // TODO Prevent repeated registrations from some driver if (state == RecoveryState.STANDBY) { // ignore, don't send response @@ -255,12 +253,11 @@ private[deploy] class Master( driver.send(RegisteredApplication(app.id, self)) schedule() } - } - case ExecutorStateChanged(appId, execId, state, message, exitStatus) => { + case ExecutorStateChanged(appId, execId, state, message, exitStatus) => val execOption = idToApp.get(appId).flatMap(app => app.executors.get(execId)) execOption match { - case Some(exec) => { + case Some(exec) => val appInfo = idToApp(appId) val oldState = exec.state exec.state = state @@ -298,22 +295,19 @@ private[deploy] class Master( } } } - } case None => logWarning(s"Got status update for unknown executor $appId/$execId") } - } - case DriverStateChanged(driverId, state, exception) => { + case DriverStateChanged(driverId, state, exception) => state match { case DriverState.ERROR | DriverState.FINISHED | DriverState.KILLED | DriverState.FAILED => removeDriver(driverId, state, exception) case _ => throw new Exception(s"Received unexpected state update for driver $driverId: $state") } - } - case Heartbeat(workerId, worker) => { + case Heartbeat(workerId, worker) => idToWorker.get(workerId) match { case Some(workerInfo) => workerInfo.lastHeartbeat = System.currentTimeMillis() @@ -327,9 +321,8 @@ private[deploy] class Master( " This worker was never registered, so ignoring the heartbeat.") } } - } - case MasterChangeAcknowledged(appId) => { + case MasterChangeAcknowledged(appId) => idToApp.get(appId) match { case Some(app) => logInfo("Application has been re-registered: " + appId) @@ -339,9 +332,8 @@ private[deploy] class Master( } if (canCompleteRecovery) { completeRecovery() } - } - case WorkerSchedulerStateResponse(workerId, executors, driverIds) => { + case WorkerSchedulerStateResponse(workerId, executors, driverIds) => idToWorker.get(workerId) match { case Some(worker) => logInfo("Worker has been re-registered: " + workerId) @@ -367,7 +359,6 @@ private[deploy] class Master( } if (canCompleteRecovery) { completeRecovery() } - } case WorkerLatestState(workerId, executors, driverIds) => idToWorker.get(workerId) match { @@ -397,9 +388,8 @@ private[deploy] class Master( logInfo(s"Received unregister request from application $applicationId") idToApp.get(applicationId).foreach(finishApplication) - case CheckForWorkerTimeOut => { + case CheckForWorkerTimeOut => timeOutDeadWorkers() - } case AttachCompletedRebuildUI(appId) => // An asyncRebuildSparkUI has completed, so need to attach to master webUi @@ -408,7 +398,7 @@ private[deploy] class Master( override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { case RegisterWorker( - id, workerHost, workerPort, workerRef, cores, memory, workerWebUiUrl) => { + id, workerHost, workerPort, workerRef, cores, memory, workerWebUiUrl) => logInfo("Registering worker %s:%d with %d cores, %s RAM".format( workerHost, workerPort, cores, Utils.megabytesToString(memory))) if (state == RecoveryState.STANDBY) { @@ -430,9 +420,8 @@ private[deploy] class Master( + workerAddress)) } } - } - case RequestSubmitDriver(description) => { + case RequestSubmitDriver(description) => if (state != RecoveryState.ALIVE) { val msg = s"${Utils.BACKUP_STANDALONE_MASTER_PREFIX}: $state. " + "Can only accept driver submissions in ALIVE state." @@ -451,9 +440,8 @@ private[deploy] class Master( context.reply(SubmitDriverResponse(self, true, Some(driver.id), s"Driver successfully submitted as ${driver.id}")) } - } - case RequestKillDriver(driverId) => { + case RequestKillDriver(driverId) => if (state != RecoveryState.ALIVE) { val msg = s"${Utils.BACKUP_STANDALONE_MASTER_PREFIX}: $state. " + s"Can only kill drivers in ALIVE state." @@ -484,9 +472,8 @@ private[deploy] class Master( context.reply(KillDriverResponse(self, driverId, success = false, msg)) } } - } - case RequestDriverStatus(driverId) => { + case RequestDriverStatus(driverId) => if (state != RecoveryState.ALIVE) { val msg = s"${Utils.BACKUP_STANDALONE_MASTER_PREFIX}: $state. " + "Can only request driver status in ALIVE state." @@ -501,18 +488,15 @@ private[deploy] class Master( context.reply(DriverStatusResponse(found = false, None, None, None, None)) } } - } - case RequestMasterState => { + case RequestMasterState => context.reply(MasterStateResponse( address.host, address.port, restServerBoundPort, workers.toArray, apps.toArray, completedApps.toArray, drivers.toArray, completedDrivers.toArray, state)) - } - case BoundPortsRequest => { + case BoundPortsRequest => context.reply(BoundPortsResponse(address.port, webUi.boundPort, restServerBoundPort)) - } case RequestExecutors(appId, requestedTotal) => context.reply(handleRequestExecutors(appId, requestedTotal)) diff --git a/core/src/main/scala/org/apache/spark/deploy/master/MasterArguments.scala b/core/src/main/scala/org/apache/spark/deploy/master/MasterArguments.scala index 9cd7458ba0..585e0839d0 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/MasterArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/MasterArguments.scala @@ -78,7 +78,7 @@ private[master] class MasterArguments(args: Array[String], conf: SparkConf) { case ("--help") :: tail => printUsageAndExit(0) - case Nil => {} + case Nil => // No-op case _ => printUsageAndExit(1) diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala index 79f77212fe..af850e4871 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala @@ -70,11 +70,10 @@ private[master] class ZooKeeperPersistenceEngine(conf: SparkConf, val serializer try { Some(serializer.newInstance().deserialize[T](ByteBuffer.wrap(fileData))) } catch { - case e: Exception => { + case e: Exception => logWarning("Exception while reading persisted file, deleting", e) zk.delete().forPath(WORKING_DIR + "/" + filename) None - } } } } diff --git a/core/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcherArguments.scala b/core/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcherArguments.scala index b97805a28b..11e13441ee 100644 --- a/core/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcherArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcherArguments.scala @@ -76,14 +76,13 @@ private[mesos] class MesosClusterDispatcherArguments(args: Array[String], conf: case ("--help") :: tail => printUsageAndExit(0) - case Nil => { + case Nil => if (masterUrl == null) { // scalastyle:off println System.err.println("--master is required") // scalastyle:on println printUsageAndExit(1) } - } case _ => printUsageAndExit(1) diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala index f9c92c3bb9..06066248ea 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala @@ -179,16 +179,14 @@ private[deploy] class ExecutorRunner( val message = "Command exited with code " + exitCode worker.send(ExecutorStateChanged(appId, execId, state, Some(message), Some(exitCode))) } catch { - case interrupted: InterruptedException => { + case interrupted: InterruptedException => logInfo("Runner thread for executor " + fullId + " interrupted") state = ExecutorState.KILLED killProcess(None) - } - case e: Exception => { + case e: Exception => logError("Error running executor", e) state = ExecutorState.FAILED killProcess(Some(e.toString)) - } } } } diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala index 1b7637a39c..449beb0811 100755 --- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala @@ -480,7 +480,7 @@ private[deploy] class Worker( memoryUsed += memory_ sendToMaster(ExecutorStateChanged(appId, execId, manager.state, None, None)) } catch { - case e: Exception => { + case e: Exception => logError(s"Failed to launch executor $appId/$execId for ${appDesc.name}.", e) if (executors.contains(appId + "/" + execId)) { executors(appId + "/" + execId).kill() @@ -488,7 +488,6 @@ private[deploy] class Worker( } sendToMaster(ExecutorStateChanged(appId, execId, ExecutorState.FAILED, Some(e.toString), None)) - } } } @@ -509,7 +508,7 @@ private[deploy] class Worker( } } - case LaunchDriver(driverId, driverDesc) => { + case LaunchDriver(driverId, driverDesc) => logInfo(s"Asked to launch driver $driverId") val driver = new DriverRunner( conf, @@ -525,9 +524,8 @@ private[deploy] class Worker( coresUsed += driverDesc.cores memoryUsed += driverDesc.mem - } - case KillDriver(driverId) => { + case KillDriver(driverId) => logInfo(s"Asked to kill driver $driverId") drivers.get(driverId) match { case Some(runner) => @@ -535,11 +533,9 @@ private[deploy] class Worker( case None => logError(s"Asked to kill unknown driver $driverId") } - } - case driverStateChanged @ DriverStateChanged(driverId, state, exception) => { + case driverStateChanged @ DriverStateChanged(driverId, state, exception) => handleDriverStateChanged(driverStateChanged) - } case ReregisterWithMaster => reregisterWithMaster() diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala index 391eb41190..777020d4d5 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala @@ -165,12 +165,11 @@ private[worker] class WorkerArguments(args: Array[String], conf: SparkConf) { } // scalastyle:on classforname } catch { - case e: Exception => { + case e: Exception => totalMb = 2*1024 // scalastyle:off println System.out.println("Failed to get total physical memory. Using " + totalMb + " MB") // scalastyle:on println - } } // Leave out 1 GB for the operating system, but don't return a negative memory size math.max(totalMb - 1024, Utils.DEFAULT_DRIVER_MEM_MB) diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala index d4ed5845e7..71b4ad160d 100644 --- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala @@ -62,10 +62,9 @@ private[spark] class CoarseGrainedExecutorBackend( // This is a very fast action so we can use "ThreadUtils.sameThread" case Success(msg) => // Always receive `true`. Just ignore it - case Failure(e) => { + case Failure(e) => logError(s"Cannot register with driver: $driverUrl", e) System.exit(1) - } }(ThreadUtils.sameThread) } diff --git a/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala b/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala index 4da1017d28..0fed991049 100644 --- a/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala +++ b/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala @@ -196,10 +196,9 @@ private[spark] class MetricsSystem private ( sinks += sink.asInstanceOf[Sink] } } catch { - case e: Exception => { + case e: Exception => logError("Sink class " + classPath + " cannot be instantiated") throw e - } } } } diff --git a/core/src/main/scala/org/apache/spark/partial/BoundedDouble.scala b/core/src/main/scala/org/apache/spark/partial/BoundedDouble.scala index c562c70aba..ab6aba6fc7 100644 --- a/core/src/main/scala/org/apache/spark/partial/BoundedDouble.scala +++ b/core/src/main/scala/org/apache/spark/partial/BoundedDouble.scala @@ -32,12 +32,11 @@ class BoundedDouble(val mean: Double, val confidence: Double, val low: Double, v */ override def equals(that: Any): Boolean = that match { - case that: BoundedDouble => { + case that: BoundedDouble => this.mean == that.mean && this.confidence == that.confidence && this.low == that.low && this.high == that.high - } case _ => false } } diff --git a/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala index 5e9230e733..368916a39e 100644 --- a/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala @@ -166,8 +166,8 @@ class DoubleRDDFunctions(self: RDD[Double]) extends Logging with Serializable { val counters = new Array[Long](buckets.length - 1) while (iter.hasNext) { bucketFunction(iter.next()) match { - case Some(x: Int) => {counters(x) += 1} - case _ => {} + case Some(x: Int) => counters(x) += 1 + case _ => // No-Op } } Iterator(counters) diff --git a/core/src/main/scala/org/apache/spark/rdd/OrderedRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/OrderedRDDFunctions.scala index 363004e587..a5992022d0 100644 --- a/core/src/main/scala/org/apache/spark/rdd/OrderedRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/OrderedRDDFunctions.scala @@ -86,12 +86,11 @@ class OrderedRDDFunctions[K : Ordering : ClassTag, def inRange(k: K): Boolean = ordering.gteq(k, lower) && ordering.lteq(k, upper) val rddToFilter: RDD[P] = self.partitioner match { - case Some(rp: RangePartitioner[K, V]) => { + case Some(rp: RangePartitioner[K, V]) => val partitionIndicies = (rp.getPartition(lower), rp.getPartition(upper)) match { case (l, u) => Math.min(l, u) to Math.max(l, u) } PartitionPruningRDD.create(self, partitionIndicies.contains) - } case _ => self } diff --git a/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala index 582fa93afe..462fb39ea2 100644 --- a/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala @@ -128,7 +128,7 @@ private object ParallelCollectionRDD { }) } seq match { - case r: Range => { + case r: Range => positions(r.length, numSlices).zipWithIndex.map({ case ((start, end), index) => // If the range is inclusive, use inclusive range for the last slice if (r.isInclusive && index == numSlices - 1) { @@ -138,8 +138,7 @@ private object ParallelCollectionRDD { new Range(r.start + start * r.step, r.start + end * r.step, r.step) } }).toSeq.asInstanceOf[Seq[Seq[T]]] - } - case nr: NumericRange[_] => { + case nr: NumericRange[_] => // For ranges of Long, Double, BigInteger, etc val slices = new ArrayBuffer[Seq[T]](numSlices) var r = nr @@ -149,14 +148,12 @@ private object ParallelCollectionRDD { r = r.drop(sliceSize) } slices - } - case _ => { + case _ => val array = seq.toArray // To prevent O(n^2) operations for List etc positions(array.length, numSlices).map({ case (start, end) => array.slice(start, end).toSeq }).toSeq - } } } } diff --git a/core/src/main/scala/org/apache/spark/rdd/PartitionerAwareUnionRDD.scala b/core/src/main/scala/org/apache/spark/rdd/PartitionerAwareUnionRDD.scala index 9e3880714a..c3579d761d 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PartitionerAwareUnionRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PartitionerAwareUnionRDD.scala @@ -78,11 +78,10 @@ class PartitionerAwareUnionRDD[T: ClassTag]( logDebug("Finding preferred location for " + this + ", partition " + s.index) val parentPartitions = s.asInstanceOf[PartitionerAwareUnionRDDPartition].parents val locations = rdds.zip(parentPartitions).flatMap { - case (rdd, part) => { + case (rdd, part) => val parentLocations = currPrefLocs(rdd, part) logDebug("Location of " + rdd + " partition " + part.index + " = " + parentLocations) parentLocations - } } val location = if (locations.isEmpty) { None diff --git a/core/src/main/scala/org/apache/spark/scheduler/InputFormatInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/InputFormatInfo.scala index 0640f26051..a6b032cc00 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/InputFormatInfo.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/InputFormatInfo.scala @@ -57,11 +57,10 @@ class InputFormatInfo(val configuration: Configuration, val inputFormatClazz: Cl // Since we are not doing canonicalization of path, this can be wrong : like relative vs // absolute path .. which is fine, this is best case effort to remove duplicates - right ? override def equals(other: Any): Boolean = other match { - case that: InputFormatInfo => { + case that: InputFormatInfo => // not checking config - that should be fine, right ? this.inputFormatClazz == that.inputFormatClazz && this.path == that.path - } case _ => false } @@ -86,10 +85,9 @@ class InputFormatInfo(val configuration: Configuration, val inputFormatClazz: Cl } } catch { - case e: ClassNotFoundException => { + case e: ClassNotFoundException => throw new IllegalArgumentException("Specified inputformat " + inputFormatClazz + " cannot be found ?", e) - } } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/SplitInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/SplitInfo.scala index 6e9337bb90..bc1431835e 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SplitInfo.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SplitInfo.scala @@ -49,14 +49,13 @@ class SplitInfo( // So unless there is identity equality between underlyingSplits, it will always fail even if it // is pointing to same block. override def equals(other: Any): Boolean = other match { - case that: SplitInfo => { + case that: SplitInfo => this.hostLocation == that.hostLocation && this.inputFormatClazz == that.inputFormatClazz && this.path == that.path && this.length == that.length && // other split specific checks (like start for FileSplit) this.underlyingSplit == that.underlyingSplit - } case _ => false } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala index 873f1b56bd..ae7ef46abb 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala @@ -133,7 +133,7 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedul // if we can't deserialize the reason. logError( "Could not deserialize TaskEndReason: ClassNotFound with classloader " + loader) - case ex: Exception => {} + case ex: Exception => // No-op } scheduler.handleFailedTask(taskSetManager, tid, taskState, reason) } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index 15d3515a02..6e08cdd87a 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -188,20 +188,18 @@ private[spark] class TaskSetManager( loc match { case e: ExecutorCacheTaskLocation => pendingTasksForExecutor.getOrElseUpdate(e.executorId, new ArrayBuffer) += index - case e: HDFSCacheTaskLocation => { + case e: HDFSCacheTaskLocation => val exe = sched.getExecutorsAliveOnHost(loc.host) exe match { - case Some(set) => { + case Some(set) => for (e <- set) { pendingTasksForExecutor.getOrElseUpdate(e, new ArrayBuffer) += index } logInfo(s"Pending task $index has a cached location at ${e.host} " + ", where there are executors " + set.mkString(",")) - } case None => logDebug(s"Pending task $index has a cached location at ${e.host} " + ", but there are no executors alive there.") } - } case _ => } pendingTasksForHost.getOrElseUpdate(loc.host, new ArrayBuffer) += index @@ -437,7 +435,7 @@ private[spark] class TaskSetManager( } dequeueTask(execId, host, allowedLocality) match { - case Some((index, taskLocality, speculative)) => { + case Some((index, taskLocality, speculative)) => // Found a task; do some bookkeeping and return a task description val task = tasks(index) val taskId = sched.newTaskId() @@ -486,7 +484,6 @@ private[spark] class TaskSetManager( sched.dagScheduler.taskStarted(task, info) return Some(new TaskDescription(taskId = taskId, attemptNumber = attemptNum, execId, taskName, index, serializedTask)) - } case _ => } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterPersistenceEngine.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterPersistenceEngine.scala index 3971e6c382..61ab3e87c5 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterPersistenceEngine.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterPersistenceEngine.scala @@ -121,11 +121,10 @@ private[spark] class ZookeeperMesosClusterPersistenceEngine( Some(Utils.deserialize[T](fileData)) } catch { case e: NoNodeException => None - case e: Exception => { + case e: Exception => logWarning("Exception while reading persisted file, deleting", e) zk.delete().forPath(zkPath) None - } } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendUtil.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendUtil.scala index 374c79a7e5..1b7ac172de 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendUtil.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendUtil.scala @@ -55,11 +55,10 @@ private[mesos] object MesosSchedulerBackendUtil extends Logging { Some(vol.setContainerPath(container_path) .setHostPath(host_path) .setMode(Volume.Mode.RO)) - case spec => { + case spec => logWarning(s"Unable to parse volume specs: $volumes. " + "Expected form: \"[host-dir:]container-dir[:rw|:ro](, ...)\"") None - } } } .map { _.build() } @@ -90,11 +89,10 @@ private[mesos] object MesosSchedulerBackendUtil extends Logging { Some(portmap.setHostPort(host_port.toInt) .setContainerPort(container_port.toInt) .setProtocol(protocol)) - case spec => { + case spec => logWarning(s"Unable to parse port mapping specs: $portmaps. " + "Expected form: \"host_port:container_port[:udp|:tcp](, ...)\"") None - } } } .map { _.build() } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala index 233bdc23e6..7295d50682 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala @@ -124,11 +124,10 @@ private[mesos] trait MesosSchedulerUtils extends Logging { markErr() } } catch { - case e: Exception => { + case e: Exception => logError("driver.run() failed", e) error = Some(e) markErr() - } } } }.start() @@ -184,7 +183,7 @@ private[mesos] trait MesosSchedulerUtils extends Logging { var remain = amountToUse var requestedResources = new ArrayBuffer[Resource] val remainingResources = resources.asScala.map { - case r => { + case r => if (remain > 0 && r.getType == Value.Type.SCALAR && r.getScalar.getValue > 0.0 && @@ -196,7 +195,6 @@ private[mesos] trait MesosSchedulerUtils extends Logging { } else { r } - } } // Filter any resource that has depleted. diff --git a/core/src/main/scala/org/apache/spark/serializer/Serializer.scala b/core/src/main/scala/org/apache/spark/serializer/Serializer.scala index 5ead40e89e..cb95246d5b 100644 --- a/core/src/main/scala/org/apache/spark/serializer/Serializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/Serializer.scala @@ -188,10 +188,9 @@ abstract class DeserializationStream { try { (readKey[Any](), readValue[Any]()) } catch { - case eof: EOFException => { + case eof: EOFException => finished = true null - } } } diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index 25edb9f1e4..4ec5b4bbb0 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -143,13 +143,12 @@ final class ShuffleBlockFetcherIterator( while (iter.hasNext) { val result = iter.next() result match { - case SuccessFetchResult(_, address, _, buf, _) => { + case SuccessFetchResult(_, address, _, buf, _) => if (address != blockManager.blockManagerId) { shuffleMetrics.incRemoteBytesRead(buf.size) shuffleMetrics.incRemoteBlocksFetched(1) } buf.release() - } case _ => } } @@ -313,7 +312,7 @@ final class ShuffleBlockFetcherIterator( shuffleMetrics.incFetchWaitTime(stopFetchWait - startFetchWait) result match { - case SuccessFetchResult(_, address, size, buf, isNetworkReqDone) => { + case SuccessFetchResult(_, address, size, buf, isNetworkReqDone) => if (address != blockManager.blockManagerId) { shuffleMetrics.incRemoteBytesRead(buf.size) shuffleMetrics.incRemoteBlocksFetched(1) @@ -323,7 +322,6 @@ final class ShuffleBlockFetcherIterator( reqsInFlight -= 1 logDebug("Number of requests in flight " + reqsInFlight) } - } case _ => } // Send fetch requests up to maxBytesInFlight diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala index cc476d61b5..a0ef80d9bd 100644 --- a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala @@ -38,7 +38,7 @@ private[ui] class ExecutorThreadDumpPage(parent: ExecutorsTab) extends WebUIPage val content = maybeThreadDump.map { threadDump => val dumpRows = threadDump.sortWith { - case (threadTrace1, threadTrace2) => { + case (threadTrace1, threadTrace2) => val v1 = if (threadTrace1.threadName.contains("Executor task launch")) 1 else 0 val v2 = if (threadTrace2.threadName.contains("Executor task launch")) 1 else 0 if (v1 == v2) { @@ -46,7 +46,6 @@ private[ui] class ExecutorThreadDumpPage(parent: ExecutorsTab) extends WebUIPage } else { v1 > v2 } - } }.map { thread => val threadId = thread.threadId
    { + case NonFatal(e) => try { onError(e) } catch { case NonFatal(e) => logError("Unexpected error in " + name, e) } - } } } } catch { diff --git a/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala b/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala index 3f627a0145..6861a75612 100644 --- a/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala +++ b/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala @@ -151,13 +151,12 @@ object SizeEstimator extends Logging { // TODO: We could use reflection on the VMOption returned ? getVMMethod.invoke(bean, "UseCompressedOops").toString.contains("true") } catch { - case e: Exception => { + case e: Exception => // Guess whether they've enabled UseCompressedOops based on whether maxMemory < 32 GB val guess = Runtime.getRuntime.maxMemory < (32L*1024*1024*1024) val guessInWords = if (guess) "yes" else "not" logWarning("Failed to check whether UseCompressedOops is set; assuming " + guessInWords) return guess - } } } diff --git a/core/src/test/scala/org/apache/spark/DistributedSuite.scala b/core/src/test/scala/org/apache/spark/DistributedSuite.scala index 67d722c1dc..2110d3d770 100644 --- a/core/src/test/scala/org/apache/spark/DistributedSuite.scala +++ b/core/src/test/scala/org/apache/spark/DistributedSuite.scala @@ -320,7 +320,7 @@ class DistributedSuite extends SparkFunSuite with Matchers with LocalSparkContex Thread.sleep(200) } } catch { - case _: Throwable => { Thread.sleep(10) } + case _: Throwable => Thread.sleep(10) // Do nothing. We might see exceptions because block manager // is racing this thread to remove entries from the driver. } diff --git a/core/src/test/scala/org/apache/spark/SparkContextInfoSuite.scala b/core/src/test/scala/org/apache/spark/SparkContextInfoSuite.scala index 3706455c3f..8feb3dee05 100644 --- a/core/src/test/scala/org/apache/spark/SparkContextInfoSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkContextInfoSuite.scala @@ -82,20 +82,18 @@ package object testPackage extends Assertions { val curCallSite = sc.getCallSite().shortForm // note: 2 lines after definition of "rdd" val rddCreationLine = rddCreationSite match { - case CALL_SITE_REGEX(func, file, line) => { + case CALL_SITE_REGEX(func, file, line) => assert(func === "makeRDD") assert(file === "SparkContextInfoSuite.scala") line.toInt - } case _ => fail("Did not match expected call site format") } curCallSite match { - case CALL_SITE_REGEX(func, file, line) => { + case CALL_SITE_REGEX(func, file, line) => assert(func === "getCallSite") // this is correct because we called it from outside of Spark assert(file === "SparkContextInfoSuite.scala") assert(line.toInt === rddCreationLine.toInt + 2) - } case _ => fail("Did not match expected call site format") } } diff --git a/core/src/test/scala/org/apache/spark/UnpersistSuite.scala b/core/src/test/scala/org/apache/spark/UnpersistSuite.scala index f7a13ab399..09e21646ee 100644 --- a/core/src/test/scala/org/apache/spark/UnpersistSuite.scala +++ b/core/src/test/scala/org/apache/spark/UnpersistSuite.scala @@ -35,7 +35,7 @@ class UnpersistSuite extends SparkFunSuite with LocalSparkContext { Thread.sleep(200) } } catch { - case _: Throwable => { Thread.sleep(10) } + case _: Throwable => Thread.sleep(10) // Do nothing. We might see exceptions because block manager // is racing this thread to remove entries from the driver. } diff --git a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala index 43e61241b6..cebac2097f 100644 --- a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala +++ b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala @@ -127,9 +127,8 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { override val rpcEnv = env override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { - case msg: String => { + case msg: String => context.reply(msg) - } } }) val reply = rpcEndpointRef.askWithRetry[String]("hello") @@ -141,9 +140,8 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { override val rpcEnv = env override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { - case msg: String => { + case msg: String => context.reply(msg) - } } }) @@ -164,10 +162,9 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { override val rpcEnv = env override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { - case msg: String => { + case msg: String => Thread.sleep(100) context.reply(msg) - } } }) @@ -317,10 +314,9 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { override val rpcEnv = env override def receive: PartialFunction[Any, Unit] = { - case m => { + case m => self callSelfSuccessfully = true - } } }) @@ -682,9 +678,8 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { override val rpcEnv = localEnv override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { - case msg: String => { + case msg: String => context.reply(msg) - } } }) val rpcEndpointRef = remoteEnv.setupEndpointRef(localEnv.address, "ask-authentication") diff --git a/examples/src/main/scala/org/apache/spark/examples/CassandraCQLTest.scala b/examples/src/main/scala/org/apache/spark/examples/CassandraCQLTest.scala index 973b005f91..ca4eea2356 100644 --- a/examples/src/main/scala/org/apache/spark/examples/CassandraCQLTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/CassandraCQLTest.scala @@ -106,9 +106,8 @@ object CassandraCQLTest { println("Count: " + casRdd.count) val productSaleRDD = casRdd.map { - case (key, value) => { + case (key, value) => (ByteBufferUtil.string(value.get("prod_id")), ByteBufferUtil.toInt(value.get("quantity"))) - } } val aggregatedRDD = productSaleRDD.reduceByKey(_ + _) aggregatedRDD.collect().foreach { @@ -116,11 +115,10 @@ object CassandraCQLTest { } val casoutputCF = aggregatedRDD.map { - case (productId, saleCount) => { + case (productId, saleCount) => val outKey = Collections.singletonMap("prod_id", ByteBufferUtil.bytes(productId)) val outVal = Collections.singletonList(ByteBufferUtil.bytes(saleCount)) (outKey, outVal) - } } casoutputCF.saveAsNewAPIHadoopFile( diff --git a/examples/src/main/scala/org/apache/spark/examples/CassandraTest.scala b/examples/src/main/scala/org/apache/spark/examples/CassandraTest.scala index 6a8f73ad00..eff840d36e 100644 --- a/examples/src/main/scala/org/apache/spark/examples/CassandraTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/CassandraTest.scala @@ -90,9 +90,8 @@ object CassandraTest { // Let us first get all the paragraphs from the retrieved rows val paraRdd = casRdd.map { - case (key, value) => { + case (key, value) => ByteBufferUtil.string(value.get(ByteBufferUtil.bytes("para")).value()) - } } // Lets get the word count in paras @@ -103,7 +102,7 @@ object CassandraTest { } counts.map { - case (word, count) => { + case (word, count) => val colWord = new org.apache.cassandra.thrift.Column() colWord.setName(ByteBufferUtil.bytes("word")) colWord.setValue(ByteBufferUtil.bytes(word)) @@ -122,7 +121,6 @@ object CassandraTest { mutations.get(1).setColumn_or_supercolumn(new ColumnOrSuperColumn()) mutations.get(1).column_or_supercolumn.setColumn(colCount) (outputkey, mutations) - } }.saveAsNewAPIHadoopFile("casDemo", classOf[ByteBuffer], classOf[List[Mutation]], classOf[ColumnFamilyOutputFormat], job.getConfiguration) diff --git a/examples/src/main/scala/org/apache/spark/examples/LocalALS.scala b/examples/src/main/scala/org/apache/spark/examples/LocalALS.scala index af5f216f28..fa10101955 100644 --- a/examples/src/main/scala/org/apache/spark/examples/LocalALS.scala +++ b/examples/src/main/scala/org/apache/spark/examples/LocalALS.scala @@ -104,16 +104,14 @@ object LocalALS { def main(args: Array[String]) { args match { - case Array(m, u, f, iters) => { + case Array(m, u, f, iters) => M = m.toInt U = u.toInt F = f.toInt ITERATIONS = iters.toInt - } - case _ => { + case _ => System.err.println("Usage: LocalALS ") System.exit(1) - } } showWarning() diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/OneVsRestExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/OneVsRestExample.scala index a0bb5dabf4..0b5d31c0ff 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/OneVsRestExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/OneVsRestExample.scala @@ -118,17 +118,15 @@ object OneVsRestExample { val inputData = sqlContext.read.format("libsvm").load(params.input) // compute the train/test split: if testInput is not provided use part of input. val data = params.testInput match { - case Some(t) => { + case Some(t) => // compute the number of features in the training set. val numFeatures = inputData.first().getAs[Vector](1).size val testData = sqlContext.read.option("numFeatures", numFeatures.toString) .format("libsvm").load(t) Array[DataFrame](inputData, testData) - } - case None => { + case None => val f = params.fracTest inputData.randomSplit(Array(1 - f, f), seed = 12345) - } } val Array(train, test) = data.map(_.cache()) diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala index c263f4f595..ee811d3aa1 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala @@ -180,7 +180,7 @@ object DecisionTreeRunner { } // For classification, re-index classes if needed. val (examples, classIndexMap, numClasses) = algo match { - case Classification => { + case Classification => // classCounts: class --> # examples in class val classCounts = origExamples.map(_.label).countByValue() val sortedClasses = classCounts.keys.toList.sorted @@ -209,7 +209,6 @@ object DecisionTreeRunner { println(s"$c\t$frac\t${classCounts(c)}") } (examples, classIndexMap, numClasses) - } case Regression => (origExamples, null, 0) case _ => @@ -225,7 +224,7 @@ object DecisionTreeRunner { case "libsvm" => MLUtils.loadLibSVMFile(sc, testInput, numFeatures) } algo match { - case Classification => { + case Classification => // classCounts: class --> # examples in class val testExamples = { if (classIndexMap.isEmpty) { @@ -235,7 +234,6 @@ object DecisionTreeRunner { } } Array(examples, testExamples) - } case Regression => Array(examples, origTestExamples) } diff --git a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisRecordProcessor.scala b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisRecordProcessor.scala index 41c6ab123b..80e0cce055 100644 --- a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisRecordProcessor.scala +++ b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisRecordProcessor.scala @@ -73,7 +73,7 @@ private[kinesis] class KinesisRecordProcessor[T](receiver: KinesisReceiver[T], w logDebug(s"Stored: Worker $workerId stored ${batch.size} records for shardId $shardId") receiver.setCheckpointer(shardId, checkpointer) } catch { - case NonFatal(e) => { + case NonFatal(e) => /* * If there is a failure within the batch, the batch will not be checkpointed. * This will potentially cause records since the last checkpoint to be processed @@ -84,7 +84,6 @@ private[kinesis] class KinesisRecordProcessor[T](receiver: KinesisReceiver[T], w /* Rethrow the exception to the Kinesis Worker that is managing this RecordProcessor. */ throw e - } } } else { /* RecordProcessor has been stopped. */ @@ -148,29 +147,25 @@ private[kinesis] object KinesisRecordProcessor extends Logging { /* If the function failed, either retry or throw the exception */ case util.Failure(e) => e match { /* Retry: Throttling or other Retryable exception has occurred */ - case _: ThrottlingException | _: KinesisClientLibDependencyException if numRetriesLeft > 1 - => { - val backOffMillis = Random.nextInt(maxBackOffMillis) - Thread.sleep(backOffMillis) - logError(s"Retryable Exception: Random backOffMillis=${backOffMillis}", e) - retryRandom(expression, numRetriesLeft - 1, maxBackOffMillis) - } + case _: ThrottlingException | _: KinesisClientLibDependencyException + if numRetriesLeft > 1 => + val backOffMillis = Random.nextInt(maxBackOffMillis) + Thread.sleep(backOffMillis) + logError(s"Retryable Exception: Random backOffMillis=${backOffMillis}", e) + retryRandom(expression, numRetriesLeft - 1, maxBackOffMillis) /* Throw: Shutdown has been requested by the Kinesis Client Library. */ - case _: ShutdownException => { + case _: ShutdownException => logError(s"ShutdownException: Caught shutdown exception, skipping checkpoint.", e) throw e - } /* Throw: Non-retryable exception has occurred with the Kinesis Client Library */ - case _: InvalidStateException => { + case _: InvalidStateException => logError(s"InvalidStateException: Cannot save checkpoint to the DynamoDB table used" + s" by the Amazon Kinesis Client Library. Table likely doesn't exist.", e) throw e - } /* Throw: Unexpected exception has occurred */ - case _ => { + case _ => logError(s"Unexpected, non-retryable exception.", e) throw e - } } } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala b/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala index 551e75dc0a..fa143715be 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala @@ -53,7 +53,7 @@ private[r] object SparkRWrappers { def getModelCoefficients(model: PipelineModel): Array[Double] = { model.stages.last match { - case m: LinearRegressionModel => { + case m: LinearRegressionModel => val coefficientStandardErrorsR = Array(m.summary.coefficientStandardErrors.last) ++ m.summary.coefficientStandardErrors.dropRight(1) val tValuesR = Array(m.summary.tValues.last) ++ m.summary.tValues.dropRight(1) @@ -64,14 +64,12 @@ private[r] object SparkRWrappers { } else { m.coefficients.toArray ++ coefficientStandardErrorsR ++ tValuesR ++ pValuesR } - } - case m: LogisticRegressionModel => { + case m: LogisticRegressionModel => if (m.getFitIntercept) { Array(m.intercept) ++ m.coefficients.toArray } else { m.coefficients.toArray } - } } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala index 03eb903bb8..f04c87259c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala @@ -181,13 +181,12 @@ class GaussianMixture private ( val (weights, gaussians) = initialModel match { case Some(gmm) => (gmm.weights, gmm.gaussians) - case None => { + case None => val samples = breezeData.takeSample(withReplacement = true, k * nSamples, seed) (Array.fill(k)(1.0 / k), Array.tabulate(k) { i => val slice = samples.view(i * nSamples, (i + 1) * nSamples) new MultivariateGaussian(vectorMean(slice), initCovariance(slice)) }) - } } var llh = Double.MinValue // current log-likelihood diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala index 02417b1124..f87613cc72 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala @@ -183,7 +183,7 @@ object GaussianMixtureModel extends Loader[GaussianMixtureModel] { val k = (metadata \ "k").extract[Int] val classNameV1_0 = SaveLoadV1_0.classNameV1_0 (loadedClassName, version) match { - case (classNameV1_0, "1.0") => { + case (classNameV1_0, "1.0") => val model = SaveLoadV1_0.load(sc, path) require(model.weights.length == k, s"GaussianMixtureModel requires weights of length $k " + @@ -192,7 +192,6 @@ object GaussianMixtureModel extends Loader[GaussianMixtureModel] { s"GaussianMixtureModel requires gaussians of length $k" + s"got gaussians of length ${model.gaussians.length}") model - } case _ => throw new Exception( s"GaussianMixtureModel.load did not recognize model with (className, format version):" + s"($loadedClassName, $version). Supported:\n" + diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala index 37a21cd879..8ff0b83e8b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala @@ -253,16 +253,14 @@ class KMeans private ( } val centers = initialModel match { - case Some(kMeansCenters) => { + case Some(kMeansCenters) => Array(kMeansCenters.clusterCenters.map(s => new VectorWithNorm(s))) - } - case None => { + case None => if (initializationMode == KMeans.RANDOM) { initRandom(data) } else { initKMeansParallel(data) } - } } val initTimeInSeconds = (System.nanoTime() - initStartTime) / 1e9 logInfo(s"Initialization with $initializationMode took " + "%.3f".format(initTimeInSeconds) + diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/test/KolmogorovSmirnovTest.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/KolmogorovSmirnovTest.scala index 0ec8975fed..ef284531c9 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/test/KolmogorovSmirnovTest.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/KolmogorovSmirnovTest.scala @@ -97,7 +97,7 @@ private[stat] object KolmogorovSmirnovTest extends Logging { : KolmogorovSmirnovTestResult = { val distObj = distName match { - case "norm" => { + case "norm" => if (params.nonEmpty) { // parameters are passed, then can only be 2 require(params.length == 2, "Normal distribution requires mean and standard " + @@ -109,7 +109,6 @@ private[stat] object KolmogorovSmirnovTest extends Logging { "initialized to standard normal (i.e. N(0, 1))") new NormalDistribution(0, 1) } - } case _ => throw new UnsupportedOperationException(s"$distName not yet supported through" + s" convenience method. Current options are:['norm'].") } diff --git a/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala b/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala index 928aaa5629..4a15d52b57 100644 --- a/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala +++ b/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala @@ -70,26 +70,24 @@ class ExecutorClassLoader( } override def findClass(name: String): Class[_] = { - userClassPathFirst match { - case true => findClassLocally(name).getOrElse(parentLoader.loadClass(name)) - case false => { - try { - parentLoader.loadClass(name) - } catch { - case e: ClassNotFoundException => { - val classOption = findClassLocally(name) - classOption match { - case None => - // If this class has a cause, it will break the internal assumption of Janino - // (the compiler used for Spark SQL code-gen). - // See org.codehaus.janino.ClassLoaderIClassLoader's findIClass, you will see - // its behavior will be changed if there is a cause and the compilation - // of generated class will fail. - throw new ClassNotFoundException(name) - case Some(a) => a - } + if (userClassPathFirst) { + findClassLocally(name).getOrElse(parentLoader.loadClass(name)) + } else { + try { + parentLoader.loadClass(name) + } catch { + case e: ClassNotFoundException => + val classOption = findClassLocally(name) + classOption match { + case None => + // If this class has a cause, it will break the internal assumption of Janino + // (the compiler used for Spark SQL code-gen). + // See org.codehaus.janino.ClassLoaderIClassLoader's findIClass, you will see + // its behavior will be changed if there is a cause and the compilation + // of generated class will fail. + throw new ClassNotFoundException(name) + case Some(a) => a } - } } } } diff --git a/scalastyle-config.xml b/scalastyle-config.xml index 472a8f4084..a14e3e583f 100644 --- a/scalastyle-config.xml +++ b/scalastyle-config.xml @@ -228,6 +228,11 @@ This file is divided into 3 sections: Use Javadoc style indentation for multiline comments + + case[^\n>]*=>\s*\{ + Omit braces in case clauses. + + diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index d842ffdc66..0f8876a9e6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -898,7 +898,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w val result = ctx.freshName("result") val tmpRow = ctx.freshName("tmpRow") - val fieldsEvalCode = fieldsCasts.zipWithIndex.map { case (cast, i) => { + val fieldsEvalCode = fieldsCasts.zipWithIndex.map { case (cast, i) => val fromFieldPrim = ctx.freshName("ffp") val fromFieldNull = ctx.freshName("ffn") val toFieldPrim = ctx.freshName("tfp") @@ -920,7 +920,6 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w } } """ - } }.mkString("\n") (c, evPrim, evNull) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala index affd1bdb32..8d8cc152ff 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala @@ -97,11 +97,11 @@ class EquivalentExpressions { def debugString(all: Boolean = false): String = { val sb: mutable.StringBuilder = new StringBuilder() sb.append("Equivalent expressions:\n") - equivalenceMap.foreach { case (k, v) => { + equivalenceMap.foreach { case (k, v) => if (all || v.length > 1) { sb.append(" " + v.mkString(", ")).append("\n") } - }} + } sb.toString() } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala index 8207d64798..711e870711 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala @@ -196,12 +196,11 @@ object RandomDataGenerator { case ShortType => randomNumeric[Short]( rand, _.nextInt().toShort, Seq(Short.MinValue, Short.MaxValue, 0.toShort)) case NullType => Some(() => null) - case ArrayType(elementType, containsNull) => { + case ArrayType(elementType, containsNull) => forType(elementType, nullable = containsNull, rand).map { elementGenerator => () => Seq.fill(rand.nextInt(MAX_ARR_SIZE))(elementGenerator()) } - } - case MapType(keyType, valueType, valueContainsNull) => { + case MapType(keyType, valueType, valueContainsNull) => for ( keyGenerator <- forType(keyType, nullable = false, rand); valueGenerator <- @@ -221,8 +220,7 @@ object RandomDataGenerator { keys.zip(values).toMap } } - } - case StructType(fields) => { + case StructType(fields) => val maybeFieldGenerators: Seq[Option[() => Any]] = fields.map { field => forType(field.dataType, nullable = field.nullable, rand) } @@ -232,8 +230,7 @@ object RandomDataGenerator { } else { None } - } - case udt: UserDefinedType[_] => { + case udt: UserDefinedType[_] => val maybeSqlTypeGenerator = forType(udt.sqlType, nullable, rand) // Because random data generator at here returns scala value, we need to // convert it to catalyst value to call udt's deserialize. @@ -253,7 +250,6 @@ object RandomDataGenerator { } else { None } - } case unsupportedType => None } // Handle nullability by wrapping the non-null value generator: @@ -277,7 +273,7 @@ object RandomDataGenerator { val fields = mutable.ArrayBuffer.empty[Any] schema.fields.foreach { f => f.dataType match { - case ArrayType(childType, nullable) => { + case ArrayType(childType, nullable) => val data = if (f.nullable && rand.nextFloat() <= PROBABILITY_OF_NULL) { null } else { @@ -294,10 +290,8 @@ object RandomDataGenerator { arr } fields += data - } - case StructType(children) => { + case StructType(children) => fields += randomRow(rand, StructType(children)) - } case _ => val generator = RandomDataGenerator.forType(f.dataType, f.nullable, rand) assert(generator.isDefined, "Unsupported type") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 86c6405522..e953a6e8ef 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -1153,14 +1153,12 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { private def verifyNonExchangingAgg(df: DataFrame) = { var atFirstAgg: Boolean = false df.queryExecution.executedPlan.foreach { - case agg: TungstenAggregate => { + case agg: TungstenAggregate => atFirstAgg = !atFirstAgg - } - case _ => { + case _ => if (atFirstAgg) { fail("Should not have operators between the two aggregations") } - } } } @@ -1170,12 +1168,11 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { private def verifyExchangingAgg(df: DataFrame) = { var atFirstAgg: Boolean = false df.queryExecution.executedPlan.foreach { - case agg: TungstenAggregate => { + case agg: TungstenAggregate => if (atFirstAgg) { fail("Should not have back to back Aggregates") } atFirstAgg = true - } case e: ShuffleExchange => atFirstAgg = false case _ => } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala index 8a551cd78c..31b63f2ce1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala @@ -612,23 +612,20 @@ class ColumnarBatchSuite extends SparkFunSuite { val a2 = r2.getList(v._2).toArray assert(a1.length == a2.length, "Seed = " + seed) childType match { - case DoubleType => { + case DoubleType => var i = 0 while (i < a1.length) { assert(doubleEquals(a1(i).asInstanceOf[Double], a2(i).asInstanceOf[Double]), "Seed = " + seed) i += 1 } - } - case FloatType => { + case FloatType => var i = 0 while (i < a1.length) { assert(doubleEquals(a1(i).asInstanceOf[Float], a2(i).asInstanceOf[Float]), "Seed = " + seed) i += 1 } - } - case t: DecimalType => var i = 0 while (i < a1.length) { @@ -640,7 +637,6 @@ class ColumnarBatchSuite extends SparkFunSuite { } i += 1 } - case _ => assert(a1 === a2, "Seed = " + seed) } case StructType(childFields) => diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala index c40beeff97..58842f9c2f 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala @@ -429,13 +429,12 @@ abstract class DStream[T: ClassTag] ( */ private[streaming] def generateJob(time: Time): Option[Job] = { getOrCompute(time) match { - case Some(rdd) => { + case Some(rdd) => val jobFunc = () => { val emptyFunc = { (iterator: Iterator[T]) => {} } context.sparkContext.runJob(rdd, emptyFunc) } Some(new Job(time, jobFunc)) - } case None => None } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStreamCheckpointData.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStreamCheckpointData.scala index 431c9dbe2c..e73837eb96 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStreamCheckpointData.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStreamCheckpointData.scala @@ -109,10 +109,9 @@ class DStreamCheckpointData[T: ClassTag](dstream: DStream[T]) def restore() { // Create RDDs from the checkpoint data currentCheckpointFiles.foreach { - case(time, file) => { + case(time, file) => logInfo("Restoring checkpointed RDD for time " + time + " from file '" + file + "'") dstream.generatedRDDs += ((time, dstream.context.sparkContext.checkpointFile[T](file))) - } } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala index 7fba2e8ec0..36f50e04db 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala @@ -333,14 +333,13 @@ class FileInputDStream[K, V, F <: NewInputFormat[K, V]]( override def restore() { hadoopFiles.toSeq.sortBy(_._1)(Time.ordering).foreach { - case (t, f) => { + case (t, f) => // Restore the metadata in both files and generatedRDDs logInfo("Restoring files for time " + t + " - " + f.mkString("[", ", ", "]") ) batchTimeToSelectedFiles.synchronized { batchTimeToSelectedFiles += ((t, f)) } recentlySelectedFiles ++= f generatedRDDs += ((t, filesToRDD(f))) - } } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/StateDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/StateDStream.scala index 0379957e58..28aed0ca45 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/StateDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/StateDStream.scala @@ -65,14 +65,12 @@ class StateDStream[K: ClassTag, V: ClassTag, S: ClassTag]( // Try to get the previous state RDD getOrCompute(validTime - slideDuration) match { - case Some(prevStateRDD) => { // If previous state RDD exists - + case Some(prevStateRDD) => // If previous state RDD exists // Try to get the parent RDD parent.getOrCompute(validTime) match { - case Some(parentRDD) => { // If parent RDD exists, then compute as usual + case Some(parentRDD) => // If parent RDD exists, then compute as usual computeUsingPreviousRDD(parentRDD, prevStateRDD) - } - case None => { // If parent RDD does not exist + case None => // If parent RDD does not exist // Re-apply the update function to the old state RDD val updateFuncLocal = updateFunc @@ -82,17 +80,14 @@ class StateDStream[K: ClassTag, V: ClassTag, S: ClassTag]( } val stateRDD = prevStateRDD.mapPartitions(finalFunc, preservePartitioning) Some(stateRDD) - } } - } - - case None => { // If previous session RDD does not exist (first input data) + case None => // If previous session RDD does not exist (first input data) // Try to get the parent RDD parent.getOrCompute(validTime) match { - case Some(parentRDD) => { // If parent RDD exists, then compute as usual + case Some(parentRDD) => // If parent RDD exists, then compute as usual initialRDD match { - case None => { + case None => // Define the function for the mapPartition operation on grouped RDD; // first map the grouped tuple to tuples of required type, // and then apply the update function @@ -105,18 +100,13 @@ class StateDStream[K: ClassTag, V: ClassTag, S: ClassTag]( val sessionRDD = groupedRDD.mapPartitions(finalFunc, preservePartitioning) // logDebug("Generating state RDD for time " + validTime + " (first)") Some(sessionRDD) - } - case Some(initialStateRDD) => { + case Some(initialStateRDD) => computeUsingPreviousRDD(parentRDD, initialStateRDD) - } } - } - case None => { // If parent RDD does not exist, then nothing to do! + case None => // If parent RDD does not exist, then nothing to do! // logDebug("Not generating state RDD (no previous state, no parent)") None - } } - } } } } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala index bd60059b18..cfcbdc7c38 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala @@ -538,10 +538,9 @@ class BasicOperationsSuite extends TestSuiteBase { val stateObj = state.getOrElse(new StateObject) values.sum match { case 0 => stateObj.expireCounter += 1 // no new values - case n => { // has new values, increment and reset expireCounter + case n => // has new values, increment and reset expireCounter stateObj.counter += n stateObj.expireCounter = 0 - } } stateObj.expireCounter match { case 2 => None // seen twice with no new values, give it the boot diff --git a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala index fbb25d4c59..bdbac64b9b 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala @@ -267,10 +267,9 @@ class CheckpointSuite extends TestSuiteBase with DStreamCheckpointTester assert(!stateStream.checkpointData.currentCheckpointFiles.isEmpty, "No checkpointed RDDs in state stream before first failure") stateStream.checkpointData.currentCheckpointFiles.foreach { - case (time, file) => { + case (time, file) => assert(fs.exists(new Path(file)), "Checkpoint file '" + file +"' for time " + time + " for state stream before first failure does not exist") - } } // Run till a further time such that previous checkpoint files in the stream would be deleted @@ -297,10 +296,9 @@ class CheckpointSuite extends TestSuiteBase with DStreamCheckpointTester assert(!stateStream.checkpointData.currentCheckpointFiles.isEmpty, "No checkpointed RDDs in state stream before second failure") stateStream.checkpointData.currentCheckpointFiles.foreach { - case (time, file) => { + case (time, file) => assert(fs.exists(new Path(file)), "Checkpoint file '" + file +"' for time " + time + " for state stream before seconds failure does not exist") - } } ssc.stop() diff --git a/streaming/src/test/scala/org/apache/spark/streaming/MasterFailureTest.scala b/streaming/src/test/scala/org/apache/spark/streaming/MasterFailureTest.scala index 29bee4adf2..60c8e70235 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/MasterFailureTest.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/MasterFailureTest.scala @@ -382,11 +382,10 @@ class FileGeneratingThread(input: Seq[String], testDir: Path, interval: Long) fs.rename(tempHadoopFile, hadoopFile) done = true } catch { - case ioe: IOException => { + case ioe: IOException => fs = testDir.getFileSystem(new Configuration()) logWarning("Attempt " + tries + " at generating file " + hadoopFile + " failed.", ioe) - } } } if (!done) { diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index 9e8453429c..d447a59937 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -374,7 +374,7 @@ private[spark] class ApplicationMaster( failureCount = 0 } catch { case i: InterruptedException => - case e: Throwable => { + case e: Throwable => failureCount += 1 // this exception was introduced in hadoop 2.4 and this code would not compile // with earlier versions if we refer it directly. @@ -390,7 +390,6 @@ private[spark] class ApplicationMaster( } else { logWarning(s"Reporter thread fails $failureCount time(s) in a row.", e) } - } } try { val numPendingAllocate = allocator.getPendingAllocate.size diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala index b0bfe855e9..23742eab62 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala @@ -148,11 +148,10 @@ private[yarn] class YarnAllocator( classOf[Array[String]], classOf[Array[String]], classOf[Priority], classOf[Boolean], classOf[String])) } catch { - case e: NoSuchMethodException => { + case e: NoSuchMethodException => logWarning(s"Node label expression $expr will be ignored because YARN version on" + " classpath does not support it.") None - } } } diff --git a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala index 8720ee57fe..6b3c831e60 100644 --- a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala +++ b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala @@ -223,17 +223,15 @@ private[spark] abstract class YarnSchedulerBackend( val lossReasonRequest = GetExecutorLossReason(executorId) val future = am.ask[ExecutorLossReason](lossReasonRequest, askTimeout) future onSuccess { - case reason: ExecutorLossReason => { + case reason: ExecutorLossReason => driverEndpoint.askWithRetry[Boolean](RemoveExecutor(executorId, reason)) - } } future onFailure { - case NonFatal(e) => { + case NonFatal(e) => logWarning(s"Attempted to get executor loss reason" + s" for executor id ${executorId} at RPC address ${executorRpcAddress}," + s" but got no response. Marking as slave lost.", e) driverEndpoint.askWithRetry[Boolean](RemoveExecutor(executorId, SlaveLost())) - } case t => throw t } case None => diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala index de14e36f4e..fe09808ae5 100644 --- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala @@ -101,22 +101,18 @@ class YarnSparkHadoopUtilSuite extends SparkFunSuite with Matchers with Logging val modifyAcls = acls.get(ApplicationAccessType.MODIFY_APP) viewAcls match { - case Some(vacls) => { + case Some(vacls) => val aclSet = vacls.split(',').map(_.trim).toSet assert(aclSet.contains(System.getProperty("user.name", "invalid"))) - } - case None => { + case None => fail() - } } modifyAcls match { - case Some(macls) => { + case Some(macls) => val aclSet = macls.split(',').map(_.trim).toSet assert(aclSet.contains(System.getProperty("user.name", "invalid"))) - } - case None => { + case None => fail() - } } } @@ -135,26 +131,22 @@ class YarnSparkHadoopUtilSuite extends SparkFunSuite with Matchers with Logging val modifyAcls = acls.get(ApplicationAccessType.MODIFY_APP) viewAcls match { - case Some(vacls) => { + case Some(vacls) => val aclSet = vacls.split(',').map(_.trim).toSet assert(aclSet.contains("user1")) assert(aclSet.contains("user2")) assert(aclSet.contains(System.getProperty("user.name", "invalid"))) - } - case None => { + case None => fail() - } } modifyAcls match { - case Some(macls) => { + case Some(macls) => val aclSet = macls.split(',').map(_.trim).toSet assert(aclSet.contains("user3")) assert(aclSet.contains("user4")) assert(aclSet.contains(System.getProperty("user.name", "invalid"))) - } - case None => { + case None => fail() - } } } -- cgit v1.2.3 From 124cbfb683a5e959e1b5181d4d0cc56956b50385 Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Tue, 12 Apr 2016 22:28:57 +0800 Subject: [SPARK-14488][SPARK-14493][SQL] "CREATE TEMPORARY TABLE ... USING ... AS SELECT" shouldn't create persisted table ## What changes were proposed in this pull request? When planning logical plan node `CreateTableUsingAsSelect`, we neglected its `temporary` field and always generates a `CreateMetastoreDataSourceAsSelect`. This PR fixes this issue generating `CreateTempTableUsingAsSelect` when `temporary` is true. This PR also fixes SPARK-14493 since the root cause of SPARK-14493 is that we were `CreateMetastoreDataSourceAsSelect` uses default Hive warehouse location when `PATH` data source option is absent. ## How was this patch tested? Added a test case to create a temporary table using the target syntax and check whether it's indeed a temporary table. Author: Cheng Lian Closes #12303 from liancheng/spark-14488-fix-ctas-using. --- .../org/apache/spark/sql/hive/HiveStrategies.scala | 10 +++-- .../spark/sql/hive/execution/SQLQuerySuite.scala | 49 ++++++++++++++++++++-- 2 files changed, 53 insertions(+), 6 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala index f44937ec6f..010361a32e 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala @@ -23,9 +23,8 @@ import org.apache.spark.sql.catalyst.planning._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution._ -import org.apache.spark.sql.execution.command.{DescribeCommand => RunnableDescribeCommand, _} -import org.apache.spark.sql.execution.datasources.{CreateTableUsing, CreateTableUsingAsSelect, - DescribeCommand} +import org.apache.spark.sql.execution.command.{DescribeCommand => _, _} +import org.apache.spark.sql.execution.datasources.{CreateTableUsing, CreateTableUsingAsSelect, CreateTempTableUsingAsSelect, DescribeCommand} import org.apache.spark.sql.hive.execution._ private[hive] trait HiveStrategies { @@ -90,6 +89,11 @@ private[hive] trait HiveStrategies { tableIdent, userSpecifiedSchema, provider, opts, allowExisting, managedIfNoPath) ExecutedCommand(cmd) :: Nil + case c: CreateTableUsingAsSelect if c.temporary => + val cmd = CreateTempTableUsingAsSelect( + c.tableIdent, c.provider, c.partitionColumns, c.mode, c.options, c.child) + ExecutedCommand(cmd) :: Nil + case c: CreateTableUsingAsSelect => val cmd = CreateMetastoreDataSourceAsSelect(c.tableIdent, c.provider, c.partitionColumns, c.bucketSpec, c.mode, c.options, c.child) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index b4886eba7a..7eaf19dfe9 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -19,12 +19,9 @@ package org.apache.spark.sql.hive.execution import java.sql.{Date, Timestamp} -import scala.collection.JavaConverters._ - import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.{EliminateSubqueryAliases, FunctionRegistry} -import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.functions._ import org.apache.spark.sql.hive.{HiveContext, MetastoreRelation} @@ -1852,4 +1849,50 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { } } } + + test( + "SPARK-14488 \"CREATE TEMPORARY TABLE ... USING ... AS SELECT ...\" " + + "shouldn't create persisted table" + ) { + withTempPath { dir => + withTempTable("t1", "t2") { + val path = dir.getCanonicalPath + val ds = sqlContext.range(10) + ds.registerTempTable("t1") + + sql( + s"""CREATE TEMPORARY TABLE t2 + |USING PARQUET + |OPTIONS (PATH '$path') + |AS SELECT * FROM t1 + """.stripMargin) + + checkAnswer( + sqlContext.tables().select('isTemporary).filter('tableName === "t2"), + Row(true) + ) + + checkAnswer(table("t2"), table("t1")) + } + } + } + + test( + "SPARK-14493 \"CREATE TEMPORARY TABLE ... USING ... AS SELECT ...\" " + + "shouldn always be used together with PATH data source option" + ) { + withTempTable("t") { + sqlContext.range(10).registerTempTable("t") + + val message = intercept[IllegalArgumentException] { + sql( + s"""CREATE TEMPORARY TABLE t1 + |USING PARQUET + |AS SELECT * FROM t + """.stripMargin) + }.getMessage + + assert(message == "'path' is not specified") + } + } } -- cgit v1.2.3 From da60b34d2f6eba19633e4f1b46504ce92cd6c179 Mon Sep 17 00:00:00 2001 From: Yong Tang Date: Tue, 12 Apr 2016 16:53:26 +0200 Subject: [SPARK-3724][ML] RandomForest: More options for feature subset size. ## What changes were proposed in this pull request? This PR tries to support more options for feature subset size in RandomForest implementation. Previously, RandomForest only support "auto", "all", "sort", "log2", "onethird". This PR tries to support any given value to allow model search. In this PR, `featureSubsetStrategy` could be passed with: a) a real number in the range of `(0.0-1.0]` that represents the fraction of the number of features in each subset, b) an integer number (`>0`) that represents the number of features in each subset. ## How was this patch tested? Two tests `JavaRandomForestClassifierSuite` and `JavaRandomForestRegressorSuite` have been updated to check the additional options for params in this PR. An additional test has been added to `org.apache.spark.mllib.tree.RandomForestSuite` to cover the cases in this PR. Author: Yong Tang Closes #11989 from yongtang/SPARK-3724. --- .../spark/ml/tree/impl/DecisionTreeMetadata.scala | 5 +++ .../org/apache/spark/ml/tree/treeParams.scala | 8 ++++- .../org/apache/spark/mllib/tree/RandomForest.scala | 11 +++++-- .../JavaRandomForestClassifierSuite.java | 19 ++++++++++++ .../regression/JavaRandomForestRegressorSuite.java | 19 ++++++++++++ .../spark/ml/tree/impl/RandomForestSuite.scala | 36 ++++++++++++++++++++++ 6 files changed, 95 insertions(+), 3 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DecisionTreeMetadata.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DecisionTreeMetadata.scala index df8eb5d1f9..c7cde1563f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DecisionTreeMetadata.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DecisionTreeMetadata.scala @@ -183,11 +183,16 @@ private[spark] object DecisionTreeMetadata extends Logging { } case _ => featureSubsetStrategy } + + val isIntRegex = "^([1-9]\\d*)$".r + val isFractionRegex = "^(0?\\.\\d*[1-9]\\d*|1\\.0+)$".r val numFeaturesPerNode: Int = _featureSubsetStrategy match { case "all" => numFeatures case "sqrt" => math.sqrt(numFeatures).ceil.toInt case "log2" => math.max(1, (math.log(numFeatures) / math.log(2)).ceil.toInt) case "onethird" => (numFeatures / 3.0).ceil.toInt + case isIntRegex(number) => if (BigInt(number) > numFeatures) numFeatures else number.toInt + case isFractionRegex(fraction) => (fraction.toDouble * numFeatures).ceil.toInt } new DecisionTreeMetadata(numFeatures, numExamples, numClasses, numBins.max, diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala index 78e6d3bfac..0767dc17e5 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala @@ -329,6 +329,8 @@ private[ml] trait HasFeatureSubsetStrategy extends Params { * - "onethird": use 1/3 of the features * - "sqrt": use sqrt(number of features) * - "log2": use log2(number of features) + * - "n": when n is in the range (0, 1.0], use n * number of features. When n + * is in the range (1, number of features), use n features. * (default = "auto") * * These various settings are based on the following references: @@ -346,7 +348,8 @@ private[ml] trait HasFeatureSubsetStrategy extends Params { "The number of features to consider for splits at each tree node." + s" Supported options: ${RandomForestParams.supportedFeatureSubsetStrategies.mkString(", ")}", (value: String) => - RandomForestParams.supportedFeatureSubsetStrategies.contains(value.toLowerCase)) + RandomForestParams.supportedFeatureSubsetStrategies.contains(value.toLowerCase) + || value.matches(RandomForestParams.supportedFeatureSubsetStrategiesRegex)) setDefault(featureSubsetStrategy -> "auto") @@ -393,6 +396,9 @@ private[spark] object RandomForestParams { // These options should be lowercase. final val supportedFeatureSubsetStrategies: Array[String] = Array("auto", "all", "onethird", "sqrt", "log2").map(_.toLowerCase) + + // The regex to capture "(0.0-1.0]", and "n" for integer 0 < n <= (number of features) + final val supportedFeatureSubsetStrategiesRegex = "^(?:[1-9]\\d*|0?\\.\\d*[1-9]\\d*|1\\.0+)$" } private[ml] trait RandomForestClassifierParams diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala index 1841fa4a95..26755849ad 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala @@ -55,10 +55,15 @@ import org.apache.spark.util.Utils * @param numTrees If 1, then no bootstrapping is used. If > 1, then bootstrapping is done. * @param featureSubsetStrategy Number of features to consider for splits at each node. * Supported values: "auto", "all", "sqrt", "log2", "onethird". + * Supported numerical values: "(0.0-1.0]", "[1-n]". * If "auto" is set, this parameter is set based on numTrees: * if numTrees == 1, set to "all"; * if numTrees > 1 (forest) set to "sqrt" for classification and * to "onethird" for regression. + * If a real value "n" in the range (0, 1.0] is set, + * use n * number of features. + * If an integer value "n" in the range (1, num features) is set, + * use n features. * @param seed Random seed for bootstrapping and choosing feature subsets. */ private class RandomForest ( @@ -70,9 +75,11 @@ private class RandomForest ( strategy.assertValid() require(numTrees > 0, s"RandomForest requires numTrees > 0, but was given numTrees = $numTrees.") - require(RandomForest.supportedFeatureSubsetStrategies.contains(featureSubsetStrategy), + require(RandomForest.supportedFeatureSubsetStrategies.contains(featureSubsetStrategy) + || featureSubsetStrategy.matches(NewRFParams.supportedFeatureSubsetStrategiesRegex), s"RandomForest given invalid featureSubsetStrategy: $featureSubsetStrategy." + - s" Supported values: ${RandomForest.supportedFeatureSubsetStrategies.mkString(", ")}.") + s" Supported values: ${RandomForest.supportedFeatureSubsetStrategies.mkString(", ")}," + + s" (0.0-1.0], [1-n].") /** * Method to train a decision tree model over an RDD diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaRandomForestClassifierSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaRandomForestClassifierSuite.java index 75061464e5..5aec52ac72 100644 --- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaRandomForestClassifierSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaRandomForestClassifierSuite.java @@ -22,6 +22,7 @@ import java.util.HashMap; import java.util.Map; import org.junit.After; +import org.junit.Assert; import org.junit.Before; import org.junit.Test; @@ -80,6 +81,24 @@ public class JavaRandomForestClassifierSuite implements Serializable { for (String featureSubsetStrategy: RandomForestClassifier.supportedFeatureSubsetStrategies()) { rf.setFeatureSubsetStrategy(featureSubsetStrategy); } + String realStrategies[] = {".1", ".10", "0.10", "0.1", "0.9", "1.0"}; + for (String strategy: realStrategies) { + rf.setFeatureSubsetStrategy(strategy); + } + String integerStrategies[] = {"1", "10", "100", "1000", "10000"}; + for (String strategy: integerStrategies) { + rf.setFeatureSubsetStrategy(strategy); + } + String invalidStrategies[] = {"-.1", "-.10", "-0.10", ".0", "0.0", "1.1", "0"}; + for (String strategy: invalidStrategies) { + try { + rf.setFeatureSubsetStrategy(strategy); + Assert.fail("Expected exception to be thrown for invalid strategies"); + } catch (Exception e) { + Assert.assertTrue(e instanceof IllegalArgumentException); + } + } + RandomForestClassificationModel model = rf.fit(dataFrame); model.transform(dataFrame); diff --git a/mllib/src/test/java/org/apache/spark/ml/regression/JavaRandomForestRegressorSuite.java b/mllib/src/test/java/org/apache/spark/ml/regression/JavaRandomForestRegressorSuite.java index b6f793f6de..a8736669f7 100644 --- a/mllib/src/test/java/org/apache/spark/ml/regression/JavaRandomForestRegressorSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/regression/JavaRandomForestRegressorSuite.java @@ -22,6 +22,7 @@ import java.util.HashMap; import java.util.Map; import org.junit.After; +import org.junit.Assert; import org.junit.Before; import org.junit.Test; @@ -80,6 +81,24 @@ public class JavaRandomForestRegressorSuite implements Serializable { for (String featureSubsetStrategy: RandomForestRegressor.supportedFeatureSubsetStrategies()) { rf.setFeatureSubsetStrategy(featureSubsetStrategy); } + String realStrategies[] = {".1", ".10", "0.10", "0.1", "0.9", "1.0"}; + for (String strategy: realStrategies) { + rf.setFeatureSubsetStrategy(strategy); + } + String integerStrategies[] = {"1", "10", "100", "1000", "10000"}; + for (String strategy: integerStrategies) { + rf.setFeatureSubsetStrategy(strategy); + } + String invalidStrategies[] = {"-.1", "-.10", "-0.10", ".0", "0.0", "1.1", "0"}; + for (String strategy: invalidStrategies) { + try { + rf.setFeatureSubsetStrategy(strategy); + Assert.fail("Expected exception to be thrown for invalid strategies"); + } catch (Exception e) { + Assert.assertTrue(e instanceof IllegalArgumentException); + } + } + RandomForestRegressionModel model = rf.fit(dataFrame); model.transform(dataFrame); diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala index cd402b1e1f..6db9ce150d 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala @@ -426,12 +426,48 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext { (math.log(numFeatures) / math.log(2)).ceil.toInt) checkFeatureSubsetStrategy(numTrees = 1, "onethird", (numFeatures / 3.0).ceil.toInt) + val realStrategies = Array(".1", ".10", "0.10", "0.1", "0.9", "1.0") + for (strategy <- realStrategies) { + val expected = (strategy.toDouble * numFeatures).ceil.toInt + checkFeatureSubsetStrategy(numTrees = 1, strategy, expected) + } + + val integerStrategies = Array("1", "10", "100", "1000", "10000") + for (strategy <- integerStrategies) { + val expected = if (strategy.toInt < numFeatures) strategy.toInt else numFeatures + checkFeatureSubsetStrategy(numTrees = 1, strategy, expected) + } + + val invalidStrategies = Array("-.1", "-.10", "-0.10", ".0", "0.0", "1.1", "0") + for (invalidStrategy <- invalidStrategies) { + intercept[MatchError]{ + val metadata = + DecisionTreeMetadata.buildMetadata(rdd, strategy, numTrees = 1, invalidStrategy) + } + } + checkFeatureSubsetStrategy(numTrees = 2, "all", numFeatures) checkFeatureSubsetStrategy(numTrees = 2, "auto", math.sqrt(numFeatures).ceil.toInt) checkFeatureSubsetStrategy(numTrees = 2, "sqrt", math.sqrt(numFeatures).ceil.toInt) checkFeatureSubsetStrategy(numTrees = 2, "log2", (math.log(numFeatures) / math.log(2)).ceil.toInt) checkFeatureSubsetStrategy(numTrees = 2, "onethird", (numFeatures / 3.0).ceil.toInt) + + for (strategy <- realStrategies) { + val expected = (strategy.toDouble * numFeatures).ceil.toInt + checkFeatureSubsetStrategy(numTrees = 2, strategy, expected) + } + + for (strategy <- integerStrategies) { + val expected = if (strategy.toInt < numFeatures) strategy.toInt else numFeatures + checkFeatureSubsetStrategy(numTrees = 2, strategy, expected) + } + for (invalidStrategy <- invalidStrategies) { + intercept[MatchError]{ + val metadata = + DecisionTreeMetadata.buildMetadata(rdd, strategy, numTrees = 2, invalidStrategy) + } + } } test("Binary classification with continuous features: subsampling features") { -- cgit v1.2.3 From 6bf692147c21dd74e91e2bd95845f11ef0a303e6 Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Tue, 12 Apr 2016 10:46:28 -0700 Subject: [SPARK-14474][SQL] Move FileSource offset log into checkpointLocation ## What changes were proposed in this pull request? Now that we have a single location for storing checkpointed state. This PR just propagates the checkpoint location into FileStreamSource so that we don't have one random log off on its own. ## How was this patch tested? test("metadataPath should be in checkpointLocation") Author: Shixiong Zhu Closes #12247 from zsxwing/file-source-log-location. --- .../apache/spark/sql/ContinuousQueryManager.scala | 5 +- .../sql/execution/datasources/DataSource.scala | 62 ++++++++++++------ .../execution/streaming/StreamingRelation.scala | 4 +- .../org/apache/spark/sql/sources/interfaces.scala | 9 +++ .../sql/streaming/DataFrameReaderWriterSuite.scala | 73 ++++++++++++++++++++-- .../sql/streaming/FileStreamSourceSuite.scala | 10 +-- .../spark/sql/streaming/MemorySinkSuite.scala | 2 +- .../apache/spark/sql/streaming/StreamSuite.scala | 9 +++ 8 files changed, 141 insertions(+), 33 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/ContinuousQueryManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/ContinuousQueryManager.scala index d7f71bd4b0..1343e81569 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/ContinuousQueryManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/ContinuousQueryManager.scala @@ -178,10 +178,13 @@ class ContinuousQueryManager(sqlContext: SQLContext) { throw new IllegalArgumentException( s"Cannot start query with name $name as a query with that name is already active") } + var nextSourceId = 0L val logicalPlan = df.logicalPlan.transform { case StreamingRelation(dataSource, _, output) => // Materialize source to avoid creating it in every batch - val source = dataSource.createSource() + val metadataPath = s"$checkpointLocation/sources/$nextSourceId" + val source = dataSource.createSource(metadataPath) + nextSourceId += 1 // We still need to use the previous `output` instead of `source.schema` as attributes in // "df.logicalPlan" has already used attributes of the previous `output`. StreamingExecutionRelation(source, output) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala index f55cedb1b6..10fde152ab 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala @@ -123,36 +123,58 @@ case class DataSource( } } - /** Returns a source that can be used to continually read data. */ - def createSource(): Source = { + private def inferFileFormatSchema(format: FileFormat): StructType = { + val caseInsensitiveOptions = new CaseInsensitiveMap(options) + val allPaths = caseInsensitiveOptions.get("path") + val globbedPaths = allPaths.toSeq.flatMap { path => + val hdfsPath = new Path(path) + val fs = hdfsPath.getFileSystem(sqlContext.sparkContext.hadoopConfiguration) + val qualified = hdfsPath.makeQualified(fs.getUri, fs.getWorkingDirectory) + SparkHadoopUtil.get.globPathIfNecessary(qualified) + }.toArray + + val fileCatalog: FileCatalog = new HDFSFileCatalog(sqlContext, options, globbedPaths, None) + userSpecifiedSchema.orElse { + format.inferSchema( + sqlContext, + caseInsensitiveOptions, + fileCatalog.allFiles()) + }.getOrElse { + throw new AnalysisException("Unable to infer schema. It must be specified manually.") + } + } + + /** Returns the name and schema of the source that can be used to continually read data. */ + def sourceSchema(): (String, StructType) = { providingClass.newInstance() match { case s: StreamSourceProvider => - s.createSource(sqlContext, userSpecifiedSchema, className, options) + s.sourceSchema(sqlContext, userSpecifiedSchema, className, options) case format: FileFormat => val caseInsensitiveOptions = new CaseInsensitiveMap(options) val path = caseInsensitiveOptions.getOrElse("path", { throw new IllegalArgumentException("'path' is not specified") }) - val metadataPath = caseInsensitiveOptions.getOrElse("metadataPath", s"$path/_metadata") + (s"FileSource[$path]", inferFileFormatSchema(format)) + case _ => + throw new UnsupportedOperationException( + s"Data source $className does not support streamed reading") + } + } - val allPaths = caseInsensitiveOptions.get("path") - val globbedPaths = allPaths.toSeq.flatMap { path => - val hdfsPath = new Path(path) - val fs = hdfsPath.getFileSystem(sqlContext.sparkContext.hadoopConfiguration) - val qualified = hdfsPath.makeQualified(fs.getUri, fs.getWorkingDirectory) - SparkHadoopUtil.get.globPathIfNecessary(qualified) - }.toArray + /** Returns a source that can be used to continually read data. */ + def createSource(metadataPath: String): Source = { + providingClass.newInstance() match { + case s: StreamSourceProvider => + s.createSource(sqlContext, metadataPath, userSpecifiedSchema, className, options) - val fileCatalog: FileCatalog = new HDFSFileCatalog(sqlContext, options, globbedPaths, None) - val dataSchema = userSpecifiedSchema.orElse { - format.inferSchema( - sqlContext, - caseInsensitiveOptions, - fileCatalog.allFiles()) - }.getOrElse { - throw new AnalysisException("Unable to infer schema. It must be specified manually.") - } + case format: FileFormat => + val caseInsensitiveOptions = new CaseInsensitiveMap(options) + val path = caseInsensitiveOptions.getOrElse("path", { + throw new IllegalArgumentException("'path' is not specified") + }) + + val dataSchema = inferFileFormatSchema(format) def dataFrameBuilder(files: Array[String]): DataFrame = { Dataset.ofRows( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala index f951dea735..d2872e49ce 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala @@ -23,8 +23,8 @@ import org.apache.spark.sql.execution.datasources.DataSource object StreamingRelation { def apply(dataSource: DataSource): StreamingRelation = { - val source = dataSource.createSource() - StreamingRelation(dataSource, source.toString, source.schema.toAttributes) + val (name, schema) = dataSource.sourceSchema() + StreamingRelation(dataSource, name, schema.toAttributes) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala index 65b1f61349..bea243a3be 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala @@ -129,8 +129,17 @@ trait SchemaRelationProvider { * Implemented by objects that can produce a streaming [[Source]] for a specific format or system. */ trait StreamSourceProvider { + + /** Returns the name and schema of the source that can be used to continually read data. */ + def sourceSchema( + sqlContext: SQLContext, + schema: Option[StructType], + providerName: String, + parameters: Map[String, String]): (String, StructType) + def createSource( sqlContext: SQLContext, + metadataPath: String, schema: Option[StructType], providerName: String, parameters: Map[String, String]): Source diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/DataFrameReaderWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/DataFrameReaderWriterSuite.scala index 28c558208f..00efe21d39 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/DataFrameReaderWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/DataFrameReaderWriterSuite.scala @@ -21,6 +21,7 @@ import java.util.concurrent.TimeUnit import scala.concurrent.duration._ +import org.mockito.Mockito._ import org.scalatest.BeforeAndAfter import org.apache.spark.sql._ @@ -31,22 +32,50 @@ import org.apache.spark.sql.types.{IntegerType, StructField, StructType} import org.apache.spark.util.Utils object LastOptions { + + var mockStreamSourceProvider = mock(classOf[StreamSourceProvider]) + var mockStreamSinkProvider = mock(classOf[StreamSinkProvider]) var parameters: Map[String, String] = null var schema: Option[StructType] = null var partitionColumns: Seq[String] = Nil + + def clear(): Unit = { + parameters = null + schema = null + partitionColumns = null + reset(mockStreamSourceProvider) + reset(mockStreamSinkProvider) + } } /** Dummy provider: returns no-op source/sink and records options in [[LastOptions]]. */ class DefaultSource extends StreamSourceProvider with StreamSinkProvider { + + private val fakeSchema = StructType(StructField("a", IntegerType) :: Nil) + + override def sourceSchema( + sqlContext: SQLContext, + schema: Option[StructType], + providerName: String, + parameters: Map[String, String]): (String, StructType) = { + LastOptions.parameters = parameters + LastOptions.schema = schema + LastOptions.mockStreamSourceProvider.sourceSchema(sqlContext, schema, providerName, parameters) + ("dummySource", fakeSchema) + } + override def createSource( sqlContext: SQLContext, + metadataPath: String, schema: Option[StructType], providerName: String, parameters: Map[String, String]): Source = { LastOptions.parameters = parameters LastOptions.schema = schema + LastOptions.mockStreamSourceProvider.createSource( + sqlContext, metadataPath, schema, providerName, parameters) new Source { - override def schema: StructType = StructType(StructField("a", IntegerType) :: Nil) + override def schema: StructType = fakeSchema override def getOffset: Option[Offset] = Some(new LongOffset(0)) @@ -64,6 +93,7 @@ class DefaultSource extends StreamSourceProvider with StreamSinkProvider { partitionColumns: Seq[String]): Sink = { LastOptions.parameters = parameters LastOptions.partitionColumns = partitionColumns + LastOptions.mockStreamSinkProvider.createSink(sqlContext, parameters, partitionColumns) new Sink { override def addBatch(batchId: Long, data: DataFrame): Unit = {} } @@ -117,7 +147,7 @@ class DataFrameReaderWriterSuite extends StreamTest with SharedSQLContext with B assert(LastOptions.parameters("opt2") == "2") assert(LastOptions.parameters("opt3") == "3") - LastOptions.parameters = null + LastOptions.clear() df.write .format("org.apache.spark.sql.streaming.test") @@ -181,7 +211,7 @@ class DataFrameReaderWriterSuite extends StreamTest with SharedSQLContext with B assert(LastOptions.parameters("path") == "/test") - LastOptions.parameters = null + LastOptions.clear() df.write .format("org.apache.spark.sql.streaming.test") @@ -204,7 +234,7 @@ class DataFrameReaderWriterSuite extends StreamTest with SharedSQLContext with B assert(LastOptions.parameters("boolOpt") == "false") assert(LastOptions.parameters("doubleOpt") == "6.7") - LastOptions.parameters = null + LastOptions.clear() df.write .format("org.apache.spark.sql.streaming.test") .option("intOpt", 56) @@ -303,4 +333,39 @@ class DataFrameReaderWriterSuite extends StreamTest with SharedSQLContext with B assert(q.asInstanceOf[StreamExecution].trigger == ProcessingTime(100000)) } + + test("source metadataPath") { + LastOptions.clear() + + val checkpointLocation = newMetadataDir + + val df1 = sqlContext.read + .format("org.apache.spark.sql.streaming.test") + .stream() + + val df2 = sqlContext.read + .format("org.apache.spark.sql.streaming.test") + .stream() + + val q = df1.union(df2).write + .format("org.apache.spark.sql.streaming.test") + .option("checkpointLocation", checkpointLocation) + .trigger(ProcessingTime(10.seconds)) + .startStream() + q.stop() + + verify(LastOptions.mockStreamSourceProvider).createSource( + sqlContext, + checkpointLocation + "/sources/0", + None, + "org.apache.spark.sql.streaming.test", + Map.empty) + + verify(LastOptions.mockStreamSourceProvider).createSource( + sqlContext, + checkpointLocation + "/sources/1", + None, + "org.apache.spark.sql.streaming.test", + Map.empty) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala index 09daa7f81a..73d1b1b1d5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala @@ -63,6 +63,7 @@ class FileStreamSourceTest extends StreamTest with SharedSQLContext { format: String, path: String, schema: Option[StructType] = None): FileStreamSource = { + val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath val reader = if (schema.isDefined) { sqlContext.read.format(format).schema(schema.get) @@ -72,7 +73,8 @@ class FileStreamSourceTest extends StreamTest with SharedSQLContext { reader.stream(path) .queryExecution.analyzed .collect { case StreamingRelation(dataSource, _, _) => - dataSource.createSource().asInstanceOf[FileStreamSource] + // There is only one source in our tests so just set sourceId to 0 + dataSource.createSource(s"$checkpointLocation/sources/0").asInstanceOf[FileStreamSource] }.head } @@ -98,9 +100,8 @@ class FileStreamSourceSuite extends FileStreamSourceTest with SharedSQLContext { } df.queryExecution.analyzed .collect { case StreamingRelation(dataSource, _, _) => - dataSource.createSource().asInstanceOf[FileStreamSource] - }.head - .schema + dataSource.sourceSchema() + }.head._2 } test("FileStreamSource schema: no path") { @@ -340,7 +341,6 @@ class FileStreamSourceSuite extends FileStreamSourceTest with SharedSQLContext { Utils.deleteRecursively(src) Utils.deleteRecursively(tmp) } - } class FileStreamSourceStressTestSuite extends FileStreamSourceTest with SharedSQLContext { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/MemorySinkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/MemorySinkSuite.scala index 5249aa28dd..1f28340545 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/MemorySinkSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/MemorySinkSuite.scala @@ -59,7 +59,7 @@ class MemorySinkSuite extends StreamTest with SharedSQLContext { } test("error if attempting to resume specific checkpoint") { - val location = Utils.createTempDir("steaming.checkpoint").getCanonicalPath + val location = Utils.createTempDir(namePrefix = "steaming.checkpoint").getCanonicalPath val input = MemoryStream[Int] val query = input.toDF().write diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala index e4ea555526..2bd27c7efd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala @@ -115,8 +115,17 @@ class StreamSuite extends StreamTest with SharedSQLContext { */ class FakeDefaultSource extends StreamSourceProvider { + private val fakeSchema = StructType(StructField("a", IntegerType) :: Nil) + + override def sourceSchema( + sqlContext: SQLContext, + schema: Option[StructType], + providerName: String, + parameters: Map[String, String]): (String, StructType) = ("fakeSource", fakeSchema) + override def createSource( sqlContext: SQLContext, + metadataPath: String, schema: Option[StructType], providerName: String, parameters: Map[String, String]): Source = { -- cgit v1.2.3 From 75e05a5a964c9585dd09a2ef6178881929bab1f1 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Tue, 12 Apr 2016 10:51:07 -0700 Subject: [SPARK-12566][SPARK-14324][ML] GLM model family, link function support in SparkR:::glm * SparkR glm supports families and link functions which match R's signature for family. * SparkR glm API refactor. The comparative standard of the new API is R glm, so I only expose the arguments that R glm supports: ```formula, family, data, epsilon and maxit```. * This PR is focus on glm() and predict(), summary statistics will be done in a separate PR after this get in. * This PR depends on #12287 which make GLMs support link prediction at Scala side. After that merged, I will add more tests for predict() to this PR. Unit tests. cc mengxr jkbradley hhbyyh Author: Yanbo Liang Closes #12294 from yanboliang/spark-12566. --- R/pkg/R/mllib.R | 139 +++++++++------------ R/pkg/inst/tests/testthat/test_mllib.R | 95 +++++--------- .../ml/r/GeneralizedLinearRegressionWrapper.scala | 79 ++++++++++++ .../org/apache/spark/ml/r/SparkRWrappers.scala | 115 ----------------- 4 files changed, 169 insertions(+), 259 deletions(-) create mode 100644 mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala delete mode 100644 mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala diff --git a/R/pkg/R/mllib.R b/R/pkg/R/mllib.R index f3152cc232..31bca16580 100644 --- a/R/pkg/R/mllib.R +++ b/R/pkg/R/mllib.R @@ -17,10 +17,10 @@ # mllib.R: Provides methods for MLlib integration -#' @title S4 class that represents a PipelineModel -#' @param model A Java object reference to the backing Scala PipelineModel +#' @title S4 class that represents a generalized linear model +#' @param jobj a Java object reference to the backing Scala GeneralizedLinearRegressionWrapper #' @export -setClass("PipelineModel", representation(model = "jobj")) +setClass("GeneralizedLinearRegressionModel", representation(jobj = "jobj")) #' @title S4 class that represents a NaiveBayesModel #' @param jobj a Java object reference to the backing Scala NaiveBayesWrapper @@ -39,21 +39,18 @@ setClass("KMeansModel", representation(jobj = "jobj")) #' Fits a generalized linear model #' -#' Fits a generalized linear model, similarly to R's glm(). Also see the glmnet package. +#' Fits a generalized linear model, similarly to R's glm(). #' #' @param formula A symbolic description of the model to be fitted. Currently only a few formula #' operators are supported, including '~', '.', ':', '+', and '-'. -#' @param data DataFrame for training -#' @param family Error distribution. "gaussian" -> linear regression, "binomial" -> logistic reg. -#' @param lambda Regularization parameter -#' @param alpha Elastic-net mixing parameter (see glmnet's documentation for details) -#' @param standardize Whether to standardize features before training -#' @param solver The solver algorithm used for optimization, this can be "l-bfgs", "normal" and -#' "auto". "l-bfgs" denotes Limited-memory BFGS which is a limited-memory -#' quasi-Newton optimization method. "normal" denotes using Normal Equation as an -#' analytical solution to the linear regression problem. The default value is "auto" -#' which means that the solver algorithm is selected automatically. -#' @return a fitted MLlib model +#' @param data DataFrame for training. +#' @param family A description of the error distribution and link function to be used in the model. +#' This can be a character string naming a family function, a family function or +#' the result of a call to a family function. Refer R family at +#' \url{https://stat.ethz.ch/R-manual/R-devel/library/stats/html/family.html}. +#' @param epsilon Positive convergence tolerance of iterations. +#' @param maxit Integer giving the maximal number of IRLS iterations. +#' @return a fitted generalized linear model #' @rdname glm #' @export #' @examples @@ -64,25 +61,59 @@ setClass("KMeansModel", representation(jobj = "jobj")) #' df <- createDataFrame(sqlContext, iris) #' model <- glm(Sepal_Length ~ Sepal_Width, df, family="gaussian") #' summary(model) -#'} +#' } setMethod("glm", signature(formula = "formula", family = "ANY", data = "DataFrame"), - function(formula, family = c("gaussian", "binomial"), data, lambda = 0, alpha = 0, - standardize = TRUE, solver = "auto") { - family <- match.arg(family) + function(formula, family = gaussian, data, epsilon = 1e-06, maxit = 25) { + if (is.character(family)) { + family <- get(family, mode = "function", envir = parent.frame()) + } + if (is.function(family)) { + family <- family() + } + if (is.null(family$family)) { + print(family) + stop("'family' not recognized") + } + formula <- paste(deparse(formula), collapse = "") - model <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers", - "fitRModelFormula", formula, data@sdf, family, lambda, - alpha, standardize, solver) - return(new("PipelineModel", model = model)) + + jobj <- callJStatic("org.apache.spark.ml.r.GeneralizedLinearRegressionWrapper", + "fit", formula, data@sdf, family$family, family$link, + epsilon, as.integer(maxit)) + return(new("GeneralizedLinearRegressionModel", jobj = jobj)) }) -#' Make predictions from a model +#' Get the summary of a generalized linear model #' -#' Makes predictions from a model produced by glm(), similarly to R's predict(). +#' Returns the summary of a model produced by glm(), similarly to R's summary(). #' -#' @param object A fitted MLlib model +#' @param object A fitted generalized linear model +#' @return coefficients the model's coefficients, intercept +#' @rdname summary +#' @export +#' @examples +#' \dontrun{ +#' model <- glm(y ~ x, trainingData) +#' summary(model) +#' } +setMethod("summary", signature(object = "GeneralizedLinearRegressionModel"), + function(object, ...) { + jobj <- object@jobj + features <- callJMethod(jobj, "rFeatures") + coefficients <- callJMethod(jobj, "rCoefficients") + coefficients <- as.matrix(unlist(coefficients)) + colnames(coefficients) <- c("Estimate") + rownames(coefficients) <- unlist(features) + return(list(coefficients = coefficients)) + }) + +#' Make predictions from a generalized linear model +#' +#' Makes predictions from a generalized linear model produced by glm(), similarly to R's predict(). +#' +#' @param object A fitted generalized linear model #' @param newData DataFrame for testing -#' @return DataFrame containing predicted values +#' @return DataFrame containing predicted labels in a column named "prediction" #' @rdname predict #' @export #' @examples @@ -90,10 +121,10 @@ setMethod("glm", signature(formula = "formula", family = "ANY", data = "DataFram #' model <- glm(y ~ x, trainingData) #' predicted <- predict(model, testData) #' showDF(predicted) -#'} -setMethod("predict", signature(object = "PipelineModel"), +#' } +setMethod("predict", signature(object = "GeneralizedLinearRegressionModel"), function(object, newData) { - return(dataFrame(callJMethod(object@model, "transform", newData@sdf))) + return(dataFrame(callJMethod(object@jobj, "transform", newData@sdf))) }) #' Make predictions from a naive Bayes model @@ -116,54 +147,6 @@ setMethod("predict", signature(object = "NaiveBayesModel"), return(dataFrame(callJMethod(object@jobj, "transform", newData@sdf))) }) -#' Get the summary of a model -#' -#' Returns the summary of a model produced by glm(), similarly to R's summary(). -#' -#' @param object A fitted MLlib model -#' @return a list with 'devianceResiduals' and 'coefficients' components for gaussian family -#' or a list with 'coefficients' component for binomial family. \cr -#' For gaussian family: the 'devianceResiduals' gives the min/max deviance residuals -#' of the estimation, the 'coefficients' gives the estimated coefficients and their -#' estimated standard errors, t values and p-values. (It only available when model -#' fitted by normal solver.) \cr -#' For binomial family: the 'coefficients' gives the estimated coefficients. -#' See summary.glm for more information. \cr -#' @rdname summary -#' @export -#' @examples -#' \dontrun{ -#' model <- glm(y ~ x, trainingData) -#' summary(model) -#'} -setMethod("summary", signature(object = "PipelineModel"), - function(object, ...) { - modelName <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers", - "getModelName", object@model) - features <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers", - "getModelFeatures", object@model) - coefficients <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers", - "getModelCoefficients", object@model) - if (modelName == "LinearRegressionModel") { - devianceResiduals <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers", - "getModelDevianceResiduals", object@model) - devianceResiduals <- matrix(devianceResiduals, nrow = 1) - colnames(devianceResiduals) <- c("Min", "Max") - rownames(devianceResiduals) <- rep("", times = 1) - coefficients <- matrix(coefficients, ncol = 4) - colnames(coefficients) <- c("Estimate", "Std. Error", "t value", "Pr(>|t|)") - rownames(coefficients) <- unlist(features) - return(list(devianceResiduals = devianceResiduals, coefficients = coefficients)) - } else if (modelName == "LogisticRegressionModel") { - coefficients <- as.matrix(unlist(coefficients)) - colnames(coefficients) <- c("Estimate") - rownames(coefficients) <- unlist(features) - return(list(coefficients = coefficients)) - } else { - stop(paste("Unsupported model", modelName, sep = " ")) - } - }) - #' Get the summary of a naive Bayes model #' #' Returns the summary of a naive Bayes model produced by naiveBayes(), similarly to R's summary(). diff --git a/R/pkg/inst/tests/testthat/test_mllib.R b/R/pkg/inst/tests/testthat/test_mllib.R index fdb591756e..a9dbd2bdc4 100644 --- a/R/pkg/inst/tests/testthat/test_mllib.R +++ b/R/pkg/inst/tests/testthat/test_mllib.R @@ -25,20 +25,21 @@ sc <- sparkR.init() sqlContext <- sparkRSQL.init(sc) -test_that("glm and predict", { +test_that("formula of glm", { training <- suppressWarnings(createDataFrame(sqlContext, iris)) - test <- select(training, "Sepal_Length") - model <- glm(Sepal_Width ~ Sepal_Length, training, family = "gaussian") - prediction <- predict(model, test) - expect_equal(typeof(take(select(prediction, "prediction"), 1)$prediction), "double") + # dot minus and intercept vs native glm + model <- glm(Sepal_Width ~ . - Species + 0, data = training) + vals <- collect(select(predict(model, training), "prediction")) + rVals <- predict(glm(Sepal.Width ~ . - Species + 0, data = iris), iris) + expect_true(all(abs(rVals - vals) < 1e-6), rVals - vals) - # Test stats::predict is working - x <- rnorm(15) - y <- x + rnorm(15) - expect_equal(length(predict(lm(y ~ x))), 15) -}) + # feature interaction vs native glm + model <- glm(Sepal_Width ~ Species:Sepal_Length, data = training) + vals <- collect(select(predict(model, training), "prediction")) + rVals <- predict(glm(Sepal.Width ~ Species:Sepal.Length, data = iris), iris) + expect_true(all(abs(rVals - vals) < 1e-6), rVals - vals) -test_that("glm should work with long formula", { + # glm should work with long formula training <- suppressWarnings(createDataFrame(sqlContext, iris)) training$LongLongLongLongLongName <- training$Sepal_Width training$VeryLongLongLongLonLongName <- training$Sepal_Length @@ -50,68 +51,30 @@ test_that("glm should work with long formula", { expect_true(all(abs(rVals - vals) < 1e-6), rVals - vals) }) -test_that("predictions match with native glm", { +test_that("glm and predict", { training <- suppressWarnings(createDataFrame(sqlContext, iris)) + # gaussian family model <- glm(Sepal_Width ~ Sepal_Length + Species, data = training) - vals <- collect(select(predict(model, training), "prediction")) + prediction <- predict(model, training) + expect_equal(typeof(take(select(prediction, "prediction"), 1)$prediction), "double") + vals <- collect(select(prediction, "prediction")) rVals <- predict(glm(Sepal.Width ~ Sepal.Length + Species, data = iris), iris) expect_true(all(abs(rVals - vals) < 1e-6), rVals - vals) -}) - -test_that("dot minus and intercept vs native glm", { - training <- suppressWarnings(createDataFrame(sqlContext, iris)) - model <- glm(Sepal_Width ~ . - Species + 0, data = training) - vals <- collect(select(predict(model, training), "prediction")) - rVals <- predict(glm(Sepal.Width ~ . - Species + 0, data = iris), iris) - expect_true(all(abs(rVals - vals) < 1e-6), rVals - vals) -}) -test_that("feature interaction vs native glm", { - training <- suppressWarnings(createDataFrame(sqlContext, iris)) - model <- glm(Sepal_Width ~ Species:Sepal_Length, data = training) - vals <- collect(select(predict(model, training), "prediction")) - rVals <- predict(glm(Sepal.Width ~ Species:Sepal.Length, data = iris), iris) + # poisson family + model <- glm(Sepal_Width ~ Sepal_Length + Species, data = training, + family = poisson(link = identity)) + prediction <- predict(model, training) + expect_equal(typeof(take(select(prediction, "prediction"), 1)$prediction), "double") + vals <- collect(select(prediction, "prediction")) + rVals <- suppressWarnings(predict(glm(Sepal.Width ~ Sepal.Length + Species, + data = iris, family = poisson(link = identity)), iris)) expect_true(all(abs(rVals - vals) < 1e-6), rVals - vals) -}) -test_that("summary coefficients match with native glm", { - training <- suppressWarnings(createDataFrame(sqlContext, iris)) - stats <- summary(glm(Sepal_Width ~ Sepal_Length + Species, data = training, solver = "normal")) - coefs <- unlist(stats$coefficients) - devianceResiduals <- unlist(stats$devianceResiduals) - - rStats <- summary(glm(Sepal.Width ~ Sepal.Length + Species, data = iris)) - rCoefs <- unlist(rStats$coefficients) - rDevianceResiduals <- c(-0.95096, 0.72918) - - expect_true(all(abs(rCoefs - coefs) < 1e-5)) - expect_true(all(abs(rDevianceResiduals - devianceResiduals) < 1e-5)) - expect_true(all( - rownames(stats$coefficients) == - c("(Intercept)", "Sepal_Length", "Species_versicolor", "Species_virginica"))) -}) - -test_that("summary coefficients match with native glm of family 'binomial'", { - df <- suppressWarnings(createDataFrame(sqlContext, iris)) - training <- filter(df, df$Species != "setosa") - stats <- summary(glm(Species ~ Sepal_Length + Sepal_Width, data = training, - family = "binomial")) - coefs <- as.vector(stats$coefficients[, 1]) - - rTraining <- iris[iris$Species %in% c("versicolor", "virginica"), ] - rCoefs <- as.vector(coef(glm(Species ~ Sepal.Length + Sepal.Width, data = rTraining, - family = binomial(link = "logit")))) - - expect_true(all(abs(rCoefs - coefs) < 1e-4)) - expect_true(all( - rownames(stats$coefficients) == - c("(Intercept)", "Sepal_Length", "Sepal_Width"))) -}) - -test_that("summary works on base GLM models", { - baseModel <- stats::glm(Sepal.Width ~ Sepal.Length + Species, data = iris) - baseSummary <- summary(baseModel) - expect_true(abs(baseSummary$deviance - 12.19313) < 1e-4) + # Test stats::predict is working + x <- rnorm(15) + y <- x + rnorm(15) + expect_equal(length(predict(lm(y ~ x))), 15) }) test_that("kmeans", { diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala new file mode 100644 index 0000000000..475a308385 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala @@ -0,0 +1,79 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.r + +import org.apache.spark.ml.{Pipeline, PipelineModel} +import org.apache.spark.ml.attribute.AttributeGroup +import org.apache.spark.ml.feature.RFormula +import org.apache.spark.ml.regression._ +import org.apache.spark.sql._ + +private[r] class GeneralizedLinearRegressionWrapper private ( + pipeline: PipelineModel, + val features: Array[String]) { + + private val glm: GeneralizedLinearRegressionModel = + pipeline.stages(1).asInstanceOf[GeneralizedLinearRegressionModel] + + lazy val rCoefficients: Array[Double] = if (glm.getFitIntercept) { + Array(glm.intercept) ++ glm.coefficients.toArray + } else { + glm.coefficients.toArray + } + + lazy val rFeatures: Array[String] = if (glm.getFitIntercept) { + Array("(Intercept)") ++ features + } else { + features + } + + def transform(dataset: DataFrame): DataFrame = { + pipeline.transform(dataset).drop(glm.getFeaturesCol) + } +} + +private[r] object GeneralizedLinearRegressionWrapper { + + def fit( + formula: String, + data: DataFrame, + family: String, + link: String, + epsilon: Double, + maxit: Int): GeneralizedLinearRegressionWrapper = { + val rFormula = new RFormula() + .setFormula(formula) + val rFormulaModel = rFormula.fit(data) + // get labels and feature names from output schema + val schema = rFormulaModel.transform(data).schema + val featureAttrs = AttributeGroup.fromStructField(schema(rFormula.getFeaturesCol)) + .attributes.get + val features = featureAttrs.map(_.name.get) + // assemble and fit the pipeline + val glm = new GeneralizedLinearRegression() + .setFamily(family) + .setLink(link) + .setFitIntercept(rFormula.hasIntercept) + .setTol(epsilon) + .setMaxIter(maxit) + val pipeline = new Pipeline() + .setStages(Array(rFormulaModel, glm)) + .fit(data) + new GeneralizedLinearRegressionWrapper(pipeline, features) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala b/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala deleted file mode 100644 index fa143715be..0000000000 --- a/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala +++ /dev/null @@ -1,115 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.ml.api.r - -import org.apache.spark.ml.{Pipeline, PipelineModel} -import org.apache.spark.ml.attribute._ -import org.apache.spark.ml.classification.{LogisticRegression, LogisticRegressionModel} -import org.apache.spark.ml.feature.RFormula -import org.apache.spark.ml.regression.{LinearRegression, LinearRegressionModel} -import org.apache.spark.sql.DataFrame - -private[r] object SparkRWrappers { - def fitRModelFormula( - value: String, - df: DataFrame, - family: String, - lambda: Double, - alpha: Double, - standardize: Boolean, - solver: String): PipelineModel = { - val formula = new RFormula().setFormula(value) - val estimator = family match { - case "gaussian" => new LinearRegression() - .setRegParam(lambda) - .setElasticNetParam(alpha) - .setFitIntercept(formula.hasIntercept) - .setStandardization(standardize) - .setSolver(solver) - case "binomial" => new LogisticRegression() - .setRegParam(lambda) - .setElasticNetParam(alpha) - .setFitIntercept(formula.hasIntercept) - .setStandardization(standardize) - } - val pipeline = new Pipeline().setStages(Array(formula, estimator)) - pipeline.fit(df) - } - - def getModelCoefficients(model: PipelineModel): Array[Double] = { - model.stages.last match { - case m: LinearRegressionModel => - val coefficientStandardErrorsR = Array(m.summary.coefficientStandardErrors.last) ++ - m.summary.coefficientStandardErrors.dropRight(1) - val tValuesR = Array(m.summary.tValues.last) ++ m.summary.tValues.dropRight(1) - val pValuesR = Array(m.summary.pValues.last) ++ m.summary.pValues.dropRight(1) - if (m.getFitIntercept) { - Array(m.intercept) ++ m.coefficients.toArray ++ coefficientStandardErrorsR ++ - tValuesR ++ pValuesR - } else { - m.coefficients.toArray ++ coefficientStandardErrorsR ++ tValuesR ++ pValuesR - } - case m: LogisticRegressionModel => - if (m.getFitIntercept) { - Array(m.intercept) ++ m.coefficients.toArray - } else { - m.coefficients.toArray - } - } - } - - def getModelDevianceResiduals(model: PipelineModel): Array[Double] = { - model.stages.last match { - case m: LinearRegressionModel => - m.summary.devianceResiduals - case m: LogisticRegressionModel => - throw new UnsupportedOperationException( - "No deviance residuals available for LogisticRegressionModel") - } - } - - def getModelFeatures(model: PipelineModel): Array[String] = { - model.stages.last match { - case m: LinearRegressionModel => - val attrs = AttributeGroup.fromStructField( - m.summary.predictions.schema(m.summary.featuresCol)) - if (m.getFitIntercept) { - Array("(Intercept)") ++ attrs.attributes.get.map(_.name.get) - } else { - attrs.attributes.get.map(_.name.get) - } - case m: LogisticRegressionModel => - val attrs = AttributeGroup.fromStructField( - m.summary.predictions.schema(m.summary.featuresCol)) - if (m.getFitIntercept) { - Array("(Intercept)") ++ attrs.attributes.get.map(_.name.get) - } else { - attrs.attributes.get.map(_.name.get) - } - } - } - - def getModelName(model: PipelineModel): String = { - model.stages.last match { - case m: LinearRegressionModel => - "LinearRegressionModel" - case m: LogisticRegressionModel => - "LogisticRegressionModel" - } - } -} -- cgit v1.2.3 From 101663f1ae222a919fc40510aa4f2bad22d1be6f Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Tue, 12 Apr 2016 11:27:16 -0700 Subject: [SPARK-13322][ML] AFTSurvivalRegression supports feature standardization ## What changes were proposed in this pull request? AFTSurvivalRegression should support feature standardization, it will improve the convergence rate. Test the convergence rate on the [Ovarian](https://stat.ethz.ch/R-manual/R-devel/library/survival/html/ovarian.html) data which is standard data comes with Survival library in R, * without standardization(before this PR) -> 74 iterations. * with standardization(after this PR) -> 38 iterations. But after this fix, with or without ```standardization``` will converge to the same solution. It means that ```standardization = false``` will run the same code route as ```standardization = true```. Because if the features are not standardized at all, it will result convergency issue when the features have very different scales. This behavior is the same as ML [```LinearRegression``` and ```LogisticRegression```](https://issues.apache.org/jira/browse/SPARK-8522). See more discussion about this topic at #11247. cc mengxr ## How was this patch tested? unit test. Author: Yanbo Liang Closes #11365 from yanboliang/spark-13322. --- .../ml/regression/AFTSurvivalRegression.scala | 105 ++++++++++++++------- .../ml/regression/AFTSurvivalRegressionSuite.scala | 22 +++++ 2 files changed, 93 insertions(+), 34 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala index afed1f32b6..89ba6ab5d2 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala @@ -31,6 +31,7 @@ import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util._ import org.apache.spark.mllib.linalg.{BLAS, Vector, Vectors, VectorUDT} +import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Dataset, Row} import org.apache.spark.sql.functions._ @@ -198,10 +199,20 @@ class AFTSurvivalRegression @Since("1.6.0") (@Since("1.6.0") override val uid: S val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE if (handlePersistence) instances.persist(StorageLevel.MEMORY_AND_DISK) - val costFun = new AFTCostFun(instances, $(fitIntercept)) + val featuresSummarizer = { + val seqOp = (c: MultivariateOnlineSummarizer, v: AFTPoint) => c.add(v.features) + val combOp = (c1: MultivariateOnlineSummarizer, c2: MultivariateOnlineSummarizer) => { + c1.merge(c2) + } + instances.treeAggregate(new MultivariateOnlineSummarizer)(seqOp, combOp) + } + + val featuresStd = featuresSummarizer.variance.toArray.map(math.sqrt) + + val costFun = new AFTCostFun(instances, $(fitIntercept), featuresStd) val optimizer = new BreezeLBFGS[BDV[Double]]($(maxIter), 10, $(tol)) - val numFeatures = dataset.select($(featuresCol)).take(1)(0).getAs[Vector](0).size + val numFeatures = featuresStd.size /* The parameters vector has three parts: the first element: Double, log(sigma), the log of scale parameter @@ -230,7 +241,13 @@ class AFTSurvivalRegression @Since("1.6.0") (@Since("1.6.0") override val uid: S if (handlePersistence) instances.unpersist() - val coefficients = Vectors.dense(parameters.slice(2, parameters.length)) + val rawCoefficients = parameters.slice(2, parameters.length) + var i = 0 + while (i < numFeatures) { + rawCoefficients(i) *= { if (featuresStd(i) != 0.0) 1.0 / featuresStd(i) else 0.0 } + i += 1 + } + val coefficients = Vectors.dense(rawCoefficients) val intercept = parameters(1) val scale = math.exp(parameters(0)) val model = new AFTSurvivalRegressionModel(uid, coefficients, intercept, scale) @@ -434,29 +451,36 @@ object AFTSurvivalRegressionModel extends MLReadable[AFTSurvivalRegressionModel] * @param parameters including three part: The log of scale parameter, the intercept and * regression coefficients corresponding to the features. * @param fitIntercept Whether to fit an intercept term. + * @param featuresStd The standard deviation values of the features. */ -private class AFTAggregator(parameters: BDV[Double], fitIntercept: Boolean) - extends Serializable { +private class AFTAggregator( + parameters: BDV[Double], + fitIntercept: Boolean, + featuresStd: Array[Double]) extends Serializable { // the regression coefficients to the covariates private val coefficients = parameters.slice(2, parameters.length) - private val intercept = parameters.valueAt(1) + private val intercept = parameters(1) // sigma is the scale parameter of the AFT model private val sigma = math.exp(parameters(0)) private var totalCnt: Long = 0L private var lossSum = 0.0 - private var gradientCoefficientSum = BDV.zeros[Double](coefficients.length) - private var gradientInterceptSum = 0.0 - private var gradientLogSigmaSum = 0.0 + // Here we optimize loss function over log(sigma), intercept and coefficients + private val gradientSumArray = Array.ofDim[Double](parameters.length) def count: Long = totalCnt + def loss: Double = { + require(totalCnt > 0.0, s"The number of instances should be " + + s"greater than 0.0, but got $totalCnt.") + lossSum / totalCnt + } + def gradient: BDV[Double] = { + require(totalCnt > 0.0, s"The number of instances should be " + + s"greater than 0.0, but got $totalCnt.") + new BDV(gradientSumArray.map(_ / totalCnt.toDouble)) + } - def loss: Double = if (totalCnt == 0) 1.0 else lossSum / totalCnt - - // Here we optimize loss function over coefficients, intercept and log(sigma) - def gradient: BDV[Double] = BDV.vertcat(BDV(Array(gradientLogSigmaSum / totalCnt.toDouble)), - BDV(Array(gradientInterceptSum/totalCnt.toDouble)), gradientCoefficientSum/totalCnt.toDouble) /** * Add a new training data to this AFTAggregator, and update the loss and gradient @@ -466,25 +490,32 @@ private class AFTAggregator(parameters: BDV[Double], fitIntercept: Boolean) * @return This AFTAggregator object. */ def add(data: AFTPoint): this.type = { - - val interceptFlag = if (fitIntercept) 1.0 else 0.0 - - val xi = data.features.toBreeze + val xi = data.features val ti = data.label val delta = data.censor - val epsilon = (math.log(ti) - coefficients.dot(xi) - intercept * interceptFlag ) / sigma - lossSum += math.log(sigma) * delta - lossSum += (math.exp(epsilon) - delta * epsilon) + val margin = { + var sum = 0.0 + xi.foreachActive { (index, value) => + if (featuresStd(index) != 0.0 && value != 0.0) { + sum += coefficients(index) * (value / featuresStd(index)) + } + } + sum + intercept + } + val epsilon = (math.log(ti) - margin) / sigma + + lossSum += delta * math.log(sigma) - delta * epsilon + math.exp(epsilon) - // Sanity check (should never occur): - assert(!lossSum.isInfinity, - s"AFTAggregator loss sum is infinity. Error for unknown reason.") + val multiplier = (delta - math.exp(epsilon)) / sigma - val deltaMinusExpEps = delta - math.exp(epsilon) - gradientCoefficientSum += xi * deltaMinusExpEps / sigma - gradientInterceptSum += interceptFlag * deltaMinusExpEps / sigma - gradientLogSigmaSum += delta + deltaMinusExpEps * epsilon + gradientSumArray(0) += delta + multiplier * sigma * epsilon + gradientSumArray(1) += { if (fitIntercept) multiplier else 0.0 } + xi.foreachActive { (index, value) => + if (featuresStd(index) != 0.0 && value != 0.0) { + gradientSumArray(index + 2) += multiplier * (value / featuresStd(index)) + } + } totalCnt += 1 this @@ -503,9 +534,12 @@ private class AFTAggregator(parameters: BDV[Double], fitIntercept: Boolean) totalCnt += other.totalCnt lossSum += other.lossSum - gradientCoefficientSum += other.gradientCoefficientSum - gradientInterceptSum += other.gradientInterceptSum - gradientLogSigmaSum += other.gradientLogSigmaSum + var i = 0 + val len = this.gradientSumArray.length + while (i < len) { + this.gradientSumArray(i) += other.gradientSumArray(i) + i += 1 + } } this } @@ -516,12 +550,15 @@ private class AFTAggregator(parameters: BDV[Double], fitIntercept: Boolean) * It returns the loss and gradient at a particular point (parameters). * It's used in Breeze's convex optimization routines. */ -private class AFTCostFun(data: RDD[AFTPoint], fitIntercept: Boolean) - extends DiffFunction[BDV[Double]] { +private class AFTCostFun( + data: RDD[AFTPoint], + fitIntercept: Boolean, + featuresStd: Array[Double]) extends DiffFunction[BDV[Double]] { override def calculate(parameters: BDV[Double]): (Double, BDV[Double]) = { - val aftAggregator = data.treeAggregate(new AFTAggregator(parameters, fitIntercept))( + val aftAggregator = data.treeAggregate( + new AFTAggregator(parameters, fitIntercept, featuresStd))( seqOp = (c, v) => (c, v) match { case (aggregator, instance) => aggregator.add(instance) }, diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala index f4844cc671..76891ad562 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala @@ -33,6 +33,7 @@ class AFTSurvivalRegressionSuite @transient var datasetUnivariate: DataFrame = _ @transient var datasetMultivariate: DataFrame = _ + @transient var datasetUnivariateScaled: DataFrame = _ override def beforeAll(): Unit = { super.beforeAll() @@ -42,6 +43,11 @@ class AFTSurvivalRegressionSuite datasetMultivariate = sqlContext.createDataFrame( sc.parallelize(generateAFTInput( 2, Array(0.9, -1.3), Array(0.7, 1.2), 1000, 42, 1.5, 2.5, 2.0))) + datasetUnivariateScaled = sqlContext.createDataFrame( + sc.parallelize(generateAFTInput( + 1, Array(5.5), Array(0.8), 1000, 42, 1.0, 2.0, 2.0)).map { x => + AFTPoint(Vectors.dense(x.features(0) * 1.0E3), x.label, x.censor) + }) } /** @@ -356,6 +362,22 @@ class AFTSurvivalRegressionSuite } } + test("numerical stability of standardization") { + val trainer = new AFTSurvivalRegression() + val model1 = trainer.fit(datasetUnivariate) + val model2 = trainer.fit(datasetUnivariateScaled) + + /** + * During training we standardize the dataset first, so no matter how we multiple + * a scaling factor into the dataset, the convergence rate should be the same, + * and the coefficients should equal to the original coefficients multiple by + * the scaling factor. It will have no effect on the intercept and scale. + */ + assert(model1.coefficients(0) ~== model2.coefficients(0) * 1.0E3 absTol 0.01) + assert(model1.intercept ~== model2.intercept absTol 0.01) + assert(model1.scale ~== model2.scale absTol 0.01) + } + test("read/write") { def checkModelData( model: AFTSurvivalRegressionModel, -- cgit v1.2.3 From 7f024c47441a2f84fcc34a6021b976f036ea24c4 Mon Sep 17 00:00:00 2001 From: Kai Jiang Date: Tue, 12 Apr 2016 11:29:12 -0700 Subject: [SPARK-13597][PYSPARK][ML] Python API for GeneralizedLinearRegression ## What changes were proposed in this pull request? Python API for GeneralizedLinearRegression JIRA: https://issues.apache.org/jira/browse/SPARK-13597 ## How was this patch tested? The patch is tested with Python doctest. Author: Kai Jiang Closes #11468 from vectorijk/spark-13597. --- python/pyspark/ml/regression.py | 145 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 145 insertions(+) diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py index 1c18df3b27..bc88f88b7f 100644 --- a/python/pyspark/ml/regression.py +++ b/python/pyspark/ml/regression.py @@ -28,6 +28,7 @@ from pyspark.sql import DataFrame __all__ = ['AFTSurvivalRegression', 'AFTSurvivalRegressionModel', 'DecisionTreeRegressor', 'DecisionTreeRegressionModel', 'GBTRegressor', 'GBTRegressionModel', + 'GeneralizedLinearRegression', 'GeneralizedLinearRegressionModel' 'IsotonicRegression', 'IsotonicRegressionModel', 'LinearRegression', 'LinearRegressionModel', 'LinearRegressionSummary', 'LinearRegressionTrainingSummary', @@ -1197,6 +1198,150 @@ class AFTSurvivalRegressionModel(JavaModel, JavaMLWritable, JavaMLReadable): return self._call_java("predict", features) +@inherit_doc +class GeneralizedLinearRegression(JavaEstimator, HasLabelCol, HasFeaturesCol, HasPredictionCol, + HasFitIntercept, HasMaxIter, HasTol, HasRegParam, HasWeightCol, + HasSolver, JavaMLWritable, JavaMLReadable): + """ + Generalized Linear Regression. + + Fit a Generalized Linear Model specified by giving a symbolic description of the linear + predictor (link function) and a description of the error distribution (family). It supports + "gaussian", "binomial", "poisson" and "gamma" as family. Valid link functions for each family + is listed below. The first link function of each family is the default one. + - "gaussian" -> "identity", "log", "inverse" + - "binomial" -> "logit", "probit", "cloglog" + - "poisson" -> "log", "identity", "sqrt" + - "gamma" -> "inverse", "identity", "log" + + .. seealso:: `GLM `_ + + >>> from pyspark.mllib.linalg import Vectors + >>> df = sqlContext.createDataFrame([ + ... (1.0, Vectors.dense(0.0, 0.0)), + ... (1.0, Vectors.dense(1.0, 2.0)), + ... (2.0, Vectors.dense(0.0, 0.0)), + ... (2.0, Vectors.dense(1.0, 1.0)),], ["label", "features"]) + >>> glr = GeneralizedLinearRegression(family="gaussian", link="identity") + >>> model = glr.fit(df) + >>> abs(model.transform(df).head().prediction - 1.5) < 0.001 + True + >>> model.coefficients + DenseVector([1.5..., -1.0...]) + >>> abs(model.intercept - 1.5) < 0.001 + True + >>> glr_path = temp_path + "/glr" + >>> glr.save(glr_path) + >>> glr2 = GeneralizedLinearRegression.load(glr_path) + >>> glr.getFamily() == glr2.getFamily() + True + >>> model_path = temp_path + "/glr_model" + >>> model.save(model_path) + >>> model2 = GeneralizedLinearRegressionModel.load(model_path) + >>> model.intercept == model2.intercept + True + >>> model.coefficients[0] == model2.coefficients[0] + True + + .. versionadded:: 2.0.0 + """ + + family = Param(Params._dummy(), "family", "The name of family which is a description of " + + "the error distribution to be used in the model. Supported options: " + + "gaussian(default), binomial, poisson and gamma.") + link = Param(Params._dummy(), "link", "The name of link function which provides the " + + "relationship between the linear predictor and the mean of the distribution " + + "function. Supported options: identity, log, inverse, logit, probit, cloglog " + + "and sqrt.") + + @keyword_only + def __init__(self, labelCol="label", featuresCol="features", predictionCol="prediction", + family="gaussian", link=None, fitIntercept=True, maxIter=25, tol=1e-6, + regParam=0.0, weightCol=None, solver="irls"): + """ + __init__(self, labelCol="label", featuresCol="features", predictionCol="prediction", \ + family="gaussian", link=None, fitIntercept=True, maxIter=25, tol=1e-6, \ + regParam=0.0, weightCol=None, solver="irls") + """ + super(GeneralizedLinearRegression, self).__init__() + self._java_obj = self._new_java_obj( + "org.apache.spark.ml.regression.GeneralizedLinearRegression", self.uid) + self._setDefault(family="gaussian", maxIter=25, tol=1e-6, regParam=0.0, solver="irls") + kwargs = self.__init__._input_kwargs + self.setParams(**kwargs) + + @keyword_only + @since("2.0.0") + def setParams(self, labelCol="label", featuresCol="features", predictionCol="prediction", + family="gaussian", link=None, fitIntercept=True, maxIter=25, tol=1e-6, + regParam=0.0, weightCol=None, solver="irls"): + """ + setParams(self, labelCol="label", featuresCol="features", predictionCol="prediction", \ + family="gaussian", link=None, fitIntercept=True, maxIter=25, tol=1e-6, \ + regParam=0.0, weightCol=None, solver="irls") + Sets params for generalized linear regression. + """ + kwargs = self.setParams._input_kwargs + return self._set(**kwargs) + + def _create_model(self, java_model): + return GeneralizedLinearRegressionModel(java_model) + + @since("2.0.0") + def setFamily(self, value): + """ + Sets the value of :py:attr:`family`. + """ + self._paramMap[self.family] = value + return self + + @since("2.0.0") + def getFamily(self): + """ + Gets the value of family or its default value. + """ + return self.getOrDefault(self.family) + + @since("2.0.0") + def setLink(self, value): + """ + Sets the value of :py:attr:`link`. + """ + self._paramMap[self.link] = value + return self + + @since("2.0.0") + def getLink(self): + """ + Gets the value of link or its default value. + """ + return self.getOrDefault(self.link) + + +class GeneralizedLinearRegressionModel(JavaModel, JavaMLWritable, JavaMLReadable): + """ + Model fitted by GeneralizedLinearRegression. + + .. versionadded:: 2.0.0 + """ + + @property + @since("2.0.0") + def coefficients(self): + """ + Model coefficients. + """ + return self._call_java("coefficients") + + @property + @since("2.0.0") + def intercept(self): + """ + Model intercept. + """ + return self._call_java("intercept") + + if __name__ == "__main__": import doctest import pyspark.ml.regression -- cgit v1.2.3 From 1995c2e6482bf4af5a4be087bfc156311c1bec19 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Tue, 12 Apr 2016 11:30:09 -0700 Subject: [SPARK-14563][ML] use a random table name instead of __THIS__ in SQLTransformer ## What changes were proposed in this pull request? Use a random table name instead of `__THIS__` in SQLTransformer, and add a test for `transformSchema`. The problems of using `__THIS__` are: * It doesn't work under HiveContext (in Spark 1.6) * Race conditions ## How was this patch tested? * Manual test with HiveContext. * Added a unit test for `transformSchema` to improve coverage. cc: yhuai Author: Xiangrui Meng Closes #12330 from mengxr/SPARK-14563. --- .../scala/org/apache/spark/ml/feature/SQLTransformer.scala | 10 ++++++---- .../org/apache/spark/ml/feature/SQLTransformerSuite.scala | 10 ++++++++++ 2 files changed, 16 insertions(+), 4 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala index 95fe942c6b..2002d15745 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala @@ -68,8 +68,7 @@ class SQLTransformer @Since("1.6.0") (override val uid: String) extends Transfor val tableName = Identifiable.randomUID(uid) dataset.registerTempTable(tableName) val realStatement = $(statement).replace(tableIdentifier, tableName) - val outputDF = dataset.sqlContext.sql(realStatement) - outputDF + dataset.sqlContext.sql(realStatement) } @Since("1.6.0") @@ -78,8 +77,11 @@ class SQLTransformer @Since("1.6.0") (override val uid: String) extends Transfor val sqlContext = SQLContext.getOrCreate(sc) val dummyRDD = sc.parallelize(Seq(Row.empty)) val dummyDF = sqlContext.createDataFrame(dummyRDD, schema) - dummyDF.registerTempTable(tableIdentifier) - val outputSchema = sqlContext.sql($(statement)).schema + val tableName = Identifiable.randomUID(uid) + val realStatement = $(statement).replace(tableIdentifier, tableName) + dummyDF.registerTempTable(tableName) + val outputSchema = sqlContext.sql(realStatement).schema + sqlContext.dropTempTable(tableName) outputSchema } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/SQLTransformerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/SQLTransformerSuite.scala index 553e0b8702..e213e17d0d 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/SQLTransformerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/SQLTransformerSuite.scala @@ -21,6 +21,7 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.util.DefaultReadWriteTest import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.sql.types.{LongType, StructField, StructType} class SQLTransformerSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { @@ -49,4 +50,13 @@ class SQLTransformerSuite .setStatement("select * from __THIS__") testDefaultReadWrite(t) } + + test("transformSchema") { + val df = sqlContext.range(10) + val outputSchema = new SQLTransformer() + .setStatement("SELECT id + 1 AS id1 FROM __THIS__") + .transformSchema(df.schema) + val expected = StructType(Seq(StructField("id1", LongType, nullable = false))) + assert(outputSchema === expected) + } } -- cgit v1.2.3 From 111a62474a2fb7f4e7f19fcfb8efaae37aa40400 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Tue, 12 Apr 2016 11:34:40 -0700 Subject: [SPARK-14147][ML][SPARKR] SparkR predict should not output feature column ## What changes were proposed in this pull request? SparkR does not support type of vector which is the default type of feature column in ML. R predict also does not output intermediate feature column. So SparkR ```predict``` should not output feature column. In this PR, I only fix this issue for ```naiveBayes``` and ```survreg```. ```kmeans``` has the right code route already and ```glm``` will be fixed at SparkRWrapper refactor(#12294). ## How was this patch tested? No new tests. cc mengxr shivaram Author: Yanbo Liang Closes #11958 from yanboliang/spark-14147. --- .../scala/org/apache/spark/ml/r/AFTSurvivalRegressionWrapper.scala | 2 +- mllib/src/main/scala/org/apache/spark/ml/r/NaiveBayesWrapper.scala | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/AFTSurvivalRegressionWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/AFTSurvivalRegressionWrapper.scala index 2ae411555f..7835468626 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/AFTSurvivalRegressionWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/AFTSurvivalRegressionWrapper.scala @@ -44,7 +44,7 @@ private[r] class AFTSurvivalRegressionWrapper private ( } def transform(dataset: Dataset[_]): DataFrame = { - pipeline.transform(dataset) + pipeline.transform(dataset).drop(aftModel.getFeaturesCol) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/NaiveBayesWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/NaiveBayesWrapper.scala index 2cd709d2ee..b17207e99b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/NaiveBayesWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/NaiveBayesWrapper.scala @@ -37,7 +37,9 @@ private[r] class NaiveBayesWrapper private ( lazy val tables: Array[Double] = naiveBayesModel.theta.toArray.map(math.exp) def transform(dataset: Dataset[_]): DataFrame = { - pipeline.transform(dataset).drop(PREDICTED_LABEL_INDEX_COL) + pipeline.transform(dataset) + .drop(PREDICTED_LABEL_INDEX_COL) + .drop(naiveBayesModel.getFeaturesCol) } } -- cgit v1.2.3 From 852bbc6c0046d194fef0b6d0b99162ea2cc10286 Mon Sep 17 00:00:00 2001 From: Liwei Lin Date: Tue, 12 Apr 2016 11:50:51 -0700 Subject: [SPARK-14556][SQL] Code clean-ups for package o.a.s.sql.execution.streaming.state ## What changes were proposed in this pull request? - `StateStoreConf.**max**DeltasForSnapshot` was renamed to `StateStoreConf.**min**DeltasForSnapshot` - some state switch checks were added - improved consistency between method names and string literals - other comments & typo fix ## How was this patch tested? N/A Author: Liwei Lin Closes #12323 from lw-lin/streaming-state-clean-up. --- .../state/HDFSBackedStateStoreProvider.scala | 40 ++++++++++++---------- .../sql/execution/streaming/state/StateStore.scala | 7 ++-- .../execution/streaming/state/StateStoreConf.scala | 3 +- .../streaming/state/StateStoreCoordinator.scala | 7 ++-- .../execution/streaming/state/StateStoreRDD.scala | 6 ++-- 5 files changed, 31 insertions(+), 32 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala index 1e0a4a5d4f..3335755fd3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala @@ -46,12 +46,14 @@ import org.apache.spark.util.Utils * Usage: * To update the data in the state store, the following order of operations are needed. * - * - val store = StateStore.get(operatorId, partitionId, version) // to get the right store - * - store.update(...) + * // get the right store + * - val store = StateStore.get( + * StateStoreId(checkpointLocation, operatorId, partitionId), ..., version, ...) + * - store.put(...) * - store.remove(...) - * - store.commit() // commits all the updates to made with version number + * - store.commit() // commits all the updates to made; the new version will be returned * - store.iterator() // key-value data after last commit as an iterator - * - store.updates() // updates made in the last as an iterator + * - store.updates() // updates made in the last commit as an iterator * * Fault-tolerance model: * - Every set of updates is written to a delta file before committing. @@ -99,7 +101,7 @@ private[state] class HDFSBackedStateStoreProvider( } override def put(key: UnsafeRow, value: UnsafeRow): Unit = { - verify(state == UPDATING, "Cannot remove after already committed or cancelled") + verify(state == UPDATING, "Cannot remove after already committed or aborted") val isNewKey = !mapToUpdate.containsKey(key) mapToUpdate.put(key, value) @@ -109,7 +111,7 @@ private[state] class HDFSBackedStateStoreProvider( // Value did not exist in previous version and was added already, keep it marked as added allUpdates.put(key, ValueAdded(key, value)) case Some(ValueUpdated(_, _)) | Some(KeyRemoved(_)) => - // Value existed in prev version and updated/removed, mark it as updated + // Value existed in previous version and updated/removed, mark it as updated allUpdates.put(key, ValueUpdated(key, value)) case None => // There was no prior update, so mark this as added or updated according to its presence @@ -122,7 +124,7 @@ private[state] class HDFSBackedStateStoreProvider( /** Remove keys that match the following condition */ override def remove(condition: UnsafeRow => Boolean): Unit = { - verify(state == UPDATING, "Cannot remove after already committed or cancelled") + verify(state == UPDATING, "Cannot remove after already committed or aborted") val keyIter = mapToUpdate.keySet().iterator() while (keyIter.hasNext) { val key = keyIter.next @@ -146,7 +148,7 @@ private[state] class HDFSBackedStateStoreProvider( /** Commit all the updates that have been made to the store, and return the new version. */ override def commit(): Long = { - verify(state == UPDATING, "Cannot commit after already committed or cancelled") + verify(state == UPDATING, "Cannot commit after already committed or aborted") try { finalizeDeltaFile(tempDeltaFileStream) @@ -161,8 +163,10 @@ private[state] class HDFSBackedStateStoreProvider( } } - /** Cancel all the updates made on this store. This store will not be usable any more. */ + /** Abort all the updates made on this store. This store will not be usable any more. */ override def abort(): Unit = { + verify(state == UPDATING || state == ABORTED, "Cannot abort after already committed") + state = ABORTED if (tempDeltaFileStream != null) { tempDeltaFileStream.close() @@ -170,7 +174,7 @@ private[state] class HDFSBackedStateStoreProvider( if (tempDeltaFile != null && fs.exists(tempDeltaFile)) { fs.delete(tempDeltaFile, true) } - logInfo("Canceled ") + logInfo("Aborted") } /** @@ -178,7 +182,8 @@ private[state] class HDFSBackedStateStoreProvider( * This can be called only after committing all the updates made in the current thread. */ override def iterator(): Iterator[(UnsafeRow, UnsafeRow)] = { - verify(state == COMMITTED, "Cannot get iterator of store data before committing") + verify(state == COMMITTED, + "Cannot get iterator of store data before committing or after aborting") HDFSBackedStateStoreProvider.this.iterator(newVersion) } @@ -187,7 +192,8 @@ private[state] class HDFSBackedStateStoreProvider( * This can be called only after committing all the updates made in the current thread. */ override def updates(): Iterator[StoreUpdate] = { - verify(state == COMMITTED, "Cannot get iterator of updates before committing") + verify(state == COMMITTED, + "Cannot get iterator of updates before committing or after aborting") allUpdates.values().asScala.toIterator } @@ -223,7 +229,7 @@ private[state] class HDFSBackedStateStoreProvider( } override def toString(): String = { - s"StateStore[id = (op=${id.operatorId},part=${id.partitionId}), dir = $baseDir]" + s"StateStore[id = (op=${id.operatorId}, part=${id.partitionId}), dir = $baseDir]" } /* Internal classes and methods */ @@ -277,7 +283,7 @@ private[state] class HDFSBackedStateStoreProvider( } else { if (!fs.isDirectory(baseDir)) { throw new IllegalStateException( - s"Cannot use ${id.checkpointLocation} for storing state data for $this as" + + s"Cannot use ${id.checkpointLocation} for storing state data for $this as " + s"$baseDir already exists and is not a directory") } } @@ -453,11 +459,11 @@ private[state] class HDFSBackedStateStoreProvider( filesForVersion(files, lastVersion).filter(_.isSnapshot == false) synchronized { loadedMaps.get(lastVersion) } match { case Some(map) => - if (deltaFilesForLastVersion.size > storeConf.maxDeltasForSnapshot) { + if (deltaFilesForLastVersion.size > storeConf.minDeltasForSnapshot) { writeSnapshotFile(lastVersion, map) } case None => - // The last map is not loaded, probably some other instance is incharge + // The last map is not loaded, probably some other instance is in charge } } @@ -506,7 +512,6 @@ private[state] class HDFSBackedStateStoreProvider( .lastOption val deltaBatchFiles = latestSnapshotFileBeforeVersion match { case Some(snapshotFile) => - val deltaBatchIds = (snapshotFile.version + 1) to version val deltaFiles = allFiles.filter { file => file.version > snapshotFile.version && file.version <= version @@ -579,4 +584,3 @@ private[state] class HDFSBackedStateStoreProvider( } } } - diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala index 07f63f928b..cc5327e0e2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.execution.streaming.state -import java.util.Timer import java.util.concurrent.{ScheduledFuture, TimeUnit} import scala.collection.mutable @@ -63,7 +62,7 @@ trait StateStore { */ def commit(): Long - /** Cancel all the updates that have been made to the store. */ + /** Abort all the updates that have been made to the store. */ def abort(): Unit /** @@ -109,8 +108,8 @@ case class KeyRemoved(key: UnsafeRow) extends StoreUpdate /** * Companion object to [[StateStore]] that provides helper methods to create and retrieve stores * by their unique ids. In addition, when a SparkContext is active (i.e. SparkEnv.get is not null), - * it also runs a periodic background tasks to do maintenance on the loaded stores. For each - * store, tt uses the [[StateStoreCoordinator]] to ensure whether the current loaded instance of + * it also runs a periodic background task to do maintenance on the loaded stores. For each + * store, it uses the [[StateStoreCoordinator]] to ensure whether the current loaded instance of * the store is the active instance. Accordingly, it either keeps it loaded and performs * maintenance, or unloads the store. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala index f0f1f3a1a8..e55f63a6c8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala @@ -26,7 +26,7 @@ private[streaming] class StateStoreConf(@transient private val conf: SQLConf) ex import SQLConf._ - val maxDeltasForSnapshot = conf.getConf(STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT) + val minDeltasForSnapshot = conf.getConf(STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT) val minVersionsToRetain = conf.getConf(STATE_STORE_MIN_VERSIONS_TO_RETAIN) } @@ -34,4 +34,3 @@ private[streaming] class StateStoreConf(@transient private val conf: SQLConf) ex private[streaming] object StateStoreConf { val empty = new StateStoreConf() } - diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala index 812e1b0a39..e418217238 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinator.scala @@ -50,8 +50,7 @@ private[sql] object StateStoreCoordinatorRef extends Logging { private val endpointName = "StateStoreCoordinator" /** - * Create a reference to a [[StateStoreCoordinator]], This can be called from driver as well as - * executors. + * Create a reference to a [[StateStoreCoordinator]] */ def forDriver(env: SparkEnv): StateStoreCoordinatorRef = synchronized { try { @@ -75,7 +74,7 @@ private[sql] object StateStoreCoordinatorRef extends Logging { } /** - * Reference to a [[StateStoreCoordinator]] that can be used to coordinator instances of + * Reference to a [[StateStoreCoordinator]] that can be used to coordinate instances of * [[StateStore]]s across all the executors, and get their locations for job scheduling. */ private[sql] class StateStoreCoordinatorRef private(rpcEndpointRef: RpcEndpointRef) { @@ -142,5 +141,3 @@ private class StateStoreCoordinator(override val rpcEnv: RpcEnv) extends ThreadS context.reply(true) } } - - diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala index df3d82c113..d708486d8e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala @@ -22,12 +22,12 @@ import scala.reflect.ClassTag import org.apache.spark.{Partition, TaskContext} import org.apache.spark.rdd.RDD import org.apache.spark.sql.types.StructType -import org.apache.spark.util.{SerializableConfiguration, Utils} +import org.apache.spark.util.SerializableConfiguration /** * An RDD that allows computations to be executed against [[StateStore]]s. It - * uses the [[StateStoreCoordinator]] to use the locations of loaded state stores as - * preferred locations. + * uses the [[StateStoreCoordinator]] to get the locations of loaded state stores + * and use that as the preferred locations. */ class StateStoreRDD[T: ClassTag, U: ClassTag]( dataRDD: RDD[T], -- cgit v1.2.3 From 85e68b4bea3e4ad2e4063334dbf5b11af197d2ce Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Tue, 12 Apr 2016 12:29:54 -0700 Subject: [SPARK-14562] [SQL] improve constraints propagation in Union ## What changes were proposed in this pull request? Currently, Union only takes intersect of the constraints from it's children, all others are dropped, we should try to merge them together. This PR try to merge the constraints that have the same reference but came from different children, for example: `a > 10` and `a < 100` could be merged as `a > 10 || a < 100`. ## How was this patch tested? Added more cases in existing test. Author: Davies Liu Closes #12328 from davies/union_const. --- .../sql/catalyst/plans/logical/basicOperators.scala | 16 +++++++++++++++- .../sql/catalyst/plans/ConstraintPropagationSuite.scala | 14 ++++++++++++++ 2 files changed, 29 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index d3353beb09..d4fc9e4da9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -236,10 +236,24 @@ case class Union(children: Seq[LogicalPlan]) extends LogicalPlan { }) } + private def merge(a: Set[Expression], b: Set[Expression]): Set[Expression] = { + val common = a.intersect(b) + // The constraint with only one reference could be easily inferred as predicate + // Grouping the constraints by it's references so we can combine the constraints with same + // reference together + val othera = a.diff(common).filter(_.references.size == 1).groupBy(_.references.head) + val otherb = b.diff(common).filter(_.references.size == 1).groupBy(_.references.head) + // loose the constraints by: A1 && B1 || A2 && B2 -> (A1 || A2) && (B1 || B2) + val others = (othera.keySet intersect otherb.keySet).map { attr => + Or(othera(attr).reduceLeft(And), otherb(attr).reduceLeft(And)) + } + common ++ others + } + override protected def validConstraints: Set[Expression] = { children .map(child => rewriteConstraints(children.head.output, child.output, child.constraints)) - .reduce(_ intersect _) + .reduce(merge(_, _)) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala index 49c1353efb..81cc6b123c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala @@ -148,6 +148,20 @@ class ConstraintPropagationSuite extends SparkFunSuite { .analyze.constraints, ExpressionSet(Seq(resolveColumn(tr1, "a") > 10, IsNotNull(resolveColumn(tr1, "a"))))) + + val a = resolveColumn(tr1, "a") + verifyConstraints(tr1 + .where('a.attr > 10) + .union(tr2.where('d.attr > 11)) + .analyze.constraints, + ExpressionSet(Seq(a > 10 || a > 11, IsNotNull(a)))) + + val b = resolveColumn(tr1, "b") + verifyConstraints(tr1 + .where('a.attr > 10 && 'b.attr < 10) + .union(tr2.where('d.attr > 11 && 'e.attr < 11)) + .analyze.constraints, + ExpressionSet(Seq(a > 10 || a > 11, b < 10 || b < 11, IsNotNull(a), IsNotNull(b)))) } test("propagating constraints in intersect") { -- cgit v1.2.3 From bcd2076274b1a95f74616d0ceacb0696e38b5f4c Mon Sep 17 00:00:00 2001 From: bomeng Date: Tue, 12 Apr 2016 13:43:39 -0700 Subject: [SPARK-14414][SQL] improve the error message class hierarchy ## What changes were proposed in this pull request? Before we are using `AnalysisException`, `ParseException`, `NoSuchFunctionException` etc when a parsing error encounters. I am trying to make it consistent and also **minimum** code impact to the current implementation by changing the class hierarchy. 1. `NoSuchItemException` is removed, since it is an abstract class and it just simply takes a message string. 2. `NoSuchDatabaseException`, `NoSuchTableException`, `NoSuchPartitionException` and `NoSuchFunctionException` now extends `AnalysisException`, as well as `ParseException`, they are all under `AnalysisException` umbrella, but you can also determine how to use them in a granular way. ## How was this patch tested? The existing test cases should cover this patch. Author: bomeng Closes #12314 from bomeng/SPARK-14414. --- .../catalyst/analysis/NoSuchItemException.scala | 31 ++++++---------------- .../apache/spark/sql/execution/command/ddl.scala | 1 + .../spark/sql/hive/HiveExternalCatalog.scala | 3 --- .../spark/sql/hive/execution/HiveQuerySuite.scala | 1 - 4 files changed, 9 insertions(+), 27 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/NoSuchItemException.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/NoSuchItemException.scala index 96fd1a027e..5e18316c94 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/NoSuchItemException.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/NoSuchItemException.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.analysis +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.catalog.ExternalCatalog.TablePartitionSpec @@ -24,29 +25,13 @@ import org.apache.spark.sql.catalyst.catalog.ExternalCatalog.TablePartitionSpec * Thrown by a catalog when an item cannot be found. The analyzer will rethrow the exception * as an [[org.apache.spark.sql.AnalysisException]] with the correct position information. */ -abstract class NoSuchItemException extends Exception { - override def getMessage: String -} +class NoSuchDatabaseException(db: String) extends AnalysisException(s"Database $db not found") -class NoSuchDatabaseException(db: String) extends NoSuchItemException { - override def getMessage: String = s"Database $db not found" -} +class NoSuchTableException(db: String, table: String) + extends AnalysisException(s"Table or View $table not found in database $db") -class NoSuchTableException(db: String, table: String) extends NoSuchItemException { - override def getMessage: String = s"Table or View $table not found in database $db" -} +class NoSuchPartitionException(db: String, table: String, spec: TablePartitionSpec) extends + AnalysisException(s"Partition not found in table $table database $db:\n" + spec.mkString("\n")) -class NoSuchPartitionException( - db: String, - table: String, - spec: TablePartitionSpec) - extends NoSuchItemException { - - override def getMessage: String = { - s"Partition not found in table $table database $db:\n" + spec.mkString("\n") - } -} - -class NoSuchFunctionException(db: String, func: String) extends NoSuchItemException { - override def getMessage: String = s"Function $func not found in database $db" -} +class NoSuchFunctionException(db: String, func: String) + extends AnalysisException(s"Function $func not found in database $db") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala index 758a7e45d2..5137bd11d8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala @@ -29,6 +29,7 @@ import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} import org.apache.spark.sql.types._ + // Note: The definition of these commands are based on the ones described in // https://cwiki.apache.org/confluence/display/Hive/LanguageManual+DDL diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala index 482f47428d..f627384253 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala @@ -25,7 +25,6 @@ import org.apache.thrift.TException import org.apache.spark.internal.Logging import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.catalyst.analysis.NoSuchItemException import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.hive.client.HiveClient @@ -66,8 +65,6 @@ private[spark] class HiveExternalCatalog(client: HiveClient) extends ExternalCat try { body } catch { - case e: NoSuchItemException => - throw new AnalysisException(e.getMessage) case NonFatal(e) if isClientException(e) => throw new AnalysisException(e.getClass.getCanonicalName + ": " + e.getMessage) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index 0c57ede9ed..af73baa1f3 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -28,7 +28,6 @@ import org.scalatest.BeforeAndAfter import org.apache.spark.{SparkException, SparkFiles} import org.apache.spark.sql.{AnalysisException, DataFrame, Row} -import org.apache.spark.sql.catalyst.analysis.NoSuchDatabaseException import org.apache.spark.sql.catalyst.expressions.Cast import org.apache.spark.sql.catalyst.plans.logical.Project import org.apache.spark.sql.execution.joins.BroadcastNestedLoopJoin -- cgit v1.2.3 From 3e53de4bdd6d7b6de1fe3e5bfbdc53180aa9a737 Mon Sep 17 00:00:00 2001 From: Terence Yim Date: Tue, 12 Apr 2016 13:46:39 -0700 Subject: [SPARK-14513][CORE] Fix threads left behind after stopping SparkContext ## What changes were proposed in this pull request? Shutting down `QueuedThreadPool` used by Jetty `Server` to avoid threads leakage after SparkContext is stopped. Note: If this fix is going to apply to the `branch-1.6`, one more patch on the `NettyRpcEnv` class is needed so that the `NettyRpcEnv._fileServer.shutdown` is called in the `NettyRpcEnv.cleanup` method. This is due to the removal of `_fileServer` field in the `NettyRpcEnv` class in the master branch. Please advice if a second PR is necessary for bring this fix back to `branch-1.6` ## How was this patch tested? Ran the ./dev/run-tests locally Author: Terence Yim Closes #12318 from chtyim/fixes/SPARK-14513-thread-leak. --- core/src/main/scala/org/apache/spark/HttpServer.scala | 7 +++++++ core/src/main/scala/org/apache/spark/ui/JettyUtils.scala | 14 +++++++++++++- core/src/main/scala/org/apache/spark/ui/WebUI.scala | 2 +- 3 files changed, 21 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/HttpServer.scala b/core/src/main/scala/org/apache/spark/HttpServer.scala index 9fad1f6786..982b6d6b61 100644 --- a/core/src/main/scala/org/apache/spark/HttpServer.scala +++ b/core/src/main/scala/org/apache/spark/HttpServer.scala @@ -25,6 +25,7 @@ import org.eclipse.jetty.server.Server import org.eclipse.jetty.server.bio.SocketConnector import org.eclipse.jetty.server.ssl.SslSocketConnector import org.eclipse.jetty.servlet.{DefaultServlet, ServletContextHandler, ServletHolder} +import org.eclipse.jetty.util.component.LifeCycle import org.eclipse.jetty.util.security.{Constraint, Password} import org.eclipse.jetty.util.thread.QueuedThreadPool @@ -155,6 +156,12 @@ private[spark] class HttpServer( throw new ServerStateException("Server is already stopped") } else { server.stop() + // Stop the ThreadPool if it supports stop() method (through LifeCycle). + // It is needed because stopping the Server won't stop the ThreadPool it uses. + val threadPool = server.getThreadPool + if (threadPool != null && threadPool.isInstanceOf[LifeCycle]) { + threadPool.asInstanceOf[LifeCycle].stop + } port = -1 server = null } diff --git a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala index c3c59f857d..119165f724 100644 --- a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala @@ -30,6 +30,7 @@ import org.eclipse.jetty.server.handler._ import org.eclipse.jetty.server.nio.SelectChannelConnector import org.eclipse.jetty.server.ssl.SslSelectChannelConnector import org.eclipse.jetty.servlet._ +import org.eclipse.jetty.util.component.LifeCycle import org.eclipse.jetty.util.thread.QueuedThreadPool import org.json4s.JValue import org.json4s.jackson.JsonMethods.{pretty, render} @@ -350,4 +351,15 @@ private[spark] object JettyUtils extends Logging { private[spark] case class ServerInfo( server: Server, boundPort: Int, - rootHandler: ContextHandlerCollection) + rootHandler: ContextHandlerCollection) { + + def stop(): Unit = { + server.stop() + // Stop the ThreadPool if it supports stop() method (through LifeCycle). + // It is needed because stopping the Server won't stop the ThreadPool it uses. + val threadPool = server.getThreadPool + if (threadPool != null && threadPool.isInstanceOf[LifeCycle]) { + threadPool.asInstanceOf[LifeCycle].stop + } + } +} diff --git a/core/src/main/scala/org/apache/spark/ui/WebUI.scala b/core/src/main/scala/org/apache/spark/ui/WebUI.scala index 250b7f2e5f..3939b111b5 100644 --- a/core/src/main/scala/org/apache/spark/ui/WebUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/WebUI.scala @@ -153,7 +153,7 @@ private[spark] abstract class WebUI( def stop() { assert(serverInfo.isDefined, "Attempted to stop %s before binding to a server!".format(className)) - serverInfo.get.server.stop() + serverInfo.get.stop() } } -- cgit v1.2.3 From 1ef5f8cfa6d6b7c9ec58a96dc447ab56ef709381 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Tue, 12 Apr 2016 15:03:00 -0700 Subject: [SPARK-14544] [SQL] improve performance of SQL UI tab ## What changes were proposed in this pull request? This PR improve the performance of SQL UI by: 1) remove the details column in all executions page (the first page in SQL tab). We can check the details by enter the execution page. 2) break-all is super slow in Chrome recently, so switch to break-word. 3) Using "display: none" to hide a block. 4) using one js closure for for all the executions, not one for each. 5) remove the height limitation of details, don't need to scroll it in the tiny window. ## How was this patch tested? Exists tests. ![ui](https://cloud.githubusercontent.com/assets/40902/14445712/68d7b258-0004-11e6-9b48-5d329b05d165.png) Author: Davies Liu Closes #12311 from davies/ui_perf. --- .../resources/org/apache/spark/ui/static/webui.css | 8 +++-- .../spark/sql/execution/ui/AllExecutionsPage.scala | 40 ++++++---------------- .../apache/spark/streaming/UISeleniumSuite.scala | 4 +-- 3 files changed, 17 insertions(+), 35 deletions(-) diff --git a/core/src/main/resources/org/apache/spark/ui/static/webui.css b/core/src/main/resources/org/apache/spark/ui/static/webui.css index 48f86d1536..47dd9162a1 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/webui.css +++ b/core/src/main/resources/org/apache/spark/ui/static/webui.css @@ -106,21 +106,22 @@ pre { line-height: 18px; padding: 6px; margin: 0; + word-break: break-word; border-radius: 3px; } .stage-details { - max-height: 100px; overflow-y: auto; margin: 0; + display: block; transition: max-height 0.25s ease-out, padding 0.25s ease-out; } .stage-details.collapsed { - max-height: 0; padding-top: 0; padding-bottom: 0; border: none; + display: none; } .description-input { @@ -143,14 +144,15 @@ pre { max-height: 300px; overflow-y: auto; margin: 0; + display: block; transition: max-height 0.25s ease-out, padding 0.25s ease-out; } .stacktrace-details.collapsed { - max-height: 0; padding-top: 0; padding-bottom: 0; border: none; + display: none; } span.expand-additional-metrics, span.expand-dag-viz { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/AllExecutionsPage.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/AllExecutionsPage.scala index d3e823fdeb..e96fb9f755 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/AllExecutionsPage.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/AllExecutionsPage.scala @@ -55,6 +55,12 @@ private[ui] class AllExecutionsPage(parent: SQLTab) extends WebUIPage("") with L } _content } + content ++= + UIUtils.headerSparkPage("SQL", content, parent, Some(5000)) } } @@ -118,14 +124,12 @@ private[ui] abstract class ExecutionTable( {failedJobs} }} - {detailCell(executionUIData.physicalPlanDescription)} } private def descriptionCell(execution: SQLExecutionUIData): Seq[Node] = { val details = if (execution.details.nonEmpty) { - + +details ++ - } - def toNodeSeq: Seq[Node] = {

    {tableName}

    @@ -197,7 +177,7 @@ private[ui] class RunningExecutionTable( showFailedJobs = true) { override protected def header: Seq[String] = - baseHeader ++ Seq("Running Jobs", "Succeeded Jobs", "Failed Jobs", "Detail") + baseHeader ++ Seq("Running Jobs", "Succeeded Jobs", "Failed Jobs") } private[ui] class CompletedExecutionTable( @@ -215,7 +195,7 @@ private[ui] class CompletedExecutionTable( showSucceededJobs = true, showFailedJobs = false) { - override protected def header: Seq[String] = baseHeader ++ Seq("Jobs", "Detail") + override protected def header: Seq[String] = baseHeader ++ Seq("Jobs") } private[ui] class FailedExecutionTable( @@ -234,5 +214,5 @@ private[ui] class FailedExecutionTable( showFailedJobs = true) { override protected def header: Seq[String] = - baseHeader ++ Seq("Succeeded Jobs", "Failed Jobs", "Detail") + baseHeader ++ Seq("Succeeded Jobs", "Failed Jobs") } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/UISeleniumSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/UISeleniumSuite.scala index 3f12de38ef..454c3dffa3 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/UISeleniumSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/UISeleniumSuite.scala @@ -169,9 +169,9 @@ class UISeleniumSuite List("4/4", "4/4", "4/4", "0/4 (1 failed)")) // Check stacktrace - val errorCells = findAll(cssSelector(""".stacktrace-details""")).map(_.text).toSeq + val errorCells = findAll(cssSelector(""".stacktrace-details""")).map(_.underlying).toSeq errorCells should have size 1 - errorCells(0) should include("java.lang.RuntimeException: Oops") + // Can't get the inner (invisible) text without running JS // Check the job link in the batch page is right go to (jobLinks(0)) -- cgit v1.2.3 From c439d88e99c35a5f29f071715addfee8cbb215dc Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Tue, 12 Apr 2016 15:28:08 -0700 Subject: [SPARK-14547] Avoid DNS resolution for reusing connections ## What changes were proposed in this pull request? This patch changes the connection creation logic in the network client module to avoid DNS resolution when reusing connections. ## How was this patch tested? Testing in production. This is too difficult to test in isolation (for high fidelity unit tests, we'd need to change the DNS resolution behavior in the JVM). Author: Reynold Xin Closes #12315 from rxin/SPARK-14547. --- .../network/client/TransportClientFactory.java | 31 ++++++++++++++-------- 1 file changed, 20 insertions(+), 11 deletions(-) diff --git a/common/network-common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java b/common/network-common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java index b5a9d6671f..a27aaf2b27 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java +++ b/common/network-common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java @@ -123,16 +123,15 @@ public class TransportClientFactory implements Closeable { public TransportClient createClient(String remoteHost, int remotePort) throws IOException { // Get connection from the connection pool first. // If it is not found or not active, create a new one. - long preResolveHost = System.nanoTime(); - final InetSocketAddress address = new InetSocketAddress(remoteHost, remotePort); - long hostResolveTimeMs = (System.nanoTime() - preResolveHost) / 1000000; - logger.info("Spent {} ms to resolve {}", hostResolveTimeMs, address); + // Use unresolved address here to avoid DNS resolution each time we creates a client. + final InetSocketAddress unresolvedAddress = + InetSocketAddress.createUnresolved(remoteHost, remotePort); // Create the ClientPool if we don't have it yet. - ClientPool clientPool = connectionPool.get(address); + ClientPool clientPool = connectionPool.get(unresolvedAddress); if (clientPool == null) { - connectionPool.putIfAbsent(address, new ClientPool(numConnectionsPerPeer)); - clientPool = connectionPool.get(address); + connectionPool.putIfAbsent(unresolvedAddress, new ClientPool(numConnectionsPerPeer)); + clientPool = connectionPool.get(unresolvedAddress); } int clientIndex = rand.nextInt(numConnectionsPerPeer); @@ -149,25 +148,35 @@ public class TransportClientFactory implements Closeable { } if (cachedClient.isActive()) { - logger.trace("Returning cached connection to {}: {}", address, cachedClient); + logger.trace("Returning cached connection to {}: {}", + cachedClient.getSocketAddress(), cachedClient); return cachedClient; } } // If we reach here, we don't have an existing connection open. Let's create a new one. // Multiple threads might race here to create new connections. Keep only one of them active. + final long preResolveHost = System.nanoTime(); + final InetSocketAddress resolvedAddress = new InetSocketAddress(remoteHost, remotePort); + final long hostResolveTimeMs = (System.nanoTime() - preResolveHost) / 1000000; + if (hostResolveTimeMs > 2000) { + logger.warn("DNS resolution for {} took {} ms", resolvedAddress, hostResolveTimeMs); + } else { + logger.trace("DNS resolution for {} took {} ms", resolvedAddress, hostResolveTimeMs); + } + synchronized (clientPool.locks[clientIndex]) { cachedClient = clientPool.clients[clientIndex]; if (cachedClient != null) { if (cachedClient.isActive()) { - logger.trace("Returning cached connection to {}: {}", address, cachedClient); + logger.trace("Returning cached connection to {}: {}", resolvedAddress, cachedClient); return cachedClient; } else { - logger.info("Found inactive connection to {}, creating a new one.", address); + logger.info("Found inactive connection to {}, creating a new one.", resolvedAddress); } } - clientPool.clients[clientIndex] = createClient(address); + clientPool.clients[clientIndex] = createClient(resolvedAddress); return clientPool.clients[clientIndex]; } } -- cgit v1.2.3 From d187e7dea9540d26b7800de4eb79863ef5f574bf Mon Sep 17 00:00:00 2001 From: Sital Kedia Date: Tue, 12 Apr 2016 16:10:07 -0700 Subject: [SPARK-14363] Fix executor OOM due to memory leak in the Sorter ## What changes were proposed in this pull request? Fix memory leak in the Sorter. When the UnsafeExternalSorter spills the data to disk, it does not free up the underlying pointer array. As a result, we see a lot of executor OOM and also memory under utilization. This is a regression partially introduced in PR https://github.com/apache/spark/pull/9241 ## How was this patch tested? Tested by running a job and observed around 30% speedup after this change. Author: Sital Kedia Closes #12285 from sitalkedia/executor_oom. --- .../java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java | 6 ++++-- .../java/org/apache/spark/shuffle/sort/ShuffleInMemorySorter.java | 7 +++++++ .../spark/util/collection/unsafe/sort/UnsafeExternalSorter.java | 7 +++++-- .../spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java | 7 +++++++ 4 files changed, 23 insertions(+), 4 deletions(-) diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java index 81ee7ab58a..3c2980e442 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java @@ -215,8 +215,6 @@ final class ShuffleExternalSorter extends MemoryConsumer { } } - inMemSorter.reset(); - if (!isLastFile) { // i.e. this is a spill file // The current semantics of `shuffleRecordsWritten` seem to be that it's updated when records // are written to disk, not when they enter the shuffle sorting code. DiskBlockObjectWriter @@ -255,6 +253,10 @@ final class ShuffleExternalSorter extends MemoryConsumer { writeSortedFile(false); final long spillSize = freeMemory(); + inMemSorter.reset(); + // Reset the in-memory sorter's pointer array only after freeing up the memory pages holding the + // records. Otherwise, if the task is over allocated memory, then without freeing the memory pages, + // we might not be able to get memory for the pointer array. taskContext.taskMetrics().incMemoryBytesSpilled(spillSize); return spillSize; } diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorter.java b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorter.java index fe79ff0e30..76b0e6a304 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorter.java @@ -51,9 +51,12 @@ final class ShuffleInMemorySorter { */ private int pos = 0; + private int initialSize; + ShuffleInMemorySorter(MemoryConsumer consumer, int initialSize) { this.consumer = consumer; assert (initialSize > 0); + this.initialSize = initialSize; this.array = consumer.allocateArray(initialSize); this.sorter = new Sorter<>(ShuffleSortDataFormat.INSTANCE); } @@ -70,6 +73,10 @@ final class ShuffleInMemorySorter { } public void reset() { + if (consumer != null) { + consumer.freeArray(array); + this.array = consumer.allocateArray(initialSize); + } pos = 0; } diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java index ded8f0472b..ef79b49083 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java @@ -200,14 +200,17 @@ public final class UnsafeExternalSorter extends MemoryConsumer { spillWriter.write(baseObject, baseOffset, recordLength, sortedRecords.getKeyPrefix()); } spillWriter.close(); - - inMemSorter.reset(); } final long spillSize = freeMemory(); // Note that this is more-or-less going to be a multiple of the page size, so wasted space in // pages will currently be counted as memory spilled even though that space isn't actually // written to disk. This also counts the space needed to store the sorter's pointer array. + inMemSorter.reset(); + // Reset the in-memory sorter's pointer array only after freeing up the memory pages holding the + // records. Otherwise, if the task is over allocated memory, then without freeing the memory pages, + // we might not be able to get memory for the pointer array. + taskContext.taskMetrics().incMemoryBytesSpilled(spillSize); return spillSize; diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java index 145c3a1950..01eae0e8dc 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java @@ -84,6 +84,8 @@ public final class UnsafeInMemorySorter { */ private int pos = 0; + private long initialSize; + public UnsafeInMemorySorter( final MemoryConsumer consumer, final TaskMemoryManager memoryManager, @@ -102,6 +104,7 @@ public final class UnsafeInMemorySorter { LongArray array) { this.consumer = consumer; this.memoryManager = memoryManager; + this.initialSize = array.size(); if (recordComparator != null) { this.sorter = new Sorter<>(UnsafeSortDataFormat.INSTANCE); this.sortComparator = new SortComparator(recordComparator, prefixComparator, memoryManager); @@ -123,6 +126,10 @@ public final class UnsafeInMemorySorter { } public void reset() { + if (consumer != null) { + consumer.freeArray(array); + this.array = consumer.allocateArray(initialSize); + } pos = 0; } -- cgit v1.2.3 From 372baf0479840695388515170e6eae0b3fc4125e Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Tue, 12 Apr 2016 17:26:37 -0700 Subject: [SPARK-14578] [SQL] Fix codegen for CreateExternalRow with nested wide schema ## What changes were proposed in this pull request? The wide schema, the expression of fields will be splitted into multiple functions, but the variable for loopVar can't be accessed in splitted functions, this PR change them as class member. ## How was this patch tested? Added regression test. Author: Davies Liu Closes #12338 from davies/nested_row. --- .../apache/spark/sql/catalyst/expressions/objects.scala | 8 +++++--- .../spark/sql/execution/datasources/json/JsonSuite.scala | 15 +++++++++++++++ 2 files changed, 20 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala index 28b6b2adf8..26b1ff39b3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala @@ -446,6 +446,8 @@ case class MapObjects private( override def genCode(ctx: CodegenContext, ev: ExprCode): String = { val javaType = ctx.javaType(dataType) val elementJavaType = ctx.javaType(loopVar.dataType) + ctx.addMutableState("boolean", loopVar.isNull, "") + ctx.addMutableState(elementJavaType, loopVar.value, "") val genInputData = inputData.gen(ctx) val genFunction = lambdaFunction.gen(ctx) val dataLength = ctx.freshName("dataLength") @@ -466,9 +468,9 @@ case class MapObjects private( } val loopNullCheck = if (primitiveElement) { - s"boolean ${loopVar.isNull} = ${genInputData.value}.isNullAt($loopIndex);" + s"${loopVar.isNull} = ${genInputData.value}.isNullAt($loopIndex);" } else { - s"boolean ${loopVar.isNull} = ${genInputData.isNull} || ${loopVar.value} == null;" + s"${loopVar.isNull} = ${genInputData.isNull} || ${loopVar.value} == null;" } s""" @@ -484,7 +486,7 @@ case class MapObjects private( int $loopIndex = 0; while ($loopIndex < $dataLength) { - $elementJavaType ${loopVar.value} = + ${loopVar.value} = ($elementJavaType)${genInputData.value}${itemAccessor(loopIndex)}; $loopNullCheck diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index 2a18acb95b..e17340c70b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -1664,4 +1664,19 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { ) } } + + test("wide nested json table") { + val nested = (1 to 100).map { i => + s""" + |"c$i": $i + """.stripMargin + }.mkString(", ") + val json = s""" + |{"a": [{$nested}], "b": [{$nested}]} + """.stripMargin + val rdd = sqlContext.sparkContext.makeRDD(Seq(json)) + val df = sqlContext.read.json(rdd) + assert(df.schema.size === 2) + df.collect() + } } -- cgit v1.2.3 From 768b3d623c29eaf960be096845b7c421f8a3ba36 Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Tue, 12 Apr 2016 17:31:47 -0700 Subject: [SPARK-14579][SQL] Fix a race condition in StreamExecution.processAllAvailable ## What changes were proposed in this pull request? There is a race condition in `StreamExecution.processAllAvailable`. Here is an execution order to reproduce it. | Time |Thread 1 | MicroBatchThread | |:-------------:|:-------------:|:-----:| | 1 | | `dataAvailable in constructNextBatch` returns false | | 2 | addData(newData) | | | 3 | `noNewData = false` in processAllAvailable | | | 4 | | noNewData = true | | 5 | `noNewData` is true so just return | | The root cause is that `checking dataAvailable and change noNewData to true` is not atomic. This PR puts these two actions into `synchronized` to make sure they are atomic. In addition, this PR also has the following changes: - Make `committedOffsets` and `availableOffsets` volatile to make sure they can be seen in other threads. - Copy the reference of `availableOffsets` to a local variable so that `sourceStatuses` can use a snapshot of `availableOffsets`. ## How was this patch tested? Existing unit tests. Author: Shixiong Zhu Closes #12339 from zsxwing/race-condition. --- .../sql/execution/streaming/StreamExecution.scala | 40 +++++++++++++++------- 1 file changed, 27 insertions(+), 13 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala index 688e051e1f..87dd27a2b1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala @@ -59,12 +59,14 @@ class StreamExecution( * Tracks how much data we have processed and committed to the sink or state store from each * input source. */ + @volatile private[sql] var committedOffsets = new StreamProgress /** * Tracks the offsets that are available to be processed, but have not yet be committed to the * sink. */ + @volatile private var availableOffsets = new StreamProgress /** The current batchId or -1 if execution has not yet been initialized. */ @@ -111,7 +113,8 @@ class StreamExecution( /** Returns current status of all the sources. */ override def sourceStatuses: Array[SourceStatus] = { - sources.map(s => new SourceStatus(s.toString, availableOffsets.get(s))).toArray + val localAvailableOffsets = availableOffsets + sources.map(s => new SourceStatus(s.toString, localAvailableOffsets.get(s))).toArray } /** Returns current status of the sink. */ @@ -228,7 +231,7 @@ class StreamExecution( * Queries all of the sources to see if any new data is available. When there is new data the * batchId counter is incremented and a new log entry is written with the newest offsets. */ - private def constructNextBatch(): Boolean = { + private def constructNextBatch(): Unit = { // There is a potential dead-lock in Hadoop "Shell.runCommand" before 2.5.0 (HADOOP-10622). // If we interrupt some thread running Shell.runCommand, we may hit this issue. // As "FileStreamSource.getOffset" will create a file using HDFS API and call "Shell.runCommand" @@ -241,7 +244,15 @@ class StreamExecution( } availableOffsets ++= newData - if (dataAvailable) { + val hasNewData = awaitBatchLock.synchronized { + if (dataAvailable) { + true + } else { + noNewData = true + false + } + } + if (hasNewData) { // There is a potential dead-lock in Hadoop "Shell.runCommand" before 2.5.0 (HADOOP-10622). // If we interrupt some thread running Shell.runCommand, we may hit this issue. // As "offsetLog.add" will create a file using HDFS API and call "Shell.runCommand" to set @@ -254,15 +265,11 @@ class StreamExecution( } currentBatchId += 1 logInfo(s"Committed offsets for batch $currentBatchId.") - true } else { - noNewData = true awaitBatchLock.synchronized { // Wake up any threads that are waiting for the stream to progress. awaitBatchLock.notifyAll() } - - false } } @@ -353,7 +360,10 @@ class StreamExecution( * least the given `Offset`. This method is indented for use primarily when writing tests. */ def awaitOffset(source: Source, newOffset: Offset): Unit = { - def notDone = !committedOffsets.contains(source) || committedOffsets(source) < newOffset + def notDone = { + val localCommittedOffsets = committedOffsets + !localCommittedOffsets.contains(source) || localCommittedOffsets(source) < newOffset + } while (notDone) { logInfo(s"Waiting until $newOffset at $source") @@ -365,13 +375,17 @@ class StreamExecution( /** A flag to indicate that a batch has completed with no new data available. */ @volatile private var noNewData = false - override def processAllAvailable(): Unit = { + override def processAllAvailable(): Unit = awaitBatchLock.synchronized { noNewData = false - while (!noNewData) { - awaitBatchLock.synchronized { awaitBatchLock.wait(10000) } - if (streamDeathCause != null) { throw streamDeathCause } + while (true) { + awaitBatchLock.wait(10000) + if (streamDeathCause != null) { + throw streamDeathCause + } + if (noNewData) { + return + } } - if (streamDeathCause != null) { throw streamDeathCause } } override def awaitTermination(): Unit = { -- cgit v1.2.3 From 587cd554af24601d332e9ce5c74e98b62d0fd830 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Wed, 13 Apr 2016 10:20:03 +0800 Subject: [MINOR][SQL] Remove some unused imports in datasources. ## What changes were proposed in this pull request? It looks several recent commits for datasources (maybe while removing old `HadoopFsRelation` interface) missed removing some unused imports. This PR removes some unused imports in datasources. ## How was this patch tested? `sbt scalastyle` and some unit tests for them. Author: hyukjinkwon Closes #12326 from HyukjinKwon/minor-imports. --- .../org/apache/spark/ml/source/libsvm/LibSVMRelation.scala | 3 --- .../spark/sql/execution/datasources/DataSourceStrategy.scala | 8 +------- .../execution/datasources/InsertIntoHadoopFsRelation.scala | 1 - .../spark/sql/execution/datasources/SqlNewHadoopRDD.scala | 2 -- .../spark/sql/execution/datasources/WriterContainer.scala | 2 +- .../spark/sql/execution/datasources/json/JSONRelation.scala | 4 +--- .../spark/sql/execution/datasources/json/JacksonParser.scala | 6 +++--- .../sql/execution/datasources/parquet/ParquetRelation.scala | 12 +++--------- .../spark/sql/execution/datasources/text/DefaultSource.scala | 7 +------ .../scala/org/apache/spark/sql/hive/orc/OrcRelation.scala | 2 -- 10 files changed, 10 insertions(+), 37 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala index 4737b6fe52..2f1f2523fd 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala @@ -26,11 +26,9 @@ import org.apache.hadoop.mapreduce.{Job, RecordWriter, TaskAttemptContext} import org.apache.hadoop.mapreduce.lib.output.TextOutputFormat import org.apache.spark.annotation.Since -import org.apache.spark.broadcast.Broadcast import org.apache.spark.mllib.linalg.{Vector, Vectors, VectorUDT} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.util.MLUtils -import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, DataFrameReader, Row, SQLContext} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.encoders.RowEncoder @@ -40,7 +38,6 @@ import org.apache.spark.sql.execution.datasources.{CaseInsensitiveMap, HadoopFil import org.apache.spark.sql.sources._ import org.apache.spark.sql.types._ import org.apache.spark.util.SerializableConfiguration -import org.apache.spark.util.collection.BitSet private[libsvm] class LibSVMOutputWriter( path: String, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index c3885a3be5..ac3c52e901 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -19,10 +19,8 @@ package org.apache.spark.sql.execution.datasources import scala.collection.mutable.ArrayBuffer -import org.apache.spark.TaskContext -import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.internal.Logging -import org.apache.spark.rdd.{MapPartitionsRDD, RDD, UnionRDD} +import org.apache.spark.rdd.RDD import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.CatalystTypeConverters.convertToScala @@ -35,14 +33,10 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.DataSourceScan.{INPUT_PATHS, PUSHED_FILTERS} -import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.command.ExecutedCommand -import org.apache.spark.sql.execution.vectorized.{ColumnarBatch, ColumnVectorUtils} import org.apache.spark.sql.sources._ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String -import org.apache.spark.util.{SerializableConfiguration, Utils} -import org.apache.spark.util.collection.BitSet /** * Replaces generic operations with specific variants that are designed to work with Spark diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelation.scala index e31380e17d..889c0204f8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelation.scala @@ -32,7 +32,6 @@ import org.apache.spark.sql.execution.SQLExecution import org.apache.spark.sql.execution.command.RunnableCommand import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources._ -import org.apache.spark.util.Utils /** * A command for writing data to a [[HadoopFsRelation]]. Supports both overwriting and appending. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SqlNewHadoopRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SqlNewHadoopRDD.scala index 6ddb218a22..4d6864d8ba 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SqlNewHadoopRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SqlNewHadoopRDD.scala @@ -34,8 +34,6 @@ import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.executor.DataReadMethod import org.apache.spark.internal.Logging import org.apache.spark.sql.SQLContext -import org.apache.spark.sql.execution.datasources.parquet.VectorizedParquetRecordReader -import org.apache.spark.sql.internal.SQLConf import org.apache.spark.storage.StorageLevel import org.apache.spark.util.{SerializableConfiguration, ShutdownHookManager} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala index b9a3162aba..815d1d01ef 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala @@ -33,7 +33,7 @@ import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.execution.UnsafeKVExternalSorter import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.sources.{HadoopFsRelation, OutputWriter, OutputWriterFactory} +import org.apache.spark.sql.sources.{OutputWriter, OutputWriterFactory} import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType} import org.apache.spark.util.{SerializableConfiguration, Utils} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala index f32fea4183..7364a1dc06 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala @@ -28,18 +28,16 @@ import org.apache.hadoop.mapreduce.{Job, RecordWriter, TaskAttemptContext} import org.apache.hadoop.mapreduce.lib.input.FileInputFormat import org.apache.hadoop.mapreduce.lib.output.TextOutputFormat -import org.apache.spark.broadcast.Broadcast import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD import org.apache.spark.sql.{AnalysisException, Row, SQLContext} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{JoinedRow, UnsafeProjection} +import org.apache.spark.sql.catalyst.expressions.JoinedRow import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.sources._ import org.apache.spark.sql.types.StructType import org.apache.spark.util.SerializableConfiguration -import org.apache.spark.util.collection.BitSet class DefaultSource extends FileFormat with DataSourceRegister { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonParser.scala index 8bc53bae6c..aeee2600a1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonParser.scala @@ -54,9 +54,9 @@ object JacksonParser extends Logging { * with an array. */ def convertRootField( - factory: JsonFactory, - parser: JsonParser, - schema: DataType): Any = { + factory: JsonFactory, + parser: JsonParser, + schema: DataType): Any = { import com.fasterxml.jackson.core.JsonToken._ (parser.getCurrentToken, schema) match { case (START_ARRAY, st: StructType) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala index dbda094996..b91e892f8f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala @@ -18,7 +18,6 @@ package org.apache.spark.sql.execution.datasources.parquet import java.net.URI -import java.util.{List => JList} import java.util.logging.{Logger => JLogger} import scala.collection.JavaConverters._ @@ -27,23 +26,19 @@ import scala.util.{Failure, Try} import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileStatus, Path} -import org.apache.hadoop.io.Writable import org.apache.hadoop.mapreduce._ import org.apache.hadoop.mapreduce.lib.input.{FileInputFormat, FileSplit} -import org.apache.hadoop.mapreduce.task.{JobContextImpl, TaskAttemptContextImpl} +import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl import org.apache.parquet.{Log => ApacheParquetLog} import org.apache.parquet.filter2.compat.FilterCompat import org.apache.parquet.filter2.predicate.FilterApi import org.apache.parquet.hadoop._ -import org.apache.parquet.hadoop.metadata.CompressionCodecName import org.apache.parquet.hadoop.util.ContextUtil import org.apache.parquet.schema.MessageType import org.slf4j.bridge.SLF4JBridgeHandler -import org.apache.spark.{Partition => SparkPartition, SparkException} -import org.apache.spark.broadcast.Broadcast +import org.apache.spark.SparkException import org.apache.spark.internal.Logging -import org.apache.spark.rdd.{RDD, SqlNewHadoopPartition, SqlNewHadoopRDD} import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.JoinedRow @@ -53,8 +48,7 @@ import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources._ import org.apache.spark.sql.types.{AtomicType, DataType, StructType} -import org.apache.spark.util.{SerializableConfiguration, Utils} -import org.apache.spark.util.collection.BitSet +import org.apache.spark.util.SerializableConfiguration private[sql] class DefaultSource extends FileFormat diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/DefaultSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/DefaultSource.scala index 28b03ee7c3..94ecb7a286 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/DefaultSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/DefaultSource.scala @@ -19,14 +19,10 @@ package org.apache.spark.sql.execution.datasources.text import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileStatus, Path} -import org.apache.hadoop.io.{LongWritable, NullWritable, Text} -import org.apache.hadoop.mapred.{JobConf, TextInputFormat} +import org.apache.hadoop.io.{NullWritable, Text} import org.apache.hadoop.mapreduce.{Job, RecordWriter, TaskAttemptContext} -import org.apache.hadoop.mapreduce.lib.input.FileInputFormat import org.apache.hadoop.mapreduce.lib.output.TextOutputFormat -import org.apache.spark.broadcast.Broadcast -import org.apache.spark.rdd.RDD import org.apache.spark.sql.{AnalysisException, Row, SQLContext} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.UnsafeRow @@ -35,7 +31,6 @@ import org.apache.spark.sql.execution.datasources.{CompressionCodecs, HadoopFile import org.apache.spark.sql.sources._ import org.apache.spark.sql.types.{StringType, StructType} import org.apache.spark.util.SerializableConfiguration -import org.apache.spark.util.collection.BitSet /** * A data source for reading text files. diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala index e915f3dfe2..21591ec093 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala @@ -33,7 +33,6 @@ import org.apache.hadoop.mapreduce._ import org.apache.hadoop.mapreduce.lib.input.{FileInputFormat, FileSplit} import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl -import org.apache.spark.broadcast.Broadcast import org.apache.spark.internal.Logging import org.apache.spark.rdd.{HadoopRDD, RDD} import org.apache.spark.sql.{Row, SQLContext} @@ -45,7 +44,6 @@ import org.apache.spark.sql.hive.{HiveInspectors, HiveMetastoreTypes, HiveShim} import org.apache.spark.sql.sources.{Filter, _} import org.apache.spark.sql.types.StructType import org.apache.spark.util.SerializableConfiguration -import org.apache.spark.util.collection.BitSet private[sql] class DefaultSource extends FileFormat with DataSourceRegister with Serializable { -- cgit v1.2.3 From a5f8c9b15b6181f04c1314c638017adb8c88c7df Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 13 Apr 2016 11:41:09 +0800 Subject: [SPARK-14554][SQL][FOLLOW-UP] use checkDataset to check the result ## What changes were proposed in this pull request? address this comment: https://github.com/apache/spark/pull/12322#discussion_r59417359 ## How was this patch tested? N/A Author: Wenchen Fan Closes #12346 from cloud-fan/tmp. --- sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 47251681e3..d074535bf6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -624,7 +624,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { test("SPARK-14554: Dataset.map may generate wrong java code for wide table") { val wideDF = sqlContext.range(10).select(Seq.tabulate(1000) {i => ('id + i).as(s"c$i")} : _*) // Make sure the generated code for this plan can compile and execute. - wideDF.map(_.getLong(0)).collect() + checkDataset(wideDF.map(_.getLong(0)), 0L until 10 : _*) } } -- cgit v1.2.3 From 23f93f559cbe5436df3bad75f4ffa1219b0d6968 Mon Sep 17 00:00:00 2001 From: Liwei Lin Date: Tue, 12 Apr 2016 23:06:55 -0700 Subject: [SPARK-13992][CORE][PYSPARK][FOLLOWUP] Update OFF_HEAP semantics for Java api and Python api ## What changes were proposed in this pull request? - updated `OFF_HEAP` semantics for `StorageLevels.java` - updated `OFF_HEAP` semantics for `storagelevel.py` ## How was this patch tested? no need to test Author: Liwei Lin Closes #12126 from lw-lin/storagelevel.py. --- core/src/main/java/org/apache/spark/api/java/StorageLevels.java | 2 +- python/pyspark/storagelevel.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/core/src/main/java/org/apache/spark/api/java/StorageLevels.java b/core/src/main/java/org/apache/spark/api/java/StorageLevels.java index 23673d3e3d..3fcb52f615 100644 --- a/core/src/main/java/org/apache/spark/api/java/StorageLevels.java +++ b/core/src/main/java/org/apache/spark/api/java/StorageLevels.java @@ -34,7 +34,7 @@ public class StorageLevels { public static final StorageLevel MEMORY_AND_DISK_2 = create(true, true, false, true, 2); public static final StorageLevel MEMORY_AND_DISK_SER = create(true, true, false, false, 1); public static final StorageLevel MEMORY_AND_DISK_SER_2 = create(true, true, false, false, 2); - public static final StorageLevel OFF_HEAP = create(false, false, true, false, 1); + public static final StorageLevel OFF_HEAP = create(true, true, true, false, 1); /** * Create a new StorageLevel object. diff --git a/python/pyspark/storagelevel.py b/python/pyspark/storagelevel.py index 176e3bb41c..ef012d27cb 100644 --- a/python/pyspark/storagelevel.py +++ b/python/pyspark/storagelevel.py @@ -55,7 +55,7 @@ StorageLevel.MEMORY_ONLY = StorageLevel(False, True, False, False) StorageLevel.MEMORY_ONLY_2 = StorageLevel(False, True, False, False, 2) StorageLevel.MEMORY_AND_DISK = StorageLevel(True, True, False, False) StorageLevel.MEMORY_AND_DISK_2 = StorageLevel(True, True, False, False, 2) -StorageLevel.OFF_HEAP = StorageLevel(False, False, True, False, 1) +StorageLevel.OFF_HEAP = StorageLevel(True, True, True, False, 1) """ .. note:: The following four storage level constants are deprecated in 2.0, since the records \ -- cgit v1.2.3 From dd11e401e45563b4bdc9829f5d23b68dacac8caf Mon Sep 17 00:00:00 2001 From: Charles Allen Date: Wed, 13 Apr 2016 16:02:49 +0100 Subject: [SPARK-14537][CORE] Make TaskSchedulerImpl waiting fail if context is shut down This patch makes the postStartHook throw an IllegalStateException if the SparkContext is shutdown while it is waiting for the backend to be ready Author: Charles Allen Closes #12301 from drcrallen/SPARK-14537. --- .../main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala index daed2ff50e..c3159188d9 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala @@ -571,6 +571,11 @@ private[spark] class TaskSchedulerImpl( return } while (!backend.isReady) { + // Might take a while for backend to be ready if it is waiting on resources. + if (sc.stopped.get) { + // For example: the master removes the application for some reason + throw new IllegalStateException("Spark context stopped while waiting for backend") + } synchronized { this.wait(100) } -- cgit v1.2.3 From 323e7390a5c123c48cc7d6d9be44bee3a7eecd99 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Wed, 13 Apr 2016 09:17:46 -0700 Subject: Revert "[SPARK-14154][MLLIB] Simplify the implementation for Kolmogorov–Smirnov test" MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This reverts commit d2a819a6363190b946986ebf6f8001d520098c3b. --- .../mllib/stat/test/KolmogorovSmirnovTest.scala | 77 ++++++++++++++++++++-- 1 file changed, 73 insertions(+), 4 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/test/KolmogorovSmirnovTest.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/KolmogorovSmirnovTest.scala index ef284531c9..9748fbf2c9 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/test/KolmogorovSmirnovTest.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/KolmogorovSmirnovTest.scala @@ -64,10 +64,11 @@ private[stat] object KolmogorovSmirnovTest extends Logging { */ def testOneSample(data: RDD[Double], cdf: Double => Double): KolmogorovSmirnovTestResult = { val n = data.count().toDouble - val ksStat = data.sortBy(x => x).zipWithIndex().map { case (v, i) => - val f = cdf(v) - math.max(f - i / n, (i + 1) / n - f) - }.max() + val localData = data.sortBy(x => x).mapPartitions { part => + val partDiffs = oneSampleDifferences(part, n, cdf) // local distances + searchOneSampleCandidates(partDiffs) // candidates: local extrema + }.collect() + val ksStat = searchOneSampleStatistic(localData, n) // result: global extreme evalOneSampleP(ksStat, n.toLong) } @@ -83,6 +84,74 @@ private[stat] object KolmogorovSmirnovTest extends Logging { testOneSample(data, cdf) } + /** + * Calculate unadjusted distances between the empirical CDF and the theoretical CDF in a + * partition + * @param partData `Iterator[Double]` 1 partition of a sorted RDD + * @param n `Double` the total size of the RDD + * @param cdf `Double => Double` a function the calculates the theoretical CDF of a value + * @return `Iterator[(Double, Double)] `Unadjusted (ie. off by a constant) potential extrema + * in a partition. The first element corresponds to the (empirical CDF - 1/N) - CDF, + * the second element corresponds to empirical CDF - CDF. We can then search the resulting + * iterator for the minimum of the first and the maximum of the second element, and provide + * this as a partition's candidate extrema + */ + private def oneSampleDifferences(partData: Iterator[Double], n: Double, cdf: Double => Double) + : Iterator[(Double, Double)] = { + // zip data with index (within that partition) + // calculate local (unadjusted) empirical CDF and subtract CDF + partData.zipWithIndex.map { case (v, ix) => + // dp and dl are later adjusted by constant, when global info is available + val dp = (ix + 1) / n + val dl = ix / n + val cdfVal = cdf(v) + (dl - cdfVal, dp - cdfVal) + } + } + + /** + * Search the unadjusted differences in a partition and return the + * two extrema (furthest below and furthest above CDF), along with a count of elements in that + * partition + * @param partDiffs `Iterator[(Double, Double)]` the unadjusted differences between empirical CDF + * and CDFin a partition, which come as a tuple of + * (empirical CDF - 1/N - CDF, empirical CDF - CDF) + * @return `Iterator[(Double, Double, Double)]` the local extrema and a count of elements + */ + private def searchOneSampleCandidates(partDiffs: Iterator[(Double, Double)]) + : Iterator[(Double, Double, Double)] = { + val initAcc = (Double.MaxValue, Double.MinValue, 0.0) + val pResults = partDiffs.foldLeft(initAcc) { case ((pMin, pMax, pCt), (dl, dp)) => + (math.min(pMin, dl), math.max(pMax, dp), pCt + 1) + } + val results = if (pResults == initAcc) Array[(Double, Double, Double)]() else Array(pResults) + results.iterator + } + + /** + * Find the global maximum distance between empirical CDF and CDF (i.e. the KS statistic) after + * adjusting local extrema estimates from individual partitions with the amount of elements in + * preceding partitions + * @param localData `Array[(Double, Double, Double)]` A local array containing the collected + * results of `searchOneSampleCandidates` across all partitions + * @param n `Double`The size of the RDD + * @return The one-sample Kolmogorov Smirnov Statistic + */ + private def searchOneSampleStatistic(localData: Array[(Double, Double, Double)], n: Double) + : Double = { + val initAcc = (Double.MinValue, 0.0) + // adjust differences based on the number of elements preceding it, which should provide + // the correct distance between empirical CDF and CDF + val results = localData.foldLeft(initAcc) { case ((prevMax, prevCt), (minCand, maxCand, ct)) => + val adjConst = prevCt / n + val dist1 = math.abs(minCand + adjConst) + val dist2 = math.abs(maxCand + adjConst) + val maxVal = Array(prevMax, dist1, dist2).max + (maxVal, prevCt + ct) + } + results._1 + } + /** * A convenience function that allows running the KS test for 1 set of sample data against * a named distribution -- cgit v1.2.3 From 1018a1c1eb33eefbfb9025fac7a1cdafc5cbf8f8 Mon Sep 17 00:00:00 2001 From: Timothy Hunter Date: Wed, 13 Apr 2016 11:06:42 -0700 Subject: [SPARK-14568][ML] Instrumentation framework for logistic regression ## What changes were proposed in this pull request? This adds extra logging information about a `LogisticRegression` estimator when being fit on a dataset. With this PR, you see the following extra lines when running the example in the documentation: ``` 16/04/13 07:19:00 INFO Instrumentation: Instrumentation(LogisticRegression-logreg_55dd3c09f164-1230977381-1): training: numPartitions=1 storageLevel=StorageLevel(disk=true, memory=true, offheap=false, deserialized=true, replication=1) 16/04/13 07:19:00 INFO Instrumentation: Instrumentation(LogisticRegression-logreg_55dd3c09f164-1230977381-1): {"regParam":0.3,"elasticNetParam":0.8,"maxIter":10} ... 16/04/12 11:48:07 INFO Instrumentation: Instrumentation(LogisticRegression-logreg_a89eb23cb386-358781145):numClasses=2 16/04/12 11:48:07 INFO Instrumentation: Instrumentation(LogisticRegression-logreg_a89eb23cb386-358781145):numFeatures=692 ... 16/04/13 07:19:01 INFO Instrumentation: Instrumentation(LogisticRegression-logreg_55dd3c09f164-1230977381-1): training finished ``` ## How was this patch tested? This PR was manually tested. Author: Timothy Hunter Closes #12331 from thunterdb/1604-instrumentation. --- .../ml/classification/LogisticRegression.scala | 11 +- .../org/apache/spark/ml/util/Instrumentation.scala | 117 +++++++++++++++++++++ 2 files changed, 127 insertions(+), 1 deletion(-) create mode 100644 mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index 4a3fe5c663..c2b440059b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -273,6 +273,10 @@ class LogisticRegression @Since("1.2.0") ( if (handlePersistence) instances.persist(StorageLevel.MEMORY_AND_DISK) + val instr = Instrumentation.create(this, instances) + instr.logParams(regParam, elasticNetParam, standardization, threshold, + maxIter, tol, fitIntercept) + val (summarizer, labelSummarizer) = { val seqOp = (c: (MultivariateOnlineSummarizer, MultiClassSummarizer), instance: Instance) => @@ -291,6 +295,9 @@ class LogisticRegression @Since("1.2.0") ( val numClasses = histogram.length val numFeatures = summarizer.mean.size + instr.logNumClasses(numClasses) + instr.logNumFeatures(numFeatures) + val (coefficients, intercept, objectiveHistory) = { if (numInvalid != 0) { val msg = s"Classification labels should be in {0 to ${numClasses - 1} " + @@ -444,7 +451,9 @@ class LogisticRegression @Since("1.2.0") ( $(labelCol), $(featuresCol), objectiveHistory) - model.setSummary(logRegSummary) + val m = model.setSummary(logRegSummary) + instr.logSuccess(m) + m } @Since("1.4.0") diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala b/mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala new file mode 100644 index 0000000000..7e57cefc44 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala @@ -0,0 +1,117 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.util + +import java.util.concurrent.atomic.AtomicLong + +import org.json4s._ +import org.json4s.JsonDSL._ +import org.json4s.jackson.JsonMethods._ + +import org.apache.spark.internal.Logging +import org.apache.spark.ml.{Estimator, Model} +import org.apache.spark.ml.param.Param +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.Dataset + +/** + * A small wrapper that defines a training session for an estimator, and some methods to log + * useful information during this session. + * + * A new instance is expected to be created within fit(). + * + * @param estimator the estimator that is being fit + * @param dataset the training dataset + * @tparam E the type of the estimator + */ +private[ml] class Instrumentation[E <: Estimator[_]] private ( + estimator: E, dataset: RDD[_]) extends Logging { + + private val id = Instrumentation.counter.incrementAndGet() + private val prefix = { + val className = estimator.getClass.getSimpleName + s"$className-${estimator.uid}-${dataset.hashCode()}-$id: " + } + + init() + + private def init(): Unit = { + log(s"training: numPartitions=${dataset.partitions.length}" + + s" storageLevel=${dataset.getStorageLevel}") + } + + /** + * Logs a message with a prefix that uniquely identifies the training session. + */ + def log(msg: String): Unit = { + logInfo(prefix + msg) + } + + /** + * Logs the value of the given parameters for the estimator being used in this session. + */ + def logParams(params: Param[_]*): Unit = { + val pairs: Seq[(String, JValue)] = for { + p <- params + value <- estimator.get(p) + } yield { + val cast = p.asInstanceOf[Param[Any]] + p.name -> parse(cast.jsonEncode(value)) + } + log(compact(render(map2jvalue(pairs.toMap)))) + } + + def logNumFeatures(num: Long): Unit = { + log(compact(render("numFeatures" -> num))) + } + + def logNumClasses(num: Long): Unit = { + log(compact(render("numClasses" -> num))) + } + + /** + * Logs the successful completion of the training session and the value of the learned model. + */ + def logSuccess(model: Model[_]): Unit = { + log(s"training finished") + } +} + +/** + * Some common methods for logging information about a training session. + */ +private[ml] object Instrumentation { + private val counter = new AtomicLong(0) + + /** + * Creates an instrumentation object for a training session. + */ + def create[E <: Estimator[_]]( + estimator: E, dataset: Dataset[_]): Instrumentation[E] = { + create[E](estimator, dataset.rdd) + } + + /** + * Creates an instrumentation object for a training session. + */ + def create[E <: Estimator[_]]( + estimator: E, dataset: RDD[_]): Instrumentation[E] = { + new Instrumentation[E](estimator, dataset) + } + +} -- cgit v1.2.3 From 7d2ed8cc030f3d84fea47fded072c320c3d87ca7 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Wed, 13 Apr 2016 11:08:34 -0700 Subject: [SPARK-14388][SQL] Implement CREATE TABLE ## What changes were proposed in this pull request? This patch implements the `CREATE TABLE` command using the `SessionCatalog`. Previously we handled only `CTAS` and `CREATE TABLE ... USING`. This requires us to refactor `CatalogTable` to accept various fields (e.g. bucket and skew columns) and pass them to Hive. WIP: Note that I haven't verified whether this actually works yet! But I believe it does. ## How was this patch tested? Tests will come in a future commit. Author: Andrew Or Author: Yin Huai Closes #12271 from andrewor14/create-table-ddl. --- .../apache/spark/sql/catalyst/parser/SqlBase.g4 | 3 +- .../spark/sql/catalyst/catalog/interface.scala | 26 ++- .../sql/catalyst/catalog/CatalogTestCases.scala | 8 +- .../spark/sql/execution/SparkSqlParser.scala | 25 +-- .../apache/spark/sql/execution/command/ddl.scala | 23 --- .../spark/sql/execution/command/tables.scala | 80 ++++++++ .../sql/execution/command/DDLCommandSuite.scala | 20 +- .../spark/sql/execution/command/DDLSuite.scala | 2 - .../spark/sql/hive/thriftserver/CliSuite.scala | 8 +- .../hive/execution/HiveCompatibilitySuite.scala | 128 ++++++------ .../spark/sql/hive/HiveMetastoreCatalog.scala | 16 +- .../spark/sql/hive/client/HiveClientImpl.scala | 44 ++-- .../spark/sql/hive/execution/HiveSqlParser.scala | 223 +++++++++++++-------- .../spark/sql/hive/HiveDDLCommandSuite.scala | 217 ++++++++++++++++++-- .../spark/sql/hive/HiveMetastoreCatalogSuite.scala | 4 +- .../spark/sql/hive/InsertIntoHiveTableSuite.scala | 2 +- .../spark/sql/hive/execution/PruningSuite.scala | 2 +- .../spark/sql/hive/execution/SQLQuerySuite.scala | 2 +- 18 files changed, 571 insertions(+), 262 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index 0e2cd39448..a937ad1eb7 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -272,8 +272,7 @@ createFileFormat ; fileFormat - : INPUTFORMAT inFmt=STRING OUTPUTFORMAT outFmt=STRING (SERDE serdeCls=STRING)? - (INPUTDRIVER inDriver=STRING OUTPUTDRIVER outDriver=STRING)? #tableFileFormat + : INPUTFORMAT inFmt=STRING OUTPUTFORMAT outFmt=STRING (SERDE serdeCls=STRING)? #tableFileFormat | identifier #genericFileFormat ; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala index 4ef59316ce..ad989a97e4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala @@ -220,14 +220,30 @@ case class CatalogTable( tableType: CatalogTableType, storage: CatalogStorageFormat, schema: Seq[CatalogColumn], - partitionColumns: Seq[CatalogColumn] = Seq.empty, - sortColumns: Seq[CatalogColumn] = Seq.empty, - numBuckets: Int = 0, + partitionColumnNames: Seq[String] = Seq.empty, + sortColumnNames: Seq[String] = Seq.empty, + bucketColumnNames: Seq[String] = Seq.empty, + numBuckets: Int = -1, createTime: Long = System.currentTimeMillis, - lastAccessTime: Long = System.currentTimeMillis, + lastAccessTime: Long = -1, properties: Map[String, String] = Map.empty, viewOriginalText: Option[String] = None, - viewText: Option[String] = None) { + viewText: Option[String] = None, + comment: Option[String] = None) { + + // Verify that the provided columns are part of the schema + private val colNames = schema.map(_.name).toSet + private def requireSubsetOfSchema(cols: Seq[String], colType: String): Unit = { + require(cols.toSet.subsetOf(colNames), s"$colType columns (${cols.mkString(", ")}) " + + s"must be a subset of schema (${colNames.mkString(", ")}) in table '$identifier'") + } + requireSubsetOfSchema(partitionColumnNames, "partition") + requireSubsetOfSchema(sortColumnNames, "sort") + requireSubsetOfSchema(bucketColumnNames, "bucket") + + /** Columns this table is partitioned by. */ + def partitionColumns: Seq[CatalogColumn] = + schema.filter { c => partitionColumnNames.contains(c.name) } /** Return the database this table was specified to belong to, assuming it exists. */ def database: String = identifier.database.getOrElse { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/CatalogTestCases.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/CatalogTestCases.scala index 0d9b0851fa..f961fe3292 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/CatalogTestCases.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/CatalogTestCases.scala @@ -553,8 +553,12 @@ abstract class CatalogTestUtils { identifier = TableIdentifier(name, database), tableType = CatalogTableType.EXTERNAL_TABLE, storage = storageFormat, - schema = Seq(CatalogColumn("col1", "int"), CatalogColumn("col2", "string")), - partitionColumns = Seq(CatalogColumn("a", "int"), CatalogColumn("b", "string"))) + schema = Seq( + CatalogColumn("col1", "int"), + CatalogColumn("col2", "string"), + CatalogColumn("a", "int"), + CatalogColumn("b", "string")), + partitionColumnNames = Seq("a", "b")) } def newFunc(name: String, database: Option[String] = None): CatalogFunction = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala index 73d9640c35..af92cecee5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala @@ -179,7 +179,9 @@ class SparkSqlAstBuilder extends AstBuilder { } } - /** Type to keep track of a table header. */ + /** + * Type to keep track of a table header: (identifier, isTemporary, ifNotExists, isExternal). + */ type TableHeader = (TableIdentifier, Boolean, Boolean, Boolean) /** @@ -616,10 +618,7 @@ class SparkSqlAstBuilder extends AstBuilder { case s: GenericFileFormatContext => (Seq.empty[String], Option(s.identifier.getText)) case s: TableFileFormatContext => - val elements = Seq(s.inFmt, s.outFmt) ++ - Option(s.serdeCls).toSeq ++ - Option(s.inDriver).toSeq ++ - Option(s.outDriver).toSeq + val elements = Seq(s.inFmt, s.outFmt) ++ Option(s.serdeCls).toSeq (elements.map(string), None) } AlterTableSetFileFormat( @@ -773,22 +772,6 @@ class SparkSqlAstBuilder extends AstBuilder { .map(_.identifier.getText)) } - /** - * Create a skew specification. This contains three components: - * - The Skewed Columns - * - Values for which are skewed. The size of each entry must match the number of skewed columns. - * - A store in directory flag. - */ - override def visitSkewSpec( - ctx: SkewSpecContext): (Seq[String], Seq[Seq[String]], Boolean) = withOrigin(ctx) { - val skewedValues = if (ctx.constantList != null) { - Seq(visitConstantList(ctx.constantList)) - } else { - visitNestedConstantList(ctx.nestedConstantList) - } - (visitIdentifierList(ctx.identifierList), skewedValues, ctx.DIRECTORIES != null) - } - /** * Convert a nested constants list into a sequence of string sequences. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala index 5137bd11d8..234099ad15 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala @@ -224,29 +224,6 @@ case class DropTable( } } -/** - * A command that renames a table/view. - * - * The syntax of this command is: - * {{{ - * ALTER TABLE table1 RENAME TO table2; - * ALTER VIEW view1 RENAME TO view2; - * }}} - */ -case class AlterTableRename( - oldName: TableIdentifier, - newName: TableIdentifier) - extends RunnableCommand { - - override def run(sqlContext: SQLContext): Seq[Row] = { - val catalog = sqlContext.sessionState.catalog - catalog.invalidateTable(oldName) - catalog.renameTable(oldName, newName) - Seq.empty[Row] - } - -} - /** * A command that sets table/view properties. * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala new file mode 100644 index 0000000000..9c6030502d --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.command + +import org.apache.spark.sql.{Row, SQLContext} +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.catalog.CatalogTable + + +// TODO: move the rest of the table commands from ddl.scala to this file + +/** + * A command to create a table. + * + * Note: This is currently used only for creating Hive tables. + * This is not intended for temporary tables. + * + * The syntax of using this command in SQL is: + * {{{ + * CREATE [EXTERNAL] TABLE [IF NOT EXISTS] [db_name.]table_name + * [(col1 data_type [COMMENT col_comment], ...)] + * [COMMENT table_comment] + * [PARTITIONED BY (col3 data_type [COMMENT col_comment], ...)] + * [CLUSTERED BY (col1, ...) [SORTED BY (col1 [ASC|DESC], ...)] INTO num_buckets BUCKETS] + * [SKEWED BY (col1, col2, ...) ON ((col_value, col_value, ...), ...) + * [STORED AS DIRECTORIES] + * [ROW FORMAT row_format] + * [STORED AS file_format | STORED BY storage_handler_class [WITH SERDEPROPERTIES (...)]] + * [LOCATION path] + * [TBLPROPERTIES (property_name=property_value, ...)] + * [AS select_statement]; + * }}} + */ +case class CreateTable(table: CatalogTable, ifNotExists: Boolean) extends RunnableCommand { + + override def run(sqlContext: SQLContext): Seq[Row] = { + sqlContext.sessionState.catalog.createTable(table, ifNotExists) + Seq.empty[Row] + } + +} + + +/** + * A command that renames a table/view. + * + * The syntax of this command is: + * {{{ + * ALTER TABLE table1 RENAME TO table2; + * ALTER VIEW view1 RENAME TO view2; + * }}} + */ +case class AlterTableRename( + oldName: TableIdentifier, + newName: TableIdentifier) + extends RunnableCommand { + + override def run(sqlContext: SQLContext): Seq[Row] = { + val catalog = sqlContext.sessionState.catalog + catalog.invalidateTable(oldName) + catalog.renameTable(oldName, newName) + Seq.empty[Row] + } + +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala index 1c8dd68286..6e6475ee29 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala @@ -440,37 +440,25 @@ class DDLCommandSuite extends PlanTest { } test("alter table: set file format") { - val sql1 = - """ - |ALTER TABLE table_name SET FILEFORMAT INPUTFORMAT 'test' - |OUTPUTFORMAT 'test' SERDE 'test' INPUTDRIVER 'test' OUTPUTDRIVER 'test' - """.stripMargin - val sql2 = "ALTER TABLE table_name SET FILEFORMAT INPUTFORMAT 'test' " + + val sql1 = "ALTER TABLE table_name SET FILEFORMAT INPUTFORMAT 'test' " + "OUTPUTFORMAT 'test' SERDE 'test'" - val sql3 = "ALTER TABLE table_name PARTITION (dt='2008-08-08', country='us') " + + val sql2 = "ALTER TABLE table_name PARTITION (dt='2008-08-08', country='us') " + "SET FILEFORMAT PARQUET" val parsed1 = parser.parsePlan(sql1) val parsed2 = parser.parsePlan(sql2) - val parsed3 = parser.parsePlan(sql3) val tableIdent = TableIdentifier("table_name", None) val expected1 = AlterTableSetFileFormat( tableIdent, None, - List("test", "test", "test", "test", "test"), + List("test", "test", "test"), None)(sql1) val expected2 = AlterTableSetFileFormat( - tableIdent, - None, - List("test", "test", "test"), - None)(sql2) - val expected3 = AlterTableSetFileFormat( tableIdent, Some(Map("dt" -> "2008-08-08", "country" -> "us")), Seq(), - Some("PARQUET"))(sql3) + Some("PARQUET"))(sql2) comparePlans(parsed1, expected1) comparePlans(parsed2, expected2) - comparePlans(parsed3, expected3) } test("alter table: set location") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala index 40a8b0e614..9ffffa0bdd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala @@ -380,8 +380,6 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { } } - // TODO: ADD a testcase for Drop Database in Restric when we can create tables in SQLContext - test("show tables") { withTempTable("show1a", "show2b") { sql( diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala index bfc3d195ff..eb49eabcb1 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala @@ -162,7 +162,7 @@ class CliSuite extends SparkFunSuite with BeforeAndAfterAll with Logging { runCliWithin(3.minute)( "CREATE TABLE hive_test(key INT, val STRING);" - -> "OK", + -> "", "SHOW TABLES;" -> "hive_test", s"LOAD DATA LOCAL INPATH '$dataFilePath' OVERWRITE INTO TABLE hive_test;" @@ -187,7 +187,7 @@ class CliSuite extends SparkFunSuite with BeforeAndAfterAll with Logging { "USE hive_test_db;" -> "", "CREATE TABLE hive_test(key INT, val STRING);" - -> "OK", + -> "", "SHOW TABLES;" -> "hive_test" ) @@ -210,9 +210,9 @@ class CliSuite extends SparkFunSuite with BeforeAndAfterAll with Logging { """CREATE TABLE t1(key string, val string) |ROW FORMAT SERDE 'org.apache.hive.hcatalog.data.JsonSerDe'; """.stripMargin - -> "OK", + -> "", "CREATE TABLE sourceTable (key INT, val STRING);" - -> "OK", + -> "", s"LOAD DATA LOCAL INPATH '$dataFilePath' OVERWRITE INTO TABLE sourceTable;" -> "OK", "INSERT INTO TABLE t1 SELECT key, val FROM sourceTable;" diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala index f0eeda09db..a45d180464 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala @@ -366,10 +366,76 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "sort_merge_join_desc_6", "sort_merge_join_desc_7", + // These tests try to create a table with bucketed columns, which we don't support + "auto_join32", + "auto_join_filters", + "auto_smb_mapjoin_14", + "ct_case_insensitive", + "explain_rearrange", + "groupby_sort_10", + "groupby_sort_2", + "groupby_sort_3", + "groupby_sort_4", + "groupby_sort_5", + "groupby_sort_7", + "groupby_sort_8", + "groupby_sort_9", + "groupby_sort_test_1", + "inputddl4", + "join_filters", + "join_nulls", + "join_nullsafe", + "load_dyn_part2", + "orc_empty_files", + "reduce_deduplicate", + "smb_mapjoin9", + "smb_mapjoin_1", + "smb_mapjoin_10", + "smb_mapjoin_13", + "smb_mapjoin_14", + "smb_mapjoin_15", + "smb_mapjoin_16", + "smb_mapjoin_17", + "smb_mapjoin_2", + "smb_mapjoin_21", + "smb_mapjoin_25", + "smb_mapjoin_3", + "smb_mapjoin_4", + "smb_mapjoin_5", + "smb_mapjoin_6", + "smb_mapjoin_7", + "smb_mapjoin_8", + "sort_merge_join_desc_1", + "sort_merge_join_desc_2", + "sort_merge_join_desc_3", + "sort_merge_join_desc_4", + + // These tests try to create a table with skewed columns, which we don't support + "create_skewed_table1", + "skewjoinopt13", + "skewjoinopt18", + "skewjoinopt9", + // Index commands are not supported "drop_index", "drop_index_removes_partition_dirs", "alter_index", + "auto_sortmerge_join_1", + "auto_sortmerge_join_10", + "auto_sortmerge_join_11", + "auto_sortmerge_join_12", + "auto_sortmerge_join_13", + "auto_sortmerge_join_14", + "auto_sortmerge_join_15", + "auto_sortmerge_join_16", + "auto_sortmerge_join_2", + "auto_sortmerge_join_3", + "auto_sortmerge_join_4", + "auto_sortmerge_join_5", + "auto_sortmerge_join_6", + "auto_sortmerge_join_7", + "auto_sortmerge_join_8", + "auto_sortmerge_join_9", // Macro commands are not supported "macro", @@ -435,33 +501,14 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "auto_join3", "auto_join30", "auto_join31", - "auto_join32", "auto_join4", "auto_join5", "auto_join6", "auto_join7", "auto_join8", "auto_join9", - "auto_join_filters", "auto_join_nulls", "auto_join_reordering_values", - "auto_smb_mapjoin_14", - "auto_sortmerge_join_1", - "auto_sortmerge_join_10", - "auto_sortmerge_join_11", - "auto_sortmerge_join_12", - "auto_sortmerge_join_13", - "auto_sortmerge_join_14", - "auto_sortmerge_join_15", - "auto_sortmerge_join_16", - "auto_sortmerge_join_2", - "auto_sortmerge_join_3", - "auto_sortmerge_join_4", - "auto_sortmerge_join_5", - "auto_sortmerge_join_6", - "auto_sortmerge_join_7", - "auto_sortmerge_join_8", - "auto_sortmerge_join_9", "binary_constant", "binarysortable_1", "cast1", @@ -492,13 +539,11 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "create_insert_outputformat", "create_like_tbl_props", "create_nested_type", - "create_skewed_table1", "create_struct_table", "create_view_translate", "cross_join", "cross_product_check_1", "cross_product_check_2", - "ct_case_insensitive", "database_drop", "database_location", "database_properties", @@ -534,7 +579,6 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "escape_distributeby1", "escape_orderby1", "escape_sortby1", - "explain_rearrange", "fileformat_mix", "fileformat_sequencefile", "fileformat_text", @@ -589,16 +633,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "groupby_neg_float", "groupby_ppd", "groupby_ppr", - "groupby_sort_10", - "groupby_sort_2", - "groupby_sort_3", - "groupby_sort_4", - "groupby_sort_5", "groupby_sort_6", - "groupby_sort_7", - "groupby_sort_8", - "groupby_sort_9", - "groupby_sort_test_1", "having", "implicit_cast1", "index_serde", @@ -653,7 +688,6 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "inputddl1", "inputddl2", "inputddl3", - "inputddl4", "inputddl6", "inputddl7", "inputddl8", @@ -709,11 +743,8 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "join_array", "join_casesensitive", "join_empty", - "join_filters", "join_hive_626", "join_map_ppr", - "join_nulls", - "join_nullsafe", "join_rc", "join_reorder2", "join_reorder3", @@ -737,7 +768,6 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "load_dyn_part13", "load_dyn_part14", "load_dyn_part14_win", - "load_dyn_part2", "load_dyn_part3", "load_dyn_part4", "load_dyn_part5", @@ -790,7 +820,6 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "nullscript", "optional_outer", "orc_dictionary_threshold", - "orc_empty_files", "order", "order2", "outer_join_ppr", @@ -846,7 +875,6 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "rcfile_null_value", "rcfile_toleratecorruptions", "rcfile_union", - "reduce_deduplicate", "reduce_deduplicate_exclude_gby", "reduce_deduplicate_exclude_join", "reduce_deduplicate_extended", @@ -867,31 +895,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "show_functions", "show_partitions", "show_tblproperties", - "skewjoinopt13", - "skewjoinopt18", - "skewjoinopt9", - "smb_mapjoin9", - "smb_mapjoin_1", - "smb_mapjoin_10", - "smb_mapjoin_13", - "smb_mapjoin_14", - "smb_mapjoin_15", - "smb_mapjoin_16", - "smb_mapjoin_17", - "smb_mapjoin_2", - "smb_mapjoin_21", - "smb_mapjoin_25", - "smb_mapjoin_3", - "smb_mapjoin_4", - "smb_mapjoin_5", - "smb_mapjoin_6", - "smb_mapjoin_7", - "smb_mapjoin_8", "sort", - "sort_merge_join_desc_1", - "sort_merge_join_desc_2", - "sort_merge_join_desc_3", - "sort_merge_join_desc_4", "stats0", "stats_aggregator_error_1", "stats_empty_partition", diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index 14f331961e..ccc8345d73 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -91,7 +91,7 @@ private[hive] object HiveSerDe { "textfile" -> HiveSerDe( inputFormat = Option("org.apache.hadoop.mapred.TextInputFormat"), - outputFormat = Option("org.apache.hadoop.hive.ql.io.IgnoreKeyTextOutputFormat")), + outputFormat = Option("org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat")), "avro" -> HiveSerDe( @@ -905,8 +905,13 @@ private[hive] case class MetastoreRelation( val sd = new org.apache.hadoop.hive.metastore.api.StorageDescriptor() tTable.setSd(sd) - sd.setCols(table.schema.map(toHiveColumn).asJava) - tTable.setPartitionKeys(table.partitionColumns.map(toHiveColumn).asJava) + + // Note: In Hive the schema and partition columns must be disjoint sets + val (partCols, schema) = table.schema.map(toHiveColumn).partition { c => + table.partitionColumnNames.contains(c.getName) + } + sd.setCols(schema.asJava) + tTable.setPartitionKeys(partCols.asJava) table.storage.locationUri.foreach(sd.setLocation) table.storage.inputFormat.foreach(sd.setInputFormat) @@ -1013,7 +1018,10 @@ private[hive] case class MetastoreRelation( val partitionKeys = table.partitionColumns.map(_.toAttribute) /** Non-partitionKey attributes */ - val attributes = table.schema.map(_.toAttribute) + // TODO: just make this hold the schema itself, not just non-partition columns + val attributes = table.schema + .filter { c => !table.partitionColumnNames.contains(c.name) } + .map(_.toAttribute) val output = attributes ++ partitionKeys diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala index 39e26acd7f..2a1fff92b5 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala @@ -299,6 +299,10 @@ private[hive] class HiveClientImpl( tableName: String): Option[CatalogTable] = withHiveState { logDebug(s"Looking up $dbName.$tableName") Option(client.getTable(dbName, tableName, false)).map { h => + // Note: Hive separates partition columns and the schema, but for us the + // partition columns are part of the schema + val partCols = h.getPartCols.asScala.map(fromHiveColumn) + val schema = h.getCols.asScala.map(fromHiveColumn) ++ partCols CatalogTable( identifier = TableIdentifier(h.getTableName, Option(h.getDbName)), tableType = h.getTableType match { @@ -307,9 +311,10 @@ private[hive] class HiveClientImpl( case HiveTableType.INDEX_TABLE => CatalogTableType.INDEX_TABLE case HiveTableType.VIRTUAL_VIEW => CatalogTableType.VIRTUAL_VIEW }, - schema = h.getCols.asScala.map(fromHiveColumn), - partitionColumns = h.getPartCols.asScala.map(fromHiveColumn), - sortColumns = Seq(), + schema = schema, + partitionColumnNames = partCols.map(_.name), + sortColumnNames = Seq(), // TODO: populate this + bucketColumnNames = h.getBucketCols.asScala, numBuckets = h.getNumBuckets, createTime = h.getTTable.getCreateTime.toLong * 1000, lastAccessTime = h.getLastAccessTime.toLong * 1000, @@ -675,24 +680,37 @@ private[hive] class HiveClientImpl( private def toHiveTable(table: CatalogTable): HiveTable = { val hiveTable = new HiveTable(table.database, table.identifier.table) - // For EXTERNAL_TABLE/MANAGED_TABLE, we also need to set EXTERNAL field in - // the table properties accodringly. Otherwise, if EXTERNAL_TABLE is the table type - // but EXTERNAL field is not set, Hive metastore will change the type to - // MANAGED_TABLE (see - // metastore/src/java/org/apache/hadoop/hive/metastore/ObjectStore.java#L1095-L1105) + // For EXTERNAL_TABLE, we also need to set EXTERNAL field in the table properties. + // Otherwise, Hive metastore will change the table to a MANAGED_TABLE. + // (metastore/src/java/org/apache/hadoop/hive/metastore/ObjectStore.java#L1095-L1105) hiveTable.setTableType(table.tableType match { case CatalogTableType.EXTERNAL_TABLE => hiveTable.setProperty("EXTERNAL", "TRUE") HiveTableType.EXTERNAL_TABLE case CatalogTableType.MANAGED_TABLE => - hiveTable.setProperty("EXTERNAL", "FALSE") HiveTableType.MANAGED_TABLE case CatalogTableType.INDEX_TABLE => HiveTableType.INDEX_TABLE case CatalogTableType.VIRTUAL_VIEW => HiveTableType.VIRTUAL_VIEW }) - hiveTable.setFields(table.schema.map(toHiveColumn).asJava) - hiveTable.setPartCols(table.partitionColumns.map(toHiveColumn).asJava) + // Note: In Hive the schema and partition columns must be disjoint sets + val (partCols, schema) = table.schema.map(toHiveColumn).partition { c => + table.partitionColumnNames.contains(c.getName) + } + if (table.schema.isEmpty) { + // This is a hack to preserve existing behavior. Before Spark 2.0, we do not + // set a default serde here (this was done in Hive), and so if the user provides + // an empty schema Hive would automatically populate the schema with a single + // field "col". However, after SPARK-14388, we set the default serde to + // LazySimpleSerde so this implicit behavior no longer happens. Therefore, + // we need to do it in Spark ourselves. + hiveTable.setFields( + Seq(new FieldSchema("col", "array", "from deserializer")).asJava) + } else { + hiveTable.setFields(schema.asJava) + } + hiveTable.setPartCols(partCols.asJava) // TODO: set sort columns here too + hiveTable.setBucketCols(table.bucketColumnNames.asJava) hiveTable.setOwner(conf.getUser) hiveTable.setNumBuckets(table.numBuckets) hiveTable.setCreateTime((table.createTime / 1000).toInt) @@ -700,9 +718,11 @@ private[hive] class HiveClientImpl( table.storage.locationUri.foreach { loc => shim.setDataLocation(hiveTable, loc) } table.storage.inputFormat.map(toInputFormat).foreach(hiveTable.setInputFormatClass) table.storage.outputFormat.map(toOutputFormat).foreach(hiveTable.setOutputFormatClass) - table.storage.serde.foreach(hiveTable.setSerializationLib) + hiveTable.setSerializationLib( + table.storage.serde.getOrElse("org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe")) table.storage.serdeProperties.foreach { case (k, v) => hiveTable.setSerdeParam(k, v) } table.properties.foreach { case (k, v) => hiveTable.setProperty(k, v) } + table.comment.foreach { c => hiveTable.setProperty("comment", c) } table.viewOriginalText.foreach { t => hiveTable.setViewOriginalText(t) } table.viewText.foreach { t => hiveTable.setViewExpandedText(t) } hiveTable diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveSqlParser.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveSqlParser.scala index 7a435117e7..b14db7fe71 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveSqlParser.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveSqlParser.scala @@ -26,6 +26,7 @@ import org.apache.hadoop.hive.ql.session.SessionState import org.apache.hadoop.hive.serde.serdeConstants import org.apache.hadoop.hive.serde2.`lazy`.LazySimpleSerDe +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.analysis.UnresolvedGenerator import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.expressions._ @@ -33,8 +34,9 @@ import org.apache.spark.sql.catalyst.parser._ import org.apache.spark.sql.catalyst.parser.SqlBaseParser._ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.SparkSqlAstBuilder +import org.apache.spark.sql.execution.command.CreateTable import org.apache.spark.sql.hive.{CreateTableAsSelect => CTAS, CreateViewAsSelect => CreateView} -import org.apache.spark.sql.hive.{HiveGenericUDTF, HiveSerDe} +import org.apache.spark.sql.hive.{HiveGenericUDTF, HiveMetastoreTypes, HiveSerDe} import org.apache.spark.sql.hive.HiveShim.HiveFunctionWrapper /** @@ -121,84 +123,116 @@ class HiveSqlAstBuilder extends SparkSqlAstBuilder { } /** - * Create a [[CatalogStorageFormat]]. This is part of the [[CreateTableAsSelect]] command. + * Create a [[CatalogStorageFormat]] for creating tables. */ override def visitCreateFileFormat( ctx: CreateFileFormatContext): CatalogStorageFormat = withOrigin(ctx) { - if (ctx.storageHandler == null) { - typedVisit[CatalogStorageFormat](ctx.fileFormat) - } else { - visitStorageHandler(ctx.storageHandler) + (ctx.fileFormat, ctx.storageHandler) match { + // Expected format: INPUTFORMAT input_format OUTPUTFORMAT output_format + case (c: TableFileFormatContext, null) => + visitTableFileFormat(c) + // Expected format: SEQUENCEFILE | TEXTFILE | RCFILE | ORC | PARQUET | AVRO + case (c: GenericFileFormatContext, null) => + visitGenericFileFormat(c) + case (null, storageHandler) => + throw new ParseException("Operation not allowed: ... STORED BY storage_handler ...", ctx) + case _ => + throw new ParseException("expected either STORED AS or STORED BY, not both", ctx) } } /** - * Create a [[CreateTableAsSelect]] command. + * Create a table, returning either a [[CreateTable]] or a [[CreateTableAsSelect]]. + * + * This is not used to create datasource tables, which is handled through + * "CREATE TABLE ... USING ...". + * + * Note: several features are currently not supported - temporary tables, bucketing, + * skewed columns and storage handlers (STORED BY). + * + * Expected format: + * {{{ + * CREATE [TEMPORARY] [EXTERNAL] TABLE [IF NOT EXISTS] [db_name.]table_name + * [(col1 data_type [COMMENT col_comment], ...)] + * [COMMENT table_comment] + * [PARTITIONED BY (col3 data_type [COMMENT col_comment], ...)] + * [CLUSTERED BY (col1, ...) [SORTED BY (col1 [ASC|DESC], ...)] INTO num_buckets BUCKETS] + * [SKEWED BY (col1, col2, ...) ON ((col_value, col_value, ...), ...) [STORED AS DIRECTORIES]] + * [ROW FORMAT row_format] + * [STORED AS file_format | STORED BY storage_handler_class [WITH SERDEPROPERTIES (...)]] + * [LOCATION path] + * [TBLPROPERTIES (property_name=property_value, ...)] + * [AS select_statement]; + * }}} */ - override def visitCreateTable(ctx: CreateTableContext): LogicalPlan = { - if (ctx.query == null) { - HiveNativeCommand(command(ctx)) + override def visitCreateTable(ctx: CreateTableContext): LogicalPlan = withOrigin(ctx) { + val (name, temp, ifNotExists, external) = visitCreateTableHeader(ctx.createTableHeader) + // TODO: implement temporary tables + if (temp) { + throw new ParseException( + "CREATE TEMPORARY TABLE is not supported yet. " + + "Please use registerTempTable as an alternative.", ctx) + } + if (ctx.skewSpec != null) { + throw new ParseException("Operation not allowed: CREATE TABLE ... SKEWED BY ...", ctx) + } + if (ctx.bucketSpec != null) { + throw new ParseException("Operation not allowed: CREATE TABLE ... CLUSTERED BY ...", ctx) + } + val tableType = if (external) { + CatalogTableType.EXTERNAL_TABLE } else { - // Get the table header. - val (table, temp, ifNotExists, external) = visitCreateTableHeader(ctx.createTableHeader) - val tableType = if (external) { - CatalogTableType.EXTERNAL_TABLE - } else { - CatalogTableType.MANAGED_TABLE - } - - // Unsupported clauses. - if (temp) { - throw new ParseException(s"Unsupported operation: TEMPORARY clause.", ctx) - } - if (ctx.bucketSpec != null) { - // TODO add this - we need cluster columns in the CatalogTable for this to work. - throw new ParseException("Unsupported operation: " + - "CLUSTERED BY ... [ORDERED BY ...] INTO ... BUCKETS clause.", ctx) - } - if (ctx.skewSpec != null) { - throw new ParseException("Operation not allowed: " + - "SKEWED BY ... ON ... [STORED AS DIRECTORIES] clause.", ctx) - } - - // Create the schema. - val schema = Option(ctx.columns).toSeq.flatMap(visitCatalogColumns(_, _.toLowerCase)) - - // Get the column by which the table is partitioned. - val partitionCols = Option(ctx.partitionColumns).toSeq.flatMap(visitCatalogColumns(_)) - - // Create the storage. - def format(fmt: ParserRuleContext): CatalogStorageFormat = { - Option(fmt).map(typedVisit[CatalogStorageFormat]).getOrElse(EmptyStorageFormat) - } - // Default storage. + CatalogTableType.MANAGED_TABLE + } + val comment = Option(ctx.STRING).map(string) + val partitionCols = Option(ctx.partitionColumns).toSeq.flatMap(visitCatalogColumns) + val cols = Option(ctx.columns).toSeq.flatMap(visitCatalogColumns) + val properties = Option(ctx.tablePropertyList).map(visitTablePropertyList).getOrElse(Map.empty) + val selectQuery = Option(ctx.query).map(plan) + + // Note: Hive requires partition columns to be distinct from the schema, so we need + // to include the partition columns here explicitly + val schema = cols ++ partitionCols + + // Storage format + val defaultStorage: CatalogStorageFormat = { val defaultStorageType = hiveConf.getVar(HiveConf.ConfVars.HIVEDEFAULTFILEFORMAT) - val hiveSerDe = HiveSerDe.sourceToSerDe(defaultStorageType, hiveConf).getOrElse { - HiveSerDe( - inputFormat = Option("org.apache.hadoop.mapred.TextInputFormat"), - outputFormat = Option("org.apache.hadoop.hive.ql.io.IgnoreKeyTextOutputFormat")) - } - // Defined storage. - val fileStorage = format(ctx.createFileFormat) - val rowStorage = format(ctx.rowFormat) - val storage = CatalogStorageFormat( - Option(ctx.locationSpec).map(visitLocationSpec), - fileStorage.inputFormat.orElse(hiveSerDe.inputFormat), - fileStorage.outputFormat.orElse(hiveSerDe.outputFormat), - rowStorage.serde.orElse(hiveSerDe.serde).orElse(fileStorage.serde), - rowStorage.serdeProperties ++ fileStorage.serdeProperties - ) + val defaultHiveSerde = HiveSerDe.sourceToSerDe(defaultStorageType, hiveConf) + CatalogStorageFormat( + locationUri = None, + inputFormat = defaultHiveSerde.flatMap(_.inputFormat) + .orElse(Some("org.apache.hadoop.mapred.TextInputFormat")), + outputFormat = defaultHiveSerde.flatMap(_.outputFormat) + .orElse(Some("org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat")), + // Note: Keep this unspecified because we use the presence of the serde to decide + // whether to convert a table created by CTAS to a datasource table. + serde = None, + serdeProperties = Map()) + } + val fileStorage = Option(ctx.createFileFormat).map(visitCreateFileFormat) + .getOrElse(EmptyStorageFormat) + val rowStorage = Option(ctx.rowFormat).map(visitRowFormat).getOrElse(EmptyStorageFormat) + val location = Option(ctx.locationSpec).map(visitLocationSpec) + val storage = CatalogStorageFormat( + locationUri = location, + inputFormat = fileStorage.inputFormat.orElse(defaultStorage.inputFormat), + outputFormat = fileStorage.outputFormat.orElse(defaultStorage.outputFormat), + serde = rowStorage.serde.orElse(fileStorage.serde).orElse(defaultStorage.serde), + serdeProperties = rowStorage.serdeProperties ++ fileStorage.serdeProperties) + + // TODO support the sql text - have a proper location for this! + val tableDesc = CatalogTable( + identifier = name, + tableType = tableType, + storage = storage, + schema = schema, + partitionColumnNames = partitionCols.map(_.name), + properties = properties, + comment = comment) - val tableDesc = CatalogTable( - identifier = table, - tableType = tableType, - schema = schema, - partitionColumns = partitionCols, - storage = storage, - properties = Option(ctx.tablePropertyList).map(visitTablePropertyList).getOrElse(Map.empty), - // TODO support the sql text - have a proper location for this! - viewText = Option(ctx.STRING).map(string)) - CTAS(tableDesc, plan(ctx.query), ifNotExists) + selectQuery match { + case Some(q) => CTAS(tableDesc, q, ifNotExists) + case None => CreateTable(tableDesc, ifNotExists) } } @@ -353,25 +387,19 @@ class HiveSqlAstBuilder extends SparkSqlAstBuilder { private val EmptyStorageFormat = CatalogStorageFormat(None, None, None, None, Map.empty) /** - * Create a [[CatalogStorageFormat]]. The INPUTDRIVER and OUTPUTDRIVER clauses are currently - * ignored. + * Create a [[CatalogStorageFormat]]. */ override def visitTableFileFormat( ctx: TableFileFormatContext): CatalogStorageFormat = withOrigin(ctx) { - import ctx._ - if (inDriver != null || outDriver != null) { - throw new ParseException( - s"Operation not allowed: INPUTDRIVER ... OUTPUTDRIVER ... clauses", ctx) - } EmptyStorageFormat.copy( - inputFormat = Option(string(inFmt)), - outputFormat = Option(string(outFmt)), - serde = Option(serdeCls).map(string) + inputFormat = Option(string(ctx.inFmt)), + outputFormat = Option(string(ctx.outFmt)), + serde = Option(ctx.serdeCls).map(string) ) } /** - * Resolve a [[HiveSerDe]] based on the format name given. + * Resolve a [[HiveSerDe]] based on the name given and return it as a [[CatalogStorageFormat]]. */ override def visitGenericFileFormat( ctx: GenericFileFormatContext): CatalogStorageFormat = withOrigin(ctx) { @@ -388,11 +416,28 @@ class HiveSqlAstBuilder extends SparkSqlAstBuilder { } /** - * Storage Handlers are currently not supported in the statements we support (CTAS). + * Create a [[RowFormat]] used for creating tables. + * + * Example format: + * {{{ + * SERDE serde_name [WITH SERDEPROPERTIES (k1=v1, k2=v2, ...)] + * }}} + * + * OR + * + * {{{ + * DELIMITED [FIELDS TERMINATED BY char [ESCAPED BY char]] + * [COLLECTION ITEMS TERMINATED BY char] + * [MAP KEYS TERMINATED BY char] + * [LINES TERMINATED BY char] + * [NULL DEFINED AS char] + * }}} */ - override def visitStorageHandler( - ctx: StorageHandlerContext): CatalogStorageFormat = withOrigin(ctx) { - throw new ParseException("Storage Handlers are currently unsupported.", ctx) + private def visitRowFormat(ctx: RowFormatContext): CatalogStorageFormat = withOrigin(ctx) { + ctx match { + case serde: RowFormatSerdeContext => visitRowFormatSerde(serde) + case delimited: RowFormatDelimitedContext => visitRowFormatDelimited(delimited) + } } /** @@ -435,13 +480,15 @@ class HiveSqlAstBuilder extends SparkSqlAstBuilder { /** * Create a sequence of [[CatalogColumn]]s from a column list */ - private def visitCatalogColumns( - ctx: ColTypeListContext, - formatter: String => String = identity): Seq[CatalogColumn] = withOrigin(ctx) { + private def visitCatalogColumns(ctx: ColTypeListContext): Seq[CatalogColumn] = withOrigin(ctx) { ctx.colType.asScala.map { col => CatalogColumn( - formatter(col.identifier.getText), - col.dataType.getText.toLowerCase, // TODO validate this? + col.identifier.getText.toLowerCase, + // Note: for types like "STRUCT" we can't + // just convert the whole type string to lower case, otherwise the struct field names + // will no longer be case sensitive. Instead, we rely on our parser to get the proper + // case before passing it to Hive. + CatalystSqlParser.parseDataType(col.dataType.getText).simpleString, nullable = true, Option(col.STRING).map(string)) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDDLCommandSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDDLCommandSuite.scala index e8086aec32..68d3ea6ed9 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDDLCommandSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDDLCommandSuite.scala @@ -29,6 +29,7 @@ import org.apache.spark.sql.catalyst.expressions.JsonTuple import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical.{Generate, ScriptTransformation} +import org.apache.spark.sql.execution.command.CreateTable import org.apache.spark.sql.hive.execution.{HiveNativeCommand, HiveSqlParser} class HiveDDLCommandSuite extends PlanTest { @@ -36,6 +37,7 @@ class HiveDDLCommandSuite extends PlanTest { private def extractTableDesc(sql: String): (CatalogTable, Boolean) = { parser.parsePlan(sql).collect { + case CreateTable(desc, allowExisting) => (desc, allowExisting) case CreateTableAsSelect(desc, _, allowExisting) => (desc, allowExisting) case CreateViewAsSelect(desc, _, allowExisting, _, _) => (desc, allowExisting) }.head @@ -76,9 +78,12 @@ class HiveDDLCommandSuite extends PlanTest { CatalogColumn("page_url", "string") :: CatalogColumn("referrer_url", "string") :: CatalogColumn("ip", "string", comment = Some("IP Address of the User")) :: - CatalogColumn("country", "string", comment = Some("country of origination")) :: Nil) + CatalogColumn("country", "string", comment = Some("country of origination")) :: + CatalogColumn("dt", "string", comment = Some("date type")) :: + CatalogColumn("hour", "string", comment = Some("hour of the day")) :: Nil) + assert(desc.comment == Some("This is the staging page view table")) // TODO will be SQLText - assert(desc.viewText == Option("This is the staging page view table")) + assert(desc.viewText.isEmpty) assert(desc.viewOriginalText.isEmpty) assert(desc.partitionColumns == CatalogColumn("dt", "string", comment = Some("date type")) :: @@ -123,9 +128,12 @@ class HiveDDLCommandSuite extends PlanTest { CatalogColumn("page_url", "string") :: CatalogColumn("referrer_url", "string") :: CatalogColumn("ip", "string", comment = Some("IP Address of the User")) :: - CatalogColumn("country", "string", comment = Some("country of origination")) :: Nil) + CatalogColumn("country", "string", comment = Some("country of origination")) :: + CatalogColumn("dt", "string", comment = Some("date type")) :: + CatalogColumn("hour", "string", comment = Some("hour of the day")) :: Nil) // TODO will be SQLText - assert(desc.viewText == Option("This is the staging page view table")) + assert(desc.comment == Some("This is the staging page view table")) + assert(desc.viewText.isEmpty) assert(desc.viewOriginalText.isEmpty) assert(desc.partitionColumns == CatalogColumn("dt", "string", comment = Some("date type")) :: @@ -151,7 +159,7 @@ class HiveDDLCommandSuite extends PlanTest { assert(desc.storage.serdeProperties == Map()) assert(desc.storage.inputFormat == Some("org.apache.hadoop.mapred.TextInputFormat")) assert(desc.storage.outputFormat == - Some("org.apache.hadoop.hive.ql.io.IgnoreKeyTextOutputFormat")) + Some("org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat")) assert(desc.storage.serde.isEmpty) assert(desc.properties == Map()) } @@ -203,17 +211,6 @@ class HiveDDLCommandSuite extends PlanTest { |AS SELECT key, value FROM src ORDER BY key, value """.stripMargin) } - intercept[ParseException] { - parser.parsePlan( - """CREATE TABLE ctas2 - |STORED AS - |INPUTFORMAT "org.apache.hadoop.mapred.TextInputFormat" - |OUTPUTFORMAT "org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat" - |INPUTDRIVER "org.apache.hadoop.hive.howl.rcfile.RCFileInputDriver" - |OUTPUTDRIVER "org.apache.hadoop.hive.howl.rcfile.RCFileOutputDriver" - |AS SELECT key, value FROM src ORDER BY key, value - """.stripMargin) - } intercept[ParseException] { parser.parsePlan( """ @@ -324,6 +321,194 @@ class HiveDDLCommandSuite extends PlanTest { """.stripMargin) } + test("create table - basic") { + val query = "CREATE TABLE my_table (id int, name string)" + val (desc, allowExisting) = extractTableDesc(query) + assert(!allowExisting) + assert(desc.identifier.database.isEmpty) + assert(desc.identifier.table == "my_table") + assert(desc.tableType == CatalogTableType.MANAGED_TABLE) + assert(desc.schema == Seq(CatalogColumn("id", "int"), CatalogColumn("name", "string"))) + assert(desc.partitionColumnNames.isEmpty) + assert(desc.sortColumnNames.isEmpty) + assert(desc.bucketColumnNames.isEmpty) + assert(desc.numBuckets == -1) + assert(desc.viewText.isEmpty) + assert(desc.viewOriginalText.isEmpty) + assert(desc.storage.locationUri.isEmpty) + assert(desc.storage.inputFormat == + Some("org.apache.hadoop.mapred.TextInputFormat")) + assert(desc.storage.outputFormat == + Some("org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat")) + assert(desc.storage.serde.isEmpty) + assert(desc.storage.serdeProperties.isEmpty) + assert(desc.properties.isEmpty) + assert(desc.comment.isEmpty) + } + + test("create table - with database name") { + val query = "CREATE TABLE dbx.my_table (id int, name string)" + val (desc, _) = extractTableDesc(query) + assert(desc.identifier.database == Some("dbx")) + assert(desc.identifier.table == "my_table") + } + + test("create table - temporary") { + val query = "CREATE TEMPORARY TABLE tab1 (id int, name string)" + val e = intercept[ParseException] { parser.parsePlan(query) } + assert(e.message.contains("registerTempTable")) + } + + test("create table - external") { + val query = "CREATE EXTERNAL TABLE tab1 (id int, name string)" + val (desc, _) = extractTableDesc(query) + assert(desc.tableType == CatalogTableType.EXTERNAL_TABLE) + } + + test("create table - if not exists") { + val query = "CREATE TABLE IF NOT EXISTS tab1 (id int, name string)" + val (_, allowExisting) = extractTableDesc(query) + assert(allowExisting) + } + + test("create table - comment") { + val query = "CREATE TABLE my_table (id int, name string) COMMENT 'its hot as hell below'" + val (desc, _) = extractTableDesc(query) + assert(desc.comment == Some("its hot as hell below")) + } + + test("create table - partitioned columns") { + val query = "CREATE TABLE my_table (id int, name string) PARTITIONED BY (month int)" + val (desc, _) = extractTableDesc(query) + assert(desc.schema == Seq( + CatalogColumn("id", "int"), + CatalogColumn("name", "string"), + CatalogColumn("month", "int"))) + assert(desc.partitionColumnNames == Seq("month")) + } + + test("create table - clustered by") { + val baseQuery = "CREATE TABLE my_table (id int, name string) CLUSTERED BY(id)" + val query1 = s"$baseQuery INTO 10 BUCKETS" + val query2 = s"$baseQuery SORTED BY(id) INTO 10 BUCKETS" + val e1 = intercept[ParseException] { parser.parsePlan(query1) } + val e2 = intercept[ParseException] { parser.parsePlan(query2) } + assert(e1.getMessage.contains("Operation not allowed")) + assert(e2.getMessage.contains("Operation not allowed")) + } + + test("create table - skewed by") { + val baseQuery = "CREATE TABLE my_table (id int, name string) SKEWED BY" + val query1 = s"$baseQuery(id) ON (1, 10, 100)" + val query2 = s"$baseQuery(id, name) ON ((1, 'x'), (2, 'y'), (3, 'z'))" + val query3 = s"$baseQuery(id, name) ON ((1, 'x'), (2, 'y'), (3, 'z')) STORED AS DIRECTORIES" + val e1 = intercept[ParseException] { parser.parsePlan(query1) } + val e2 = intercept[ParseException] { parser.parsePlan(query2) } + val e3 = intercept[ParseException] { parser.parsePlan(query3) } + assert(e1.getMessage.contains("Operation not allowed")) + assert(e2.getMessage.contains("Operation not allowed")) + assert(e3.getMessage.contains("Operation not allowed")) + } + + test("create table - row format") { + val baseQuery = "CREATE TABLE my_table (id int, name string) ROW FORMAT" + val query1 = s"$baseQuery SERDE 'org.apache.poof.serde.Baff'" + val query2 = s"$baseQuery SERDE 'org.apache.poof.serde.Baff' WITH SERDEPROPERTIES ('k1'='v1')" + val query3 = + s""" + |$baseQuery DELIMITED FIELDS TERMINATED BY 'x' ESCAPED BY 'y' + |COLLECTION ITEMS TERMINATED BY 'a' + |MAP KEYS TERMINATED BY 'b' + |LINES TERMINATED BY '\n' + |NULL DEFINED AS 'c' + """.stripMargin + val (desc1, _) = extractTableDesc(query1) + val (desc2, _) = extractTableDesc(query2) + val (desc3, _) = extractTableDesc(query3) + assert(desc1.storage.serde == Some("org.apache.poof.serde.Baff")) + assert(desc1.storage.serdeProperties.isEmpty) + assert(desc2.storage.serde == Some("org.apache.poof.serde.Baff")) + assert(desc2.storage.serdeProperties == Map("k1" -> "v1")) + assert(desc3.storage.serdeProperties == Map( + "field.delim" -> "x", + "escape.delim" -> "y", + "serialization.format" -> "x", + "line.delim" -> "\n", + "colelction.delim" -> "a", // yes, it's a typo from Hive :) + "mapkey.delim" -> "b")) + } + + test("create table - file format") { + val baseQuery = "CREATE TABLE my_table (id int, name string) STORED AS" + val query1 = s"$baseQuery INPUTFORMAT 'winput' OUTPUTFORMAT 'wowput'" + val query2 = s"$baseQuery ORC" + val (desc1, _) = extractTableDesc(query1) + val (desc2, _) = extractTableDesc(query2) + assert(desc1.storage.inputFormat == Some("winput")) + assert(desc1.storage.outputFormat == Some("wowput")) + assert(desc1.storage.serde.isEmpty) + assert(desc2.storage.inputFormat == Some("org.apache.hadoop.hive.ql.io.orc.OrcInputFormat")) + assert(desc2.storage.outputFormat == Some("org.apache.hadoop.hive.ql.io.orc.OrcOutputFormat")) + assert(desc2.storage.serde == Some("org.apache.hadoop.hive.ql.io.orc.OrcSerde")) + } + + test("create table - storage handler") { + val baseQuery = "CREATE TABLE my_table (id int, name string) STORED BY" + val query1 = s"$baseQuery 'org.papachi.StorageHandler'" + val query2 = s"$baseQuery 'org.mamachi.StorageHandler' WITH SERDEPROPERTIES ('k1'='v1')" + val e1 = intercept[ParseException] { parser.parsePlan(query1) } + val e2 = intercept[ParseException] { parser.parsePlan(query2) } + assert(e1.getMessage.contains("Operation not allowed")) + assert(e2.getMessage.contains("Operation not allowed")) + } + + test("create table - location") { + val query = "CREATE TABLE my_table (id int, name string) LOCATION '/path/to/mars'" + val (desc, _) = extractTableDesc(query) + assert(desc.storage.locationUri == Some("/path/to/mars")) + } + + test("create table - properties") { + val query = "CREATE TABLE my_table (id int, name string) TBLPROPERTIES ('k1'='v1', 'k2'='v2')" + val (desc, _) = extractTableDesc(query) + assert(desc.properties == Map("k1" -> "v1", "k2" -> "v2")) + } + + test("create table - everything!") { + val query = + """ + |CREATE EXTERNAL TABLE IF NOT EXISTS dbx.my_table (id int, name string) + |COMMENT 'no comment' + |PARTITIONED BY (month int) + |ROW FORMAT SERDE 'org.apache.poof.serde.Baff' WITH SERDEPROPERTIES ('k1'='v1') + |STORED AS INPUTFORMAT 'winput' OUTPUTFORMAT 'wowput' + |LOCATION '/path/to/mercury' + |TBLPROPERTIES ('k1'='v1', 'k2'='v2') + """.stripMargin + val (desc, allowExisting) = extractTableDesc(query) + assert(allowExisting) + assert(desc.identifier.database == Some("dbx")) + assert(desc.identifier.table == "my_table") + assert(desc.tableType == CatalogTableType.EXTERNAL_TABLE) + assert(desc.schema == Seq( + CatalogColumn("id", "int"), + CatalogColumn("name", "string"), + CatalogColumn("month", "int"))) + assert(desc.partitionColumnNames == Seq("month")) + assert(desc.sortColumnNames.isEmpty) + assert(desc.bucketColumnNames.isEmpty) + assert(desc.numBuckets == -1) + assert(desc.viewText.isEmpty) + assert(desc.viewOriginalText.isEmpty) + assert(desc.storage.locationUri == Some("/path/to/mercury")) + assert(desc.storage.inputFormat == Some("winput")) + assert(desc.storage.outputFormat == Some("wowput")) + assert(desc.storage.serde == Some("org.apache.poof.serde.Baff")) + assert(desc.storage.serdeProperties == Map("k1" -> "v1")) + assert(desc.properties == Map("k1" -> "v1", "k2" -> "v2")) + assert(desc.comment == Some("no comment")) + } + test("create view -- basic") { val v1 = "CREATE VIEW view1 AS SELECT * FROM tab1" val (desc, exists) = extractTableDesc(v1) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala index ada8621d07..8648834f0d 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala @@ -88,7 +88,7 @@ class DataSourceWithHiveMetastoreCatalogSuite assert(hiveTable.storage.outputFormat === Some(outputFormat)) assert(hiveTable.storage.serde === Some(serde)) - assert(hiveTable.partitionColumns.isEmpty) + assert(hiveTable.partitionColumnNames.isEmpty) assert(hiveTable.tableType === CatalogTableType.MANAGED_TABLE) val columns = hiveTable.schema @@ -151,7 +151,7 @@ class DataSourceWithHiveMetastoreCatalogSuite assert(hiveTable.storage.outputFormat === Some(outputFormat)) assert(hiveTable.storage.serde === Some(serde)) - assert(hiveTable.partitionColumns.isEmpty) + assert(hiveTable.partitionColumnNames.isEmpty) assert(hiveTable.tableType === CatalogTableType.EXTERNAL_TABLE) val columns = hiveTable.schema diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala index 40e9c9362c..4db95636e7 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala @@ -81,7 +81,7 @@ class InsertIntoHiveTableSuite extends QueryTest with TestHiveSingleton with Bef test("Double create fails when allowExisting = false") { sql("CREATE TABLE doubleCreateAndInsertTest (key int, value string)") - intercept[QueryExecutionException] { + intercept[AnalysisException] { sql("CREATE TABLE doubleCreateAndInsertTest (key int, value string)") } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala index 37c01792d9..97cb9d9720 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala @@ -149,7 +149,7 @@ class PruningSuite extends HiveComparisonTest with BeforeAndAfter { val (actualScannedColumns, actualPartValues) = plan.collect { case p @ HiveTableScan(columns, relation, _) => val columnNames = columns.map(_.name) - val partValues = if (relation.table.partitionColumns.nonEmpty) { + val partValues = if (relation.table.partitionColumnNames.nonEmpty) { p.prunePartitions(relation.getHiveQlPartitions()).map(_.getValues) } else { Seq.empty diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 7eaf19dfe9..5ce16be4dc 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -360,7 +360,7 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { var message = intercept[AnalysisException] { sql("CREATE TABLE ctas1 AS SELECT key k, value FROM src ORDER BY k, value") }.getMessage - assert(message.contains("ctas1 already exists")) + assert(message.contains("already exists")) checkRelation("ctas1", true) sql("DROP TABLE ctas1") -- cgit v1.2.3 From f9d578eaa107d8e8503c1563a2b3990c85104298 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Wed, 13 Apr 2016 11:31:10 -0700 Subject: [SPARK-13783][ML] Model export/import for spark.ml: GBTs ## What changes were proposed in this pull request? * Added save/load for ```GBTClassifier/GBTClassificationModel/GBTRegressor/GBTRegressionModel```. * Meanwhile, I modified ```EnsembleModelReadWrite.saveImpl/loadImpl``` to support save/load ```treeWeights```. ## How was this patch tested? Adds standard unit tests for GBT save/load. cc jkbradley GayathriMurali Author: Yanbo Liang Closes #12230 from yanboliang/spark-13783. --- .../spark/ml/classification/GBTClassifier.scala | 110 ++++++++++++-------- .../ml/classification/RandomForestClassifier.scala | 2 +- .../apache/spark/ml/regression/GBTRegressor.scala | 114 +++++++++++++-------- .../ml/regression/RandomForestRegressor.scala | 2 +- .../org/apache/spark/ml/tree/treeModels.scala | 25 +++-- .../org/apache/spark/ml/tree/treeParams.scala | 73 ++++++++++++- .../ml/classification/GBTClassifierSuite.scala | 37 +++---- .../spark/ml/regression/GBTRegressorSuite.scala | 36 +++---- 8 files changed, 262 insertions(+), 137 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala index 46e8b89d01..39a698af15 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala @@ -18,19 +18,21 @@ package org.apache.spark.ml.classification import com.github.fommil.netlib.BLAS.{getInstance => blas} +import org.json4s.{DefaultFormats, JObject} +import org.json4s.JsonDSL._ import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.internal.Logging import org.apache.spark.ml.{PredictionModel, Predictor} -import org.apache.spark.ml.param.{Param, ParamMap} +import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.regression.DecisionTreeRegressionModel -import org.apache.spark.ml.tree.{GBTParams, TreeClassifierParams, TreeEnsembleModel} +import org.apache.spark.ml.tree._ import org.apache.spark.ml.tree.impl.GradientBoostedTrees -import org.apache.spark.ml.util.{Identifiable, MetadataUtils} +import org.apache.spark.ml.util._ +import org.apache.spark.ml.util.DefaultParamsReader.Metadata import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} -import org.apache.spark.mllib.tree.loss.{LogLoss => OldLogLoss, Loss => OldLoss} import org.apache.spark.mllib.tree.model.{GradientBoostedTreesModel => OldGBTModel} import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Dataset} @@ -58,7 +60,7 @@ import org.apache.spark.sql.functions._ final class GBTClassifier @Since("1.4.0") ( @Since("1.4.0") override val uid: String) extends Predictor[Vector, GBTClassifier, GBTClassificationModel] - with GBTParams with TreeClassifierParams with Logging { + with GBTClassifierParams with DefaultParamsWritable with Logging { @Since("1.4.0") def this() = this(Identifiable.randomUID("gbtc")) @@ -115,40 +117,12 @@ final class GBTClassifier @Since("1.4.0") ( @Since("1.4.0") override def setStepSize(value: Double): this.type = super.setStepSize(value) - // Parameters for GBTClassifier: - - /** - * Loss function which GBT tries to minimize. (case-insensitive) - * Supported: "logistic" - * (default = logistic) - * @group param - */ - @Since("1.4.0") - val lossType: Param[String] = new Param[String](this, "lossType", "Loss function which GBT" + - " tries to minimize (case-insensitive). Supported options:" + - s" ${GBTClassifier.supportedLossTypes.mkString(", ")}", - (value: String) => GBTClassifier.supportedLossTypes.contains(value.toLowerCase)) - - setDefault(lossType -> "logistic") + // Parameters from GBTClassifierParams: /** @group setParam */ @Since("1.4.0") def setLossType(value: String): this.type = set(lossType, value) - /** @group getParam */ - @Since("1.4.0") - def getLossType: String = $(lossType).toLowerCase - - /** (private[ml]) Convert new loss to old loss. */ - override private[ml] def getOldLossType: OldLoss = { - getLossType match { - case "logistic" => OldLogLoss - case _ => - // Should never happen because of check in setter method. - throw new RuntimeException(s"GBTClassifier was given bad loss type: $getLossType") - } - } - override protected def train(dataset: Dataset[_]): GBTClassificationModel = { val categoricalFeatures: Map[Int, Int] = MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol))) @@ -175,11 +149,14 @@ final class GBTClassifier @Since("1.4.0") ( @Since("1.4.0") @Experimental -object GBTClassifier { - // The losses below should be lowercase. +object GBTClassifier extends DefaultParamsReadable[GBTClassifier] { + /** Accessor for supported loss settings: logistic */ @Since("1.4.0") - final val supportedLossTypes: Array[String] = Array("logistic").map(_.toLowerCase) + final val supportedLossTypes: Array[String] = GBTClassifierParams.supportedLossTypes + + @Since("2.0.0") + override def load(path: String): GBTClassifier = super.load(path) } /** @@ -199,7 +176,8 @@ final class GBTClassificationModel private[ml]( private val _treeWeights: Array[Double], @Since("1.6.0") override val numFeatures: Int) extends PredictionModel[Vector, GBTClassificationModel] - with TreeEnsembleModel[DecisionTreeRegressionModel] with Serializable { + with GBTClassifierParams with TreeEnsembleModel[DecisionTreeRegressionModel] + with MLWritable with Serializable { require(_trees.nonEmpty, "GBTClassificationModel requires at least 1 tree.") require(_trees.length == _treeWeights.length, "GBTClassificationModel given trees, treeWeights" + @@ -267,12 +245,62 @@ final class GBTClassificationModel private[ml]( private[ml] def toOld: OldGBTModel = { new OldGBTModel(OldAlgo.Classification, _trees.map(_.toOld), _treeWeights) } + + @Since("2.0.0") + override def write: MLWriter = new GBTClassificationModel.GBTClassificationModelWriter(this) } -private[ml] object GBTClassificationModel { +@Since("2.0.0") +object GBTClassificationModel extends MLReadable[GBTClassificationModel] { + + @Since("2.0.0") + override def read: MLReader[GBTClassificationModel] = new GBTClassificationModelReader + + @Since("2.0.0") + override def load(path: String): GBTClassificationModel = super.load(path) + + private[GBTClassificationModel] + class GBTClassificationModelWriter(instance: GBTClassificationModel) extends MLWriter { + + override protected def saveImpl(path: String): Unit = { + + val extraMetadata: JObject = Map( + "numFeatures" -> instance.numFeatures, + "numTrees" -> instance.getNumTrees) + EnsembleModelReadWrite.saveImpl(instance, path, sqlContext, extraMetadata) + } + } + + private class GBTClassificationModelReader extends MLReader[GBTClassificationModel] { + + /** Checked against metadata when loading model */ + private val className = classOf[GBTClassificationModel].getName + private val treeClassName = classOf[DecisionTreeRegressionModel].getName + + override def load(path: String): GBTClassificationModel = { + implicit val format = DefaultFormats + val (metadata: Metadata, treesData: Array[(Metadata, Node)], treeWeights: Array[Double]) = + EnsembleModelReadWrite.loadImpl(path, sqlContext, className, treeClassName) + val numFeatures = (metadata.metadata \ "numFeatures").extract[Int] + val numTrees = (metadata.metadata \ "numTrees").extract[Int] + + val trees: Array[DecisionTreeRegressionModel] = treesData.map { + case (treeMetadata, root) => + val tree = + new DecisionTreeRegressionModel(treeMetadata.uid, root, numFeatures) + DefaultParamsReader.getAndSetParams(tree, treeMetadata) + tree + } + require(numTrees == trees.length, s"GBTClassificationModel.load expected $numTrees" + + s" trees based on metadata but found ${trees.length} trees.") + val model = new GBTClassificationModel(metadata.uid, trees, treeWeights, numFeatures) + DefaultParamsReader.getAndSetParams(model, metadata) + model + } + } - /** (private[ml]) Convert a model from the old API */ - def fromOld( + /** Convert a model from the old API */ + private[ml] def fromOld( oldModel: OldGBTModel, parent: GBTClassifier, categoricalFeatures: Map[Int, Int], diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala index 9d80b8eb68..dfa711b243 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala @@ -294,7 +294,7 @@ object RandomForestClassificationModel extends MLReadable[RandomForestClassifica override def load(path: String): RandomForestClassificationModel = { implicit val format = DefaultFormats - val (metadata: Metadata, treesData: Array[(Metadata, Node)]) = + val (metadata: Metadata, treesData: Array[(Metadata, Node)], _) = EnsembleModelReadWrite.loadImpl(path, sqlContext, className, treeClassName) val numFeatures = (metadata.metadata \ "numFeatures").extract[Int] val numClasses = (metadata.metadata \ "numClasses").extract[Int] diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala index 0b52fe2d13..741724d7a1 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala @@ -18,19 +18,20 @@ package org.apache.spark.ml.regression import com.github.fommil.netlib.BLAS.{getInstance => blas} +import org.json4s.{DefaultFormats, JObject} +import org.json4s.JsonDSL._ import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.internal.Logging import org.apache.spark.ml.{PredictionModel, Predictor} -import org.apache.spark.ml.param.{Param, ParamMap} -import org.apache.spark.ml.tree.{GBTParams, TreeEnsembleModel, TreeRegressorParams} +import org.apache.spark.ml.param.ParamMap +import org.apache.spark.ml.tree._ import org.apache.spark.ml.tree.impl.GradientBoostedTrees -import org.apache.spark.ml.util.{Identifiable, MetadataUtils} +import org.apache.spark.ml.util._ +import org.apache.spark.ml.util.DefaultParamsReader.Metadata import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} -import org.apache.spark.mllib.tree.loss.{AbsoluteError => OldAbsoluteError, Loss => OldLoss, - SquaredError => OldSquaredError} import org.apache.spark.mllib.tree.model.{GradientBoostedTreesModel => OldGBTModel} import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Dataset} @@ -58,7 +59,7 @@ import org.apache.spark.sql.functions._ @Experimental final class GBTRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: String) extends Predictor[Vector, GBTRegressor, GBTRegressionModel] - with GBTParams with TreeRegressorParams with Logging { + with GBTRegressorParams with DefaultParamsWritable with Logging { @Since("1.4.0") def this() = this(Identifiable.randomUID("gbtr")) @@ -112,41 +113,12 @@ final class GBTRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: Stri @Since("1.4.0") override def setStepSize(value: Double): this.type = super.setStepSize(value) - // Parameters for GBTRegressor: - - /** - * Loss function which GBT tries to minimize. (case-insensitive) - * Supported: "squared" (L2) and "absolute" (L1) - * (default = squared) - * @group param - */ - @Since("1.4.0") - val lossType: Param[String] = new Param[String](this, "lossType", "Loss function which GBT" + - " tries to minimize (case-insensitive). Supported options:" + - s" ${GBTRegressor.supportedLossTypes.mkString(", ")}", - (value: String) => GBTRegressor.supportedLossTypes.contains(value.toLowerCase)) - - setDefault(lossType -> "squared") + // Parameters from GBTRegressorParams: /** @group setParam */ @Since("1.4.0") def setLossType(value: String): this.type = set(lossType, value) - /** @group getParam */ - @Since("1.4.0") - def getLossType: String = $(lossType).toLowerCase - - /** (private[ml]) Convert new loss to old loss. */ - override private[ml] def getOldLossType: OldLoss = { - getLossType match { - case "squared" => OldSquaredError - case "absolute" => OldAbsoluteError - case _ => - // Should never happen because of check in setter method. - throw new RuntimeException(s"GBTRegressorParams was given bad loss type: $getLossType") - } - } - override protected def train(dataset: Dataset[_]): GBTRegressionModel = { val categoricalFeatures: Map[Int, Int] = MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol))) @@ -164,11 +136,14 @@ final class GBTRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: Stri @Since("1.4.0") @Experimental -object GBTRegressor { - // The losses below should be lowercase. +object GBTRegressor extends DefaultParamsReadable[GBTRegressor] { + /** Accessor for supported loss settings: squared (L2), absolute (L1) */ @Since("1.4.0") - final val supportedLossTypes: Array[String] = Array("squared", "absolute").map(_.toLowerCase) + final val supportedLossTypes: Array[String] = GBTRegressorParams.supportedLossTypes + + @Since("2.0.0") + override def load(path: String): GBTRegressor = super.load(path) } /** @@ -188,7 +163,8 @@ final class GBTRegressionModel private[ml]( private val _treeWeights: Array[Double], override val numFeatures: Int) extends PredictionModel[Vector, GBTRegressionModel] - with TreeEnsembleModel[DecisionTreeRegressionModel] with Serializable { + with GBTRegressorParams with TreeEnsembleModel[DecisionTreeRegressionModel] + with MLWritable with Serializable { require(_trees.nonEmpty, "GBTRegressionModel requires at least 1 tree.") require(_trees.length == _treeWeights.length, "GBTRegressionModel given trees, treeWeights of" + @@ -255,12 +231,64 @@ final class GBTRegressionModel private[ml]( private[ml] def toOld: OldGBTModel = { new OldGBTModel(OldAlgo.Regression, _trees.map(_.toOld), _treeWeights) } + + @Since("2.0.0") + override def write: MLWriter = new GBTRegressionModel.GBTRegressionModelWriter(this) } -private[ml] object GBTRegressionModel { +@Since("2.0.0") +object GBTRegressionModel extends MLReadable[GBTRegressionModel] { + + @Since("2.0.0") + override def read: MLReader[GBTRegressionModel] = new GBTRegressionModelReader + + @Since("2.0.0") + override def load(path: String): GBTRegressionModel = super.load(path) + + private[GBTRegressionModel] + class GBTRegressionModelWriter(instance: GBTRegressionModel) extends MLWriter { + + override protected def saveImpl(path: String): Unit = { + val extraMetadata: JObject = Map( + "numFeatures" -> instance.numFeatures, + "numTrees" -> instance.getNumTrees) + EnsembleModelReadWrite.saveImpl(instance, path, sqlContext, extraMetadata) + } + } + + private class GBTRegressionModelReader extends MLReader[GBTRegressionModel] { + + /** Checked against metadata when loading model */ + private val className = classOf[GBTRegressionModel].getName + private val treeClassName = classOf[DecisionTreeRegressionModel].getName + + override def load(path: String): GBTRegressionModel = { + implicit val format = DefaultFormats + val (metadata: Metadata, treesData: Array[(Metadata, Node)], treeWeights: Array[Double]) = + EnsembleModelReadWrite.loadImpl(path, sqlContext, className, treeClassName) + + val numFeatures = (metadata.metadata \ "numFeatures").extract[Int] + val numTrees = (metadata.metadata \ "numTrees").extract[Int] + + val trees: Array[DecisionTreeRegressionModel] = treesData.map { + case (treeMetadata, root) => + val tree = + new DecisionTreeRegressionModel(treeMetadata.uid, root, numFeatures) + DefaultParamsReader.getAndSetParams(tree, treeMetadata) + tree + } + + require(numTrees == trees.length, s"GBTRegressionModel.load expected $numTrees" + + s" trees based on metadata but found ${trees.length} trees.") + + val model = new GBTRegressionModel(metadata.uid, trees, treeWeights, numFeatures) + DefaultParamsReader.getAndSetParams(model, metadata) + model + } + } - /** (private[ml]) Convert a model from the old API */ - def fromOld( + /** Convert a model from the old API */ + private[ml] def fromOld( oldModel: OldGBTModel, parent: GBTRegressor, categoricalFeatures: Map[Int, Int], diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala index bee13c2ebf..4c4ff278d4 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala @@ -249,7 +249,7 @@ object RandomForestRegressionModel extends MLReadable[RandomForestRegressionMode override def load(path: String): RandomForestRegressionModel = { implicit val format = DefaultFormats - val (metadata: Metadata, treesData: Array[(Metadata, Node)]) = + val (metadata: Metadata, treesData: Array[(Metadata, Node)], treeWeights: Array[Double]) = EnsembleModelReadWrite.loadImpl(path, sqlContext, className, treeClassName) val numFeatures = (metadata.metadata \ "numFeatures").extract[Int] val numTrees = (metadata.metadata \ "numTrees").extract[Int] diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala index c4ab673d9a..f38e1ec7c0 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala @@ -396,12 +396,14 @@ private[ml] object EnsembleModelReadWrite { sql: SQLContext, extraMetadata: JObject): Unit = { DefaultParamsWriter.saveMetadata(instance, path, sql.sparkContext, Some(extraMetadata)) - val treesMetadataJson: Array[(Int, String)] = instance.trees.zipWithIndex.map { + val treesMetadataWeights: Array[(Int, String, Double)] = instance.trees.zipWithIndex.map { case (tree, treeID) => - treeID -> DefaultParamsWriter.getMetadataToSave(tree.asInstanceOf[Params], sql.sparkContext) + (treeID, + DefaultParamsWriter.getMetadataToSave(tree.asInstanceOf[Params], sql.sparkContext), + instance.treeWeights(treeID)) } val treesMetadataPath = new Path(path, "treesMetadata").toString - sql.createDataFrame(treesMetadataJson).toDF("treeID", "metadata") + sql.createDataFrame(treesMetadataWeights).toDF("treeID", "metadata", "weights") .write.parquet(treesMetadataPath) val dataPath = new Path(path, "data").toString val nodeDataRDD = sql.sparkContext.parallelize(instance.trees.zipWithIndex).flatMap { @@ -424,7 +426,7 @@ private[ml] object EnsembleModelReadWrite { path: String, sql: SQLContext, className: String, - treeClassName: String): (Metadata, Array[(Metadata, Node)]) = { + treeClassName: String): (Metadata, Array[(Metadata, Node)], Array[Double]) = { import sql.implicits._ implicit val format = DefaultFormats val metadata = DefaultParamsReader.loadMetadata(path, sql.sparkContext, className) @@ -436,12 +438,15 @@ private[ml] object EnsembleModelReadWrite { } val treesMetadataPath = new Path(path, "treesMetadata").toString - val treesMetadataRDD: RDD[(Int, Metadata)] = sql.read.parquet(treesMetadataPath) - .select("treeID", "metadata").as[(Int, String)].rdd.map { - case (treeID: Int, json: String) => - treeID -> DefaultParamsReader.parseMetadata(json, treeClassName) + val treesMetadataRDD: RDD[(Int, (Metadata, Double))] = sql.read.parquet(treesMetadataPath) + .select("treeID", "metadata", "weights").as[(Int, String, Double)].rdd.map { + case (treeID: Int, json: String, weights: Double) => + treeID -> (DefaultParamsReader.parseMetadata(json, treeClassName), weights) } - val treesMetadata: Array[Metadata] = treesMetadataRDD.sortByKey().values.collect() + + val treesMetadataWeights = treesMetadataRDD.sortByKey().values.collect() + val treesMetadata = treesMetadataWeights.map(_._1) + val treesWeights = treesMetadataWeights.map(_._2) val dataPath = new Path(path, "data").toString val nodeData: Dataset[EnsembleNodeData] = @@ -452,7 +457,7 @@ private[ml] object EnsembleModelReadWrite { treeID -> DecisionTreeModelReadWrite.buildTreeFromNodes(nodeData.toArray, impurityType) } val rootNodes: Array[Node] = rootNodesRDD.sortByKey().values.collect() - (metadata, treesMetadata.zip(rootNodes)) + (metadata, treesMetadata.zip(rootNodes), treesWeights) } /** diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala index 0767dc17e5..b6783911ad 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala @@ -23,7 +23,7 @@ import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util.SchemaUtils import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, BoostingStrategy => OldBoostingStrategy, Strategy => OldStrategy} import org.apache.spark.mllib.tree.impurity.{Entropy => OldEntropy, Gini => OldGini, Impurity => OldImpurity, Variance => OldVariance} -import org.apache.spark.mllib.tree.loss.{Loss => OldLoss} +import org.apache.spark.mllib.tree.loss.{AbsoluteError => OldAbsoluteError, LogLoss => OldLogLoss, Loss => OldLoss, SquaredError => OldSquaredError} import org.apache.spark.sql.types.{DataType, DoubleType, StructType} /** @@ -462,3 +462,74 @@ private[ml] trait GBTParams extends TreeEnsembleParams with HasMaxIter with HasS /** Get old Gradient Boosting Loss type */ private[ml] def getOldLossType: OldLoss } + +private[ml] object GBTClassifierParams { + // The losses below should be lowercase. + /** Accessor for supported loss settings: logistic */ + final val supportedLossTypes: Array[String] = Array("logistic").map(_.toLowerCase) +} + +private[ml] trait GBTClassifierParams extends GBTParams with TreeClassifierParams { + + /** + * Loss function which GBT tries to minimize. (case-insensitive) + * Supported: "logistic" + * (default = logistic) + * @group param + */ + val lossType: Param[String] = new Param[String](this, "lossType", "Loss function which GBT" + + " tries to minimize (case-insensitive). Supported options:" + + s" ${GBTClassifierParams.supportedLossTypes.mkString(", ")}", + (value: String) => GBTClassifierParams.supportedLossTypes.contains(value.toLowerCase)) + + setDefault(lossType -> "logistic") + + /** @group getParam */ + def getLossType: String = $(lossType).toLowerCase + + /** (private[ml]) Convert new loss to old loss. */ + override private[ml] def getOldLossType: OldLoss = { + getLossType match { + case "logistic" => OldLogLoss + case _ => + // Should never happen because of check in setter method. + throw new RuntimeException(s"GBTClassifier was given bad loss type: $getLossType") + } + } +} + +private[ml] object GBTRegressorParams { + // The losses below should be lowercase. + /** Accessor for supported loss settings: squared (L2), absolute (L1) */ + final val supportedLossTypes: Array[String] = Array("squared", "absolute").map(_.toLowerCase) +} + +private[ml] trait GBTRegressorParams extends GBTParams with TreeRegressorParams { + + /** + * Loss function which GBT tries to minimize. (case-insensitive) + * Supported: "squared" (L2) and "absolute" (L1) + * (default = squared) + * @group param + */ + val lossType: Param[String] = new Param[String](this, "lossType", "Loss function which GBT" + + " tries to minimize (case-insensitive). Supported options:" + + s" ${GBTRegressorParams.supportedLossTypes.mkString(", ")}", + (value: String) => GBTRegressorParams.supportedLossTypes.contains(value.toLowerCase)) + + setDefault(lossType -> "squared") + + /** @group getParam */ + def getLossType: String = $(lossType).toLowerCase + + /** (private[ml]) Convert new loss to old loss. */ + override private[ml] def getOldLossType: OldLoss = { + getLossType match { + case "squared" => OldSquaredError + case "absolute" => OldAbsoluteError + case _ => + // Should never happen because of check in setter method. + throw new RuntimeException(s"GBTRegressorParams was given bad loss type: $getLossType") + } + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala index 76d8c9372e..7e6aec6b1b 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala @@ -22,7 +22,7 @@ import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.regression.DecisionTreeRegressionModel import org.apache.spark.ml.tree.LeafNode import org.apache.spark.ml.tree.impl.TreeTests -import org.apache.spark.ml.util.MLTestingUtils +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.{EnsembleTestHelper, GradientBoostedTrees => OldGBT} import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} @@ -34,7 +34,8 @@ import org.apache.spark.util.Utils /** * Test suite for [[GBTClassifier]]. */ -class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext { +class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext + with DefaultReadWriteTest { import GBTClassifierSuite.compareAPIs @@ -156,27 +157,23 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext { // Tests of model save/load ///////////////////////////////////////////////////////////////////////////// - // TODO: Reinstate test once save/load are implemented SPARK-6725 - /* test("model save/load") { - val tempDir = Utils.createTempDir() - val path = tempDir.toURI.toString - - val trees = Range(0, 3).map(_ => OldDecisionTreeSuite.createModel(OldAlgo.Regression)).toArray - val treeWeights = Array(0.1, 0.3, 1.1) - val oldModel = new OldGBTModel(OldAlgo.Classification, trees, treeWeights) - val newModel = GBTClassificationModel.fromOld(oldModel) - - // Save model, load it back, and compare. - try { - newModel.save(sc, path) - val sameNewModel = GBTClassificationModel.load(sc, path) - TreeTests.checkEqual(newModel, sameNewModel) - } finally { - Utils.deleteRecursively(tempDir) + def checkModelData( + model: GBTClassificationModel, + model2: GBTClassificationModel): Unit = { + TreeTests.checkEqual(model, model2) + assert(model.numFeatures === model2.numFeatures) } + + val gbt = new GBTClassifier() + val rdd = TreeTests.getTreeReadWriteData(sc) + + val allParamSettings = TreeTests.allParamSettings ++ Map("lossType" -> "logistic") + + val continuousData: DataFrame = + TreeTests.setMetadata(rdd, Map.empty[Int, Int], numClasses = 2) + testEstimatorAndModelReadWrite(gbt, continuousData, allParamSettings, checkModelData) } - */ } private object GBTClassifierSuite extends SparkFunSuite { diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala index 3c11631f98..216377959e 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.ml.regression import org.apache.spark.SparkFunSuite import org.apache.spark.ml.tree.impl.TreeTests -import org.apache.spark.ml.util.MLTestingUtils +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.{EnsembleTestHelper, GradientBoostedTrees => OldGBT} @@ -32,7 +32,8 @@ import org.apache.spark.util.Utils /** * Test suite for [[GBTRegressor]]. */ -class GBTRegressorSuite extends SparkFunSuite with MLlibTestSparkContext { +class GBTRegressorSuite extends SparkFunSuite with MLlibTestSparkContext + with DefaultReadWriteTest { import GBTRegressorSuite.compareAPIs @@ -164,27 +165,22 @@ class GBTRegressorSuite extends SparkFunSuite with MLlibTestSparkContext { // Tests of model save/load ///////////////////////////////////////////////////////////////////////////// - // TODO: Reinstate test once save/load are implemented SPARK-6725 - /* test("model save/load") { - val tempDir = Utils.createTempDir() - val path = tempDir.toURI.toString - - val trees = Range(0, 3).map(_ => OldDecisionTreeSuite.createModel(OldAlgo.Regression)).toArray - val treeWeights = Array(0.1, 0.3, 1.1) - val oldModel = new OldGBTModel(OldAlgo.Regression, trees, treeWeights) - val newModel = GBTRegressionModel.fromOld(oldModel) - - // Save model, load it back, and compare. - try { - newModel.save(sc, path) - val sameNewModel = GBTRegressionModel.load(sc, path) - TreeTests.checkEqual(newModel, sameNewModel) - } finally { - Utils.deleteRecursively(tempDir) + def checkModelData( + model: GBTRegressionModel, + model2: GBTRegressionModel): Unit = { + TreeTests.checkEqual(model, model2) + assert(model.numFeatures === model2.numFeatures) } + + val gbt = new GBTRegressor() + val rdd = TreeTests.getTreeReadWriteData(sc) + + val allParamSettings = TreeTests.allParamSettings ++ Map("lossType" -> "squared") + val continuousData: DataFrame = + TreeTests.setMetadata(rdd, Map.empty[Int, Int], numClasses = 0) + testEstimatorAndModelReadWrite(gbt, continuousData, allParamSettings, checkModelData) } - */ } private object GBTRegressorSuite extends SparkFunSuite { -- cgit v1.2.3 From dbbe149070052af5cda04f7b110d65de73766ded Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Wed, 13 Apr 2016 13:01:13 -0700 Subject: [SPARK-14581] [SQL] push predicatese through more logical plans ## What changes were proposed in this pull request? Right now, filter push down only works with Project, Aggregate, Generate and Join, they can't be pushed through many other plans. This PR added support for Union, Intersect, Except and all unary plans. ## How was this patch tested? Added tests. Author: Davies Liu Closes #12342 from davies/filter_hint. --- .../spark/sql/catalyst/optimizer/Optimizer.scala | 111 +++++++++++++-------- .../spark/sql/catalyst/planning/patterns.scala | 3 + .../catalyst/optimizer/ColumnPruningSuite.scala | 2 +- .../catalyst/optimizer/FilterPushdownSuite.scala | 76 ++++++++++++-- .../catalyst/optimizer/JoinOptimizationSuite.scala | 4 +- .../sql/catalyst/optimizer/PruneFiltersSuite.scala | 2 +- 6 files changed, 146 insertions(+), 52 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index bad115d22f..438cbabdbb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -66,9 +66,7 @@ abstract class Optimizer extends RuleExecutor[LogicalPlan] { ReorderJoin, OuterJoinElimination, PushPredicateThroughJoin, - PushPredicateThroughProject, - PushPredicateThroughGenerate, - PushPredicateThroughAggregate, + PushDownPredicate, LimitPushDown, ColumnPruning, InferFiltersFromConstraints, @@ -917,12 +915,13 @@ object PruneFilters extends Rule[LogicalPlan] with PredicateHelper { } /** - * Pushes [[Filter]] operators through [[Project]] operators, in-lining any [[Alias Aliases]] - * that were defined in the projection. + * Pushes [[Filter]] operators through many operators iff: + * 1) the operator is deterministic + * 2) the predicate is deterministic and the operator will not change any of rows. * * This heuristic is valid assuming the expression evaluation cost is minimal. */ -object PushPredicateThroughProject extends Rule[LogicalPlan] with PredicateHelper { +object PushDownPredicate extends Rule[LogicalPlan] with PredicateHelper { def apply(plan: LogicalPlan): LogicalPlan = plan transform { // SPARK-13473: We can't push the predicate down when the underlying projection output non- // deterministic field(s). Non-deterministic expressions are essentially stateful. This @@ -939,41 +938,7 @@ object PushPredicateThroughProject extends Rule[LogicalPlan] with PredicateHelpe }) project.copy(child = Filter(replaceAlias(condition, aliasMap), grandChild)) - } - -} - -/** - * Push [[Filter]] operators through [[Generate]] operators. Parts of the predicate that reference - * attributes generated in [[Generate]] will remain above, and the rest should be pushed beneath. - */ -object PushPredicateThroughGenerate extends Rule[LogicalPlan] with PredicateHelper { - - def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case filter @ Filter(condition, g: Generate) => - // Predicates that reference attributes produced by the `Generate` operator cannot - // be pushed below the operator. - val (pushDown, stayUp) = splitConjunctivePredicates(condition).partition { cond => - cond.references.subsetOf(g.child.outputSet) && cond.deterministic - } - if (pushDown.nonEmpty) { - val pushDownPredicate = pushDown.reduce(And) - val newGenerate = Generate(g.generator, join = g.join, outer = g.outer, - g.qualifier, g.generatorOutput, Filter(pushDownPredicate, g.child)) - if (stayUp.isEmpty) newGenerate else Filter(stayUp.reduce(And), newGenerate) - } else { - filter - } - } -} -/** - * Push [[Filter]] operators through [[Aggregate]] operators, iff the filters reference only - * non-aggregate attributes (typically literals or grouping expressions). - */ -object PushPredicateThroughAggregate extends Rule[LogicalPlan] with PredicateHelper { - - def apply(plan: LogicalPlan): LogicalPlan = plan transform { case filter @ Filter(condition, aggregate: Aggregate) => // Find all the aliased expressions in the aggregate list that don't include any actual // AggregateExpression, and create a map from the alias to the expression @@ -999,6 +964,72 @@ object PushPredicateThroughAggregate extends Rule[LogicalPlan] with PredicateHel } else { filter } + + case filter @ Filter(condition, child) + if child.isInstanceOf[Union] || child.isInstanceOf[Intersect] => + // Union/Intersect could change the rows, so non-deterministic predicate can't be pushed down + val (pushDown, stayUp) = splitConjunctivePredicates(condition).partition { cond => + cond.deterministic + } + if (pushDown.nonEmpty) { + val pushDownCond = pushDown.reduceLeft(And) + val output = child.output + val newGrandChildren = child.children.map { grandchild => + val newCond = pushDownCond transform { + case e if output.exists(_.semanticEquals(e)) => + grandchild.output(output.indexWhere(_.semanticEquals(e))) + } + assert(newCond.references.subsetOf(grandchild.outputSet)) + Filter(newCond, grandchild) + } + val newChild = child.withNewChildren(newGrandChildren) + if (stayUp.nonEmpty) { + Filter(stayUp.reduceLeft(And), newChild) + } else { + newChild + } + } else { + filter + } + + case filter @ Filter(condition, e @ Except(left, _)) => + pushDownPredicate(filter, e.left) { predicate => + e.copy(left = Filter(predicate, left)) + } + + // two filters should be combine together by other rules + case filter @ Filter(_, f: Filter) => filter + // should not push predicates through sample, or will generate different results. + case filter @ Filter(_, s: Sample) => filter + // TODO: push predicates through expand + case filter @ Filter(_, e: Expand) => filter + + case filter @ Filter(condition, u: UnaryNode) if u.expressions.forall(_.deterministic) => + pushDownPredicate(filter, u.child) { predicate => + u.withNewChildren(Seq(Filter(predicate, u.child))) + } + } + + private def pushDownPredicate( + filter: Filter, + grandchild: LogicalPlan)(insertFilter: Expression => LogicalPlan): LogicalPlan = { + // Only push down the predicates that is deterministic and all the referenced attributes + // come from grandchild. + // TODO: non-deterministic predicates could be pushed through some operators that do not change + // the rows. + val (pushDown, stayUp) = splitConjunctivePredicates(filter.condition).partition { cond => + cond.deterministic && cond.references.subsetOf(grandchild.outputSet) + } + if (pushDown.nonEmpty) { + val newChild = insertFilter(pushDown.reduceLeft(And)) + if (stayUp.nonEmpty) { + Filter(stayUp.reduceLeft(And), newChild) + } else { + newChild + } + } else { + filter + } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index 6f35d87ebb..0065619135 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -69,6 +69,9 @@ object PhysicalOperation extends PredicateHelper { val substitutedCondition = substitute(aliases)(condition) (fields, filters ++ splitConjunctivePredicates(substitutedCondition), other, aliases) + case BroadcastHint(child) => + collectProjectsAndFilters(child) + case other => (None, Nil, other, Map.empty) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala index 2248e03b2f..52b574c0e6 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala @@ -34,7 +34,7 @@ class ColumnPruningSuite extends PlanTest { object Optimize extends RuleExecutor[LogicalPlan] { val batches = Batch("Column pruning", FixedPoint(100), - PushPredicateThroughProject, + PushDownPredicate, ColumnPruning, CollapseProject) :: Nil } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala index b84ae7c5bb..df7529d83f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala @@ -33,14 +33,12 @@ class FilterPushdownSuite extends PlanTest { val batches = Batch("Subqueries", Once, EliminateSubqueryAliases) :: - Batch("Filter Pushdown", Once, + Batch("Filter Pushdown", FixedPoint(10), SamplePushDown, CombineFilters, - PushPredicateThroughProject, + PushDownPredicate, BooleanSimplification, PushPredicateThroughJoin, - PushPredicateThroughGenerate, - PushPredicateThroughAggregate, CollapseProject) :: Nil } @@ -620,8 +618,8 @@ class FilterPushdownSuite extends PlanTest { val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = testRelation - .select('a, 'b) .where('a === 3) + .select('a, 'b) .groupBy('a)('a, count('b) as 'c) .where('c === 2L) .analyze @@ -638,8 +636,8 @@ class FilterPushdownSuite extends PlanTest { val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = testRelation - .select('a, 'b) .where('a + 1 < 3) + .select('a, 'b) .groupBy('a)(('a + 1) as 'aa, count('b) as 'c) .where('c === 2L || 'aa > 4) .analyze @@ -656,8 +654,8 @@ class FilterPushdownSuite extends PlanTest { val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = testRelation - .select('a, 'b) .where("s" === "s") + .select('a, 'b) .groupBy('a)('a, count('b) as 'c, "s" as 'd) .where('c === 2L) .analyze @@ -681,4 +679,68 @@ class FilterPushdownSuite extends PlanTest { comparePlans(optimized, correctAnswer) } + + test("broadcast hint") { + val originalQuery = BroadcastHint(testRelation) + .where('a === 2L && 'b + Rand(10).as("rnd") === 3) + + val optimized = Optimize.execute(originalQuery.analyze) + + val correctAnswer = BroadcastHint(testRelation.where('a === 2L)) + .where('b + Rand(10).as("rnd") === 3) + .analyze + + comparePlans(optimized, correctAnswer) + } + + test("union") { + val testRelation2 = LocalRelation('d.int, 'e.int, 'f.int) + + val originalQuery = Union(Seq(testRelation, testRelation2)) + .where('a === 2L && 'b + Rand(10).as("rnd") === 3) + + val optimized = Optimize.execute(originalQuery.analyze) + + val correctAnswer = Union(Seq( + testRelation.where('a === 2L), + testRelation2.where('d === 2L))) + .where('b + Rand(10).as("rnd") === 3) + .analyze + + comparePlans(optimized, correctAnswer) + } + + test("intersect") { + val testRelation2 = LocalRelation('d.int, 'e.int, 'f.int) + + val originalQuery = Intersect(testRelation, testRelation2) + .where('a === 2L && 'b + Rand(10).as("rnd") === 3) + + val optimized = Optimize.execute(originalQuery.analyze) + + val correctAnswer = Intersect( + testRelation.where('a === 2L), + testRelation2.where('d === 2L)) + .where('b + Rand(10).as("rnd") === 3) + .analyze + + comparePlans(optimized, correctAnswer) + } + + test("except") { + val testRelation2 = LocalRelation('d.int, 'e.int, 'f.int) + + val originalQuery = Except(testRelation, testRelation2) + .where('a === 2L && 'b + Rand(10).as("rnd") === 3) + + val optimized = Optimize.execute(originalQuery.analyze) + + val correctAnswer = Except( + testRelation.where('a === 2L), + testRelation2) + .where('b + Rand(10).as("rnd") === 3) + .analyze + + comparePlans(optimized, correctAnswer) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala index e2f8146bee..c1ebf8b09e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala @@ -36,12 +36,10 @@ class JoinOptimizationSuite extends PlanTest { EliminateSubqueryAliases) :: Batch("Filter Pushdown", FixedPoint(100), CombineFilters, - PushPredicateThroughProject, + PushDownPredicate, BooleanSimplification, ReorderJoin, PushPredicateThroughJoin, - PushPredicateThroughGenerate, - PushPredicateThroughAggregate, ColumnPruning, CollapseProject) :: Nil diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PruneFiltersSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PruneFiltersSuite.scala index 14fb72a8a3..d8cfec5391 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PruneFiltersSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PruneFiltersSuite.scala @@ -34,7 +34,7 @@ class PruneFiltersSuite extends PlanTest { Batch("Filter Pushdown and Pruning", Once, CombineFilters, PruneFilters, - PushPredicateThroughProject, + PushDownPredicate, PushPredicateThroughJoin) :: Nil } -- cgit v1.2.3 From b0adb9f543fbac16ea14c64eef6ba032a9919039 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Wed, 13 Apr 2016 13:18:02 -0700 Subject: [SPARK-10386][MLLIB] PrefixSpanModel supports save/load ```PrefixSpanModel``` supports ```save/load```. It's similar with #9267. cc jkbradley Author: Yanbo Liang Closes #10664 from yanboliang/spark-10386. --- .../org/apache/spark/mllib/fpm/PrefixSpan.scala | 96 +++++++++++++++++++++- .../spark/mllib/fpm/JavaPrefixSpanSuite.java | 37 +++++++++ .../apache/spark/mllib/fpm/PrefixSpanSuite.scala | 31 +++++++ 3 files changed, 163 insertions(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala index 4455681e50..4344ab1bad 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala @@ -23,12 +23,22 @@ import java.util.concurrent.atomic.AtomicInteger import scala.collection.JavaConverters._ import scala.collection.mutable import scala.reflect.ClassTag +import scala.reflect.runtime.universe._ +import org.json4s.DefaultFormats +import org.json4s.JsonDSL._ +import org.json4s.jackson.JsonMethods.{compact, render} + +import org.apache.spark.SparkContext import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.api.java.JavaRDD import org.apache.spark.api.java.JavaSparkContext.fakeClassTag import org.apache.spark.internal.Logging +import org.apache.spark.mllib.util.{Loader, Saveable} import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{DataFrame, Row, SQLContext} +import org.apache.spark.sql.catalyst.ScalaReflection +import org.apache.spark.sql.types._ import org.apache.spark.storage.StorageLevel /** @@ -566,4 +576,88 @@ object PrefixSpan extends Logging { @Since("1.5.0") class PrefixSpanModel[Item] @Since("1.5.0") ( @Since("1.5.0") val freqSequences: RDD[PrefixSpan.FreqSequence[Item]]) - extends Serializable + extends Saveable with Serializable { + + /** + * Save this model to the given path. + * It only works for Item datatypes supported by DataFrames. + * + * This saves: + * - human-readable (JSON) model metadata to path/metadata/ + * - Parquet formatted data to path/data/ + * + * The model may be loaded using [[PrefixSpanModel.load]]. + * + * @param sc Spark context used to save model data. + * @param path Path specifying the directory in which to save this model. + * If the directory already exists, this method throws an exception. + */ + @Since("2.0.0") + override def save(sc: SparkContext, path: String): Unit = { + PrefixSpanModel.SaveLoadV1_0.save(this, path) + } + + override protected val formatVersion: String = "1.0" +} + +@Since("2.0.0") +object PrefixSpanModel extends Loader[PrefixSpanModel[_]] { + + @Since("2.0.0") + override def load(sc: SparkContext, path: String): PrefixSpanModel[_] = { + PrefixSpanModel.SaveLoadV1_0.load(sc, path) + } + + private[fpm] object SaveLoadV1_0 { + + private val thisFormatVersion = "1.0" + + private val thisClassName = "org.apache.spark.mllib.fpm.PrefixSpanModel" + + def save(model: PrefixSpanModel[_], path: String): Unit = { + val sc = model.freqSequences.sparkContext + val sqlContext = SQLContext.getOrCreate(sc) + + val metadata = compact(render( + ("class" -> thisClassName) ~ ("version" -> thisFormatVersion))) + sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path)) + + // Get the type of item class + val sample = model.freqSequences.first().sequence(0)(0) + val className = sample.getClass.getCanonicalName + val classSymbol = runtimeMirror(getClass.getClassLoader).staticClass(className) + val tpe = classSymbol.selfType + + val itemType = ScalaReflection.schemaFor(tpe).dataType + val fields = Array(StructField("sequence", ArrayType(ArrayType(itemType))), + StructField("freq", LongType)) + val schema = StructType(fields) + val rowDataRDD = model.freqSequences.map { x => + Row(x.sequence, x.freq) + } + sqlContext.createDataFrame(rowDataRDD, schema).write.parquet(Loader.dataPath(path)) + } + + def load(sc: SparkContext, path: String): PrefixSpanModel[_] = { + implicit val formats = DefaultFormats + val sqlContext = SQLContext.getOrCreate(sc) + + val (className, formatVersion, metadata) = Loader.loadMetadata(sc, path) + assert(className == thisClassName) + assert(formatVersion == thisFormatVersion) + + val freqSequences = sqlContext.read.parquet(Loader.dataPath(path)) + val sample = freqSequences.select("sequence").head().get(0) + loadImpl(freqSequences, sample) + } + + def loadImpl[Item: ClassTag](freqSequences: DataFrame, sample: Item): PrefixSpanModel[Item] = { + val freqSequencesRDD = freqSequences.select("sequence", "freq").rdd.map { x => + val sequence = x.getAs[Seq[Seq[Item]]](0).map(_.toArray).toArray + val freq = x.getLong(1) + new PrefixSpan.FreqSequence(sequence, freq) + } + new PrefixSpanModel(freqSequencesRDD) + } + } +} diff --git a/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaPrefixSpanSuite.java b/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaPrefixSpanSuite.java index 34daf5fbde..8a67793abc 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaPrefixSpanSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaPrefixSpanSuite.java @@ -17,6 +17,7 @@ package org.apache.spark.mllib.fpm; +import java.io.File; import java.util.Arrays; import java.util.List; @@ -28,6 +29,7 @@ import org.junit.Test; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.mllib.fpm.PrefixSpan.FreqSequence; +import org.apache.spark.util.Utils; public class JavaPrefixSpanSuite { private transient JavaSparkContext sc; @@ -64,4 +66,39 @@ public class JavaPrefixSpanSuite { long freq = freqSeq.freq(); } } + + @Test + public void runPrefixSpanSaveLoad() { + JavaRDD>> sequences = sc.parallelize(Arrays.asList( + Arrays.asList(Arrays.asList(1, 2), Arrays.asList(3)), + Arrays.asList(Arrays.asList(1), Arrays.asList(3, 2), Arrays.asList(1, 2)), + Arrays.asList(Arrays.asList(1, 2), Arrays.asList(5)), + Arrays.asList(Arrays.asList(6)) + ), 2); + PrefixSpan prefixSpan = new PrefixSpan() + .setMinSupport(0.5) + .setMaxPatternLength(5); + PrefixSpanModel model = prefixSpan.run(sequences); + + File tempDir = Utils.createTempDir( + System.getProperty("java.io.tmpdir"), "JavaPrefixSpanSuite"); + String outputPath = tempDir.getPath(); + + try { + model.save(sc.sc(), outputPath); + PrefixSpanModel newModel = PrefixSpanModel.load(sc.sc(), outputPath); + JavaRDD> freqSeqs = newModel.freqSequences().toJavaRDD(); + List> localFreqSeqs = freqSeqs.collect(); + Assert.assertEquals(5, localFreqSeqs.size()); + // Check that each frequent sequence could be materialized. + for (PrefixSpan.FreqSequence freqSeq: localFreqSeqs) { + List> seq = freqSeq.javaSequence(); + long freq = freqSeq.freq(); + } + } finally { + Utils.deleteRecursively(tempDir); + } + + + } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala index a83e543859..6d8c7b47d8 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.mllib.fpm import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.util.Utils class PrefixSpanSuite extends SparkFunSuite with MLlibTestSparkContext { @@ -357,6 +358,36 @@ class PrefixSpanSuite extends SparkFunSuite with MLlibTestSparkContext { compareResults(expected, model.freqSequences.collect()) } + test("model save/load") { + val sequences = Seq( + Array(Array(1, 2), Array(3)), + Array(Array(1), Array(3, 2), Array(1, 2)), + Array(Array(1, 2), Array(5)), + Array(Array(6))) + val rdd = sc.parallelize(sequences, 2).cache() + + val prefixSpan = new PrefixSpan() + .setMinSupport(0.5) + .setMaxPatternLength(5) + val model = prefixSpan.run(rdd) + + val tempDir = Utils.createTempDir() + val path = tempDir.toURI.toString + try { + model.save(sc, path) + val newModel = PrefixSpanModel.load(sc, path) + val originalSet = model.freqSequences.collect().map { x => + (x.sequence.map(_.toSet).toSeq, x.freq) + }.toSet + val newSet = newModel.freqSequences.collect().map { x => + (x.sequence.map(_.toSet).toSeq, x.freq) + }.toSet + assert(originalSet === newSet) + } finally { + Utils.deleteRecursively(tempDir) + } + } + private def compareResults[Item]( expectedValue: Array[(Array[Array[Item]], Long)], actualValue: Array[PrefixSpan.FreqSequence[Item]]): Unit = { -- cgit v1.2.3 From 0d17593b32c12c3e39575430aa85cf20e56fae6a Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Wed, 13 Apr 2016 13:20:29 -0700 Subject: [SPARK-14461][ML] GLM training summaries should provide solver ## What changes were proposed in this pull request? GLM training summaries should provide solver. ## How was this patch tested? Unit tests. cc jkbradley Author: Yanbo Liang Closes #12253 from yanboliang/spark-14461. --- .../spark/ml/regression/GeneralizedLinearRegression.scala | 10 +++++++--- .../spark/ml/regression/GeneralizedLinearRegressionSuite.scala | 4 ++++ 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala index 00cf25dc54..e92a3e7fa1 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala @@ -237,7 +237,8 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val predictionColName, model, wlsModel.diagInvAtWA.toArray, - 1) + 1, + getSolver) return model.setSummary(trainingSummary) } @@ -257,7 +258,8 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val predictionColName, model, irlsModel.diagInvAtWA.toArray, - irlsModel.numIterations) + irlsModel.numIterations, + getSolver) model.setSummary(trainingSummary) } @@ -781,6 +783,7 @@ object GeneralizedLinearRegressionModel extends MLReadable[GeneralizedLinearRegr * @param model the model that should be summarized * @param diagInvAtWA diagonal of matrix (A^T * W * A)^-1 in the last iteration * @param numIterations number of iterations + * @param solver the solver algorithm used for model training */ @Since("2.0.0") @Experimental @@ -789,7 +792,8 @@ class GeneralizedLinearRegressionSummary private[regression] ( @Since("2.0.0") val predictionCol: String, @Since("2.0.0") val model: GeneralizedLinearRegressionModel, private val diagInvAtWA: Array[Double], - @Since("2.0.0") val numIterations: Int) extends Serializable { + @Since("2.0.0") val numIterations: Int, + @Since("2.0.0") val solver: String) extends Serializable { import GeneralizedLinearRegression._ diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala index 4905f3e068..3ecc210abd 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala @@ -626,6 +626,7 @@ class GeneralizedLinearRegressionSuite assert(summary.residualDegreeOfFreedom === residualDegreeOfFreedomR) assert(summary.residualDegreeOfFreedomNull === residualDegreeOfFreedomNullR) assert(summary.aic ~== aicR absTol 1E-3) + assert(summary.solver === "irls") } test("glm summary: binomial family with weight") { @@ -739,6 +740,7 @@ class GeneralizedLinearRegressionSuite assert(summary.residualDegreeOfFreedom === residualDegreeOfFreedomR) assert(summary.residualDegreeOfFreedomNull === residualDegreeOfFreedomNullR) assert(summary.aic ~== aicR absTol 1E-3) + assert(summary.solver === "irls") } test("glm summary: poisson family with weight") { @@ -855,6 +857,7 @@ class GeneralizedLinearRegressionSuite assert(summary.residualDegreeOfFreedom === residualDegreeOfFreedomR) assert(summary.residualDegreeOfFreedomNull === residualDegreeOfFreedomNullR) assert(summary.aic ~== aicR absTol 1E-3) + assert(summary.solver === "irls") } test("glm summary: gamma family with weight") { @@ -968,6 +971,7 @@ class GeneralizedLinearRegressionSuite assert(summary.residualDegreeOfFreedom === residualDegreeOfFreedomR) assert(summary.residualDegreeOfFreedomNull === residualDegreeOfFreedomNullR) assert(summary.aic ~== aicR absTol 1E-3) + assert(summary.solver === "irls") } test("read/write") { -- cgit v1.2.3 From a91aaf5a8cca18811c0cccc20f4e77f36231b344 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Wed, 13 Apr 2016 13:23:10 -0700 Subject: [SPARK-14375][ML] Unit test for spark.ml KMeansSummary ## What changes were proposed in this pull request? * Modify ```KMeansSummary.clusterSizes``` method to make it robust to empty clusters. * Add unit test for spark.ml ```KMeansSummary```. * Add Since tag. ## How was this patch tested? unit tests. cc jkbradley Author: Yanbo Liang Closes #12254 from yanboliang/spark-14375. --- .../org/apache/spark/ml/clustering/KMeans.scala | 35 ++++++++++++++++++---- .../org/apache/spark/ml/r/KMeansWrapper.scala | 2 +- .../apache/spark/ml/clustering/KMeansSuite.scala | 18 ++++++++++- 3 files changed, 47 insertions(+), 8 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala index d716bc6887..b324196842 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala @@ -143,6 +143,12 @@ class KMeansModel private[ml] ( this } + /** + * Return true if there exists summary of model. + */ + @Since("2.0.0") + def hasSummary: Boolean = trainingSummary.nonEmpty + /** * Gets summary of model on training set. An exception is * thrown if `trainingSummary == None`. @@ -267,7 +273,8 @@ class KMeans @Since("1.5.0") ( .setEpsilon($(tol)) val parentModel = algo.run(rdd) val model = copyValues(new KMeansModel(uid, parentModel).setParent(this)) - val summary = new KMeansSummary(model.transform(dataset), $(predictionCol), $(featuresCol)) + val summary = new KMeansSummary( + model.transform(dataset), $(predictionCol), $(featuresCol), $(k)) model.setSummary(summary) } @@ -284,10 +291,22 @@ object KMeans extends DefaultParamsReadable[KMeans] { override def load(path: String): KMeans = super.load(path) } +/** + * :: Experimental :: + * Summary of KMeans. + * + * @param predictions [[DataFrame]] produced by [[KMeansModel.transform()]] + * @param predictionCol Name for column of predicted clusters in `predictions` + * @param featuresCol Name for column of features in `predictions` + * @param k Number of clusters + */ +@Since("2.0.0") +@Experimental class KMeansSummary private[clustering] ( @Since("2.0.0") @transient val predictions: DataFrame, @Since("2.0.0") val predictionCol: String, - @Since("2.0.0") val featuresCol: String) extends Serializable { + @Since("2.0.0") val featuresCol: String, + @Since("2.0.0") val k: Int) extends Serializable { /** * Cluster centers of the transformed data. @@ -296,11 +315,15 @@ class KMeansSummary private[clustering] ( @transient lazy val cluster: DataFrame = predictions.select(predictionCol) /** - * Size of each cluster. + * Size of (number of data points in) each cluster. */ @Since("2.0.0") - lazy val clusterSizes: Array[Int] = cluster.rdd.map { - case Row(clusterIdx: Int) => (clusterIdx, 1) - }.reduceByKey(_ + _).collect().sortBy(_._1).map(_._2) + lazy val clusterSizes: Array[Long] = { + val sizes = Array.fill[Long](k)(0) + cluster.groupBy(predictionCol).count().select(predictionCol, "count").collect().foreach { + case Row(cluster: Int, count: Long) => sizes(cluster) = count + } + sizes + } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/KMeansWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/KMeansWrapper.scala index ee513579ce..9e2b81ee20 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/KMeansWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/KMeansWrapper.scala @@ -37,7 +37,7 @@ private[r] class KMeansWrapper private ( lazy val k: Int = kMeansModel.getK - lazy val size: Array[Int] = kMeansModel.summary.clusterSizes + lazy val size: Array[Long] = kMeansModel.summary.clusterSizes lazy val cluster: DataFrame = kMeansModel.summary.cluster diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala index 2076c745e2..2ca386e422 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala @@ -82,7 +82,7 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR } } - test("fit & transform") { + test("fit, transform, and summary") { val predictionColName = "kmeans_prediction" val kmeans = new KMeans().setK(k).setPredictionCol(predictionColName).setSeed(1) val model = kmeans.fit(dataset) @@ -99,6 +99,22 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR assert(clusters === Set(0, 1, 2, 3, 4)) assert(model.computeCost(dataset) < 0.1) assert(model.hasParent) + + // Check validity of model summary + val numRows = dataset.count() + assert(model.hasSummary) + val summary: KMeansSummary = model.summary + assert(summary.predictionCol === predictionColName) + assert(summary.featuresCol === "features") + assert(summary.predictions.count() === numRows) + for (c <- Array(predictionColName, "features")) { + assert(summary.predictions.columns.contains(c)) + } + assert(summary.cluster.columns === Array(predictionColName)) + val clusterSizes = summary.clusterSizes + assert(clusterSizes.length === k) + assert(clusterSizes.sum === numRows) + assert(clusterSizes.forall(_ >= 0)) } test("read/write") { -- cgit v1.2.3 From fcdd69260ec75c180f4d727ff2625ca9bf0bdad7 Mon Sep 17 00:00:00 2001 From: Zheng RuiFeng Date: Wed, 13 Apr 2016 13:56:23 -0700 Subject: [SPARK-14509][DOC] Add python CountVectorizerExample ## What changes were proposed in this pull request? Add python CountVectorizerExample ## How was this patch tested? manual tests Author: Zheng RuiFeng Closes #11917 from zhengruifeng/cv_pe. --- docs/ml-features.md | 9 +++++ .../src/main/python/ml/count_vectorizer_example.py | 44 ++++++++++++++++++++++ 2 files changed, 53 insertions(+) create mode 100644 examples/src/main/python/ml/count_vectorizer_example.py diff --git a/docs/ml-features.md b/docs/ml-features.md index 5cc27d3565..70812eb5e2 100644 --- a/docs/ml-features.md +++ b/docs/ml-features.md @@ -149,6 +149,15 @@ for more details on the API. {% include_example java/org/apache/spark/examples/ml/JavaCountVectorizerExample.java %}
    + +
    + +Refer to the [CountVectorizer Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.CountVectorizer) +and the [CountVectorizerModel Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.CountVectorizerModel) +for more details on the API. + +{% include_example python/ml/count_vectorizer_example.py %} +
    # Feature Transformers diff --git a/examples/src/main/python/ml/count_vectorizer_example.py b/examples/src/main/python/ml/count_vectorizer_example.py new file mode 100644 index 0000000000..e839f645f7 --- /dev/null +++ b/examples/src/main/python/ml/count_vectorizer_example.py @@ -0,0 +1,44 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +from pyspark import SparkContext +from pyspark.sql import SQLContext +# $example on$ +from pyspark.ml.feature import CountVectorizer +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="CountVectorizerExample") + sqlContext = SQLContext(sc) + + # $example on$ + # Input data: Each row is a bag of words with a ID. + df = sqlContext.createDataFrame([ + (0, "a b c".split(" ")), + (1, "a b b c a".split(" ")) + ], ["id", "words"]) + + # fit a CountVectorizerModel from the corpus. + cv = CountVectorizer(inputCol="words", outputCol="features", vocabSize=3, minDF=2.0) + model = cv.fit(df) + result = model.transform(df) + result.show() + # $example off$ + + sc.stop() -- cgit v1.2.3 From 781df499836e4216939e0febdcd5f89d30645759 Mon Sep 17 00:00:00 2001 From: Yuhao Yang Date: Wed, 13 Apr 2016 13:58:35 -0700 Subject: [SPARK-13089][ML] [Doc] spark.ml Naive Bayes user guide and examples jira: https://issues.apache.org/jira/browse/SPARK-13089 Add section in ml-classification.md for NaiveBayes DataFrame-based API, plus example code (using include_example to clip code from examples/ folder files). Author: Yuhao Yang Closes #11015 from hhbyyh/naiveBayesDoc. --- docs/ml-classification-regression.md | 34 ++++++++++++ .../spark/examples/ml/JavaNaiveBayesExample.java | 64 ++++++++++++++++++++++ examples/src/main/python/ml/naive_bayes_example.py | 53 ++++++++++++++++++ .../spark/examples/ml/NaiveBayesExample.scala | 58 ++++++++++++++++++++ 4 files changed, 209 insertions(+) create mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaNaiveBayesExample.java create mode 100644 examples/src/main/python/ml/naive_bayes_example.py create mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/NaiveBayesExample.scala diff --git a/docs/ml-classification-regression.md b/docs/ml-classification-regression.md index 45155c8ad1..eaf4f6d843 100644 --- a/docs/ml-classification-regression.md +++ b/docs/ml-classification-regression.md @@ -302,6 +302,40 @@ Refer to the [Java API docs](api/java/org/apache/spark/ml/classification/OneVsRe +## Naive Bayes + +[Naive Bayes](http://en.wikipedia.org/wiki/Naive_Bayes_classifier) are a family of simple +probabilistic classifiers based on applying Bayes' theorem with strong (naive) independence +assumptions between the features. The spark.ml implementation currently supports both [multinomial +naive Bayes](http://nlp.stanford.edu/IR-book/html/htmledition/naive-bayes-text-classification-1.html) +and [Bernoulli naive Bayes](http://nlp.stanford.edu/IR-book/html/htmledition/the-bernoulli-model-1.html). +More information can be found in the section on [Naive Bayes in MLlib](mllib-naive-bayes.html#naive-bayes-sparkmllib). + +**Example** + +
    +
    + +Refer to the [Scala API docs](api/scala/index.html#org.apache.spark.ml.classification.NaiveBayes) for more details. + +{% include_example scala/org/apache/spark/examples/ml/NaiveBayesExample.scala %} +
    + +
    + +Refer to the [Java API docs](api/java/org/apache/spark/ml/classification/NaiveBayes.html) for more details. + +{% include_example java/org/apache/spark/examples/ml/JavaNaiveBayesExample.java %} +
    + +
    + +Refer to the [Python API docs](api/python/pyspark.ml.html#pyspark.ml.classification.NaiveBayes) for more details. + +{% include_example python/ml/naive_bayes_example.py %} +
    +
    + # Regression diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaNaiveBayesExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaNaiveBayesExample.java new file mode 100644 index 0000000000..41d7ad75b9 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaNaiveBayesExample.java @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaSparkContext; +// $example on$ +import org.apache.spark.ml.classification.NaiveBayes; +import org.apache.spark.ml.classification.NaiveBayesModel; +import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SQLContext; +// $example off$ + +/** + * An example for Naive Bayes Classification. + */ +public class JavaNaiveBayesExample { + + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaNaiveBayesExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext jsql = new SQLContext(jsc); + + // $example on$ + // Load training data + Dataset dataFrame = jsql.read().format("libsvm").load("data/mllib/sample_libsvm_data.txt"); + // Split the data into train and test + Dataset[] splits = dataFrame.randomSplit(new double[]{0.6, 0.4}, 1234L); + Dataset train = splits[0]; + Dataset test = splits[1]; + + // create the trainer and set its parameters + NaiveBayes nb = new NaiveBayes(); + // train the model + NaiveBayesModel model = nb.fit(train); + // compute precision on the test set + Dataset result = model.transform(test); + Dataset predictionAndLabels = result.select("prediction", "label"); + MulticlassClassificationEvaluator evaluator = new MulticlassClassificationEvaluator() + .setMetricName("precision"); + System.out.println("Precision = " + evaluator.evaluate(predictionAndLabels)); + // $example off$ + + jsc.stop(); + } +} diff --git a/examples/src/main/python/ml/naive_bayes_example.py b/examples/src/main/python/ml/naive_bayes_example.py new file mode 100644 index 0000000000..db8fbea9bf --- /dev/null +++ b/examples/src/main/python/ml/naive_bayes_example.py @@ -0,0 +1,53 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +from pyspark import SparkContext +from pyspark.sql import SQLContext +# $example on$ +from pyspark.ml.classification import NaiveBayes +from pyspark.ml.evaluation import MulticlassClassificationEvaluator +# $example off$ + +if __name__ == "__main__": + + sc = SparkContext(appName="naive_bayes_example") + sqlContext = SQLContext(sc) + + # $example on$ + # Load training data + data = sqlContext.read.format("libsvm") \ + .load("data/mllib/sample_libsvm_data.txt") + # Split the data into train and test + splits = data.randomSplit([0.6, 0.4], 1234) + train = splits[0] + test = splits[1] + + # create the trainer and set its parameters + nb = NaiveBayes(smoothing=1.0, modelType="multinomial") + + # train the model + model = nb.fit(train) + # compute precision on the test set + result = model.transform(test) + predictionAndLabels = result.select("prediction", "label") + evaluator = MulticlassClassificationEvaluator(metricName="precision") + print("Precision:" + str(evaluator.evaluate(predictionAndLabels))) + # $example off$ + + sc.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/NaiveBayesExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/NaiveBayesExample.scala new file mode 100644 index 0000000000..5ea1270c97 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/NaiveBayesExample.scala @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +import org.apache.spark.{SparkConf, SparkContext} +// $example on$ +import org.apache.spark.ml.classification.{NaiveBayes} +import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator +// $example off$ +import org.apache.spark.sql.SQLContext + +object NaiveBayesExample { + def main(args: Array[String]): Unit = { + val conf = new SparkConf().setAppName("NaiveBayesExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + // $example on$ + // Load the data stored in LIBSVM format as a DataFrame. + val data = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") + + // Split the data into training and test sets (30% held out for testing) + val Array(trainingData, testData) = data.randomSplit(Array(0.7, 0.3)) + + // Train a NaiveBayes model. + val model = new NaiveBayes() + .fit(trainingData) + + // Select example rows to display. + val predictions = model.transform(testData) + predictions.show() + + // Select (prediction, true label) and compute test error + val evaluator = new MulticlassClassificationEvaluator() + .setLabelCol("label") + .setPredictionCol("prediction") + .setMetricName("precision") + val precision = evaluator.evaluate(predictions) + println("Precision:" + precision) + // $example off$ + } +} +// scalastyle:on println -- cgit v1.2.3 From fc3cd2f5090b3ba1cfde0fca3b3ce632d0b2f9c4 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Wed, 13 Apr 2016 14:08:57 -0700 Subject: [SPARK-14472][PYSPARK][ML] Cleanup ML JavaWrapper and related class hierarchy Currently, JavaWrapper is only a wrapper class for pipeline classes that have Params and JavaCallable is a separate mixin that provides methods to make Java calls. This change simplifies the class structure and to define the Java wrapper in a plain base class along with methods to make Java calls. Also, renames Java wrapper classes to better reflect their purpose. Ran existing Python ml tests and generated documentation to test this change. Author: Bryan Cutler Closes #12304 from BryanCutler/pyspark-cleanup-JavaWrapper-SPARK-14472. --- python/pyspark/ml/classification.py | 4 +- python/pyspark/ml/evaluation.py | 4 +- python/pyspark/ml/pipeline.py | 10 ++--- python/pyspark/ml/regression.py | 4 +- python/pyspark/ml/tests.py | 4 +- python/pyspark/ml/tuning.py | 26 ++++++------- python/pyspark/ml/util.py | 4 +- python/pyspark/ml/wrapper.py | 76 +++++++++++++++++-------------------- 8 files changed, 62 insertions(+), 70 deletions(-) diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index e64c7a392b..922f8069fa 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -19,7 +19,7 @@ import warnings from pyspark import since from pyspark.ml.util import * -from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaCallable +from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaWrapper from pyspark.ml.param import TypeConverters from pyspark.ml.param.shared import * from pyspark.ml.regression import ( @@ -272,7 +272,7 @@ class LogisticRegressionModel(JavaModel, JavaMLWritable, JavaMLReadable): return BinaryLogisticRegressionSummary(java_blr_summary) -class LogisticRegressionSummary(JavaCallable): +class LogisticRegressionSummary(JavaWrapper): """ Abstraction for Logistic Regression Results for a given model. diff --git a/python/pyspark/ml/evaluation.py b/python/pyspark/ml/evaluation.py index c9b95b3bf4..4b0bade102 100644 --- a/python/pyspark/ml/evaluation.py +++ b/python/pyspark/ml/evaluation.py @@ -18,7 +18,7 @@ from abc import abstractmethod, ABCMeta from pyspark import since -from pyspark.ml.wrapper import JavaWrapper +from pyspark.ml.wrapper import JavaParams from pyspark.ml.param import Param, Params from pyspark.ml.param.shared import HasLabelCol, HasPredictionCol, HasRawPredictionCol from pyspark.ml.util import keyword_only @@ -81,7 +81,7 @@ class Evaluator(Params): @inherit_doc -class JavaEvaluator(Evaluator, JavaWrapper): +class JavaEvaluator(JavaParams, Evaluator): """ Base class for :py:class:`Evaluator`s that wrap Java/Scala implementations. diff --git a/python/pyspark/ml/pipeline.py b/python/pyspark/ml/pipeline.py index 2b5504bc29..9d654e8b0f 100644 --- a/python/pyspark/ml/pipeline.py +++ b/python/pyspark/ml/pipeline.py @@ -25,7 +25,7 @@ from pyspark import since from pyspark.ml import Estimator, Model, Transformer from pyspark.ml.param import Param, Params from pyspark.ml.util import keyword_only, JavaMLWriter, JavaMLReader, MLReadable, MLWritable -from pyspark.ml.wrapper import JavaWrapper +from pyspark.ml.wrapper import JavaParams from pyspark.mllib.common import inherit_doc @@ -177,7 +177,7 @@ class Pipeline(Estimator, MLReadable, MLWritable): # Create a new instance of this stage. py_stage = cls() # Load information from java_stage to the instance. - py_stages = [JavaWrapper._from_java(s) for s in java_stage.getStages()] + py_stages = [JavaParams._from_java(s) for s in java_stage.getStages()] py_stage.setStages(py_stages) py_stage._resetUid(java_stage.uid()) return py_stage @@ -195,7 +195,7 @@ class Pipeline(Estimator, MLReadable, MLWritable): for idx, stage in enumerate(self.getStages()): java_stages[idx] = stage._to_java() - _java_obj = JavaWrapper._new_java_obj("org.apache.spark.ml.Pipeline", self.uid) + _java_obj = JavaParams._new_java_obj("org.apache.spark.ml.Pipeline", self.uid) _java_obj.setStages(java_stages) return _java_obj @@ -275,7 +275,7 @@ class PipelineModel(Model, MLReadable, MLWritable): Used for ML persistence. """ # Load information from java_stage to the instance. - py_stages = [JavaWrapper._from_java(s) for s in java_stage.stages()] + py_stages = [JavaParams._from_java(s) for s in java_stage.stages()] # Create a new instance of this stage. py_stage = cls(py_stages) py_stage._resetUid(java_stage.uid()) @@ -295,6 +295,6 @@ class PipelineModel(Model, MLReadable, MLWritable): java_stages[idx] = stage._to_java() _java_obj =\ - JavaWrapper._new_java_obj("org.apache.spark.ml.PipelineModel", self.uid, java_stages) + JavaParams._new_java_obj("org.apache.spark.ml.PipelineModel", self.uid, java_stages) return _java_obj diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py index bc88f88b7f..316d7e30bc 100644 --- a/python/pyspark/ml/regression.py +++ b/python/pyspark/ml/regression.py @@ -20,7 +20,7 @@ import warnings from pyspark import since from pyspark.ml.param.shared import * from pyspark.ml.util import * -from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaCallable +from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaWrapper from pyspark.mllib.common import inherit_doc from pyspark.sql import DataFrame @@ -188,7 +188,7 @@ class LinearRegressionModel(JavaModel, JavaMLWritable, JavaMLReadable): return LinearRegressionSummary(java_lr_summary) -class LinearRegressionSummary(JavaCallable): +class LinearRegressionSummary(JavaWrapper): """ .. note:: Experimental diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index 2dcd5eeb52..bcbeacbe80 100644 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -52,7 +52,7 @@ from pyspark.ml.regression import LinearRegression, DecisionTreeRegressor from pyspark.ml.tuning import * from pyspark.ml.util import keyword_only from pyspark.ml.util import MLWritable, MLWriter -from pyspark.ml.wrapper import JavaWrapper +from pyspark.ml.wrapper import JavaParams from pyspark.mllib.linalg import Vectors, DenseVector, SparseVector from pyspark.sql import DataFrame, SQLContext, Row from pyspark.sql.functions import rand @@ -644,7 +644,7 @@ class PersistenceTest(PySparkTestCase): """ self.assertEqual(m1.uid, m2.uid) self.assertEqual(type(m1), type(m2)) - if isinstance(m1, JavaWrapper): + if isinstance(m1, JavaParams): self.assertEqual(len(m1.params), len(m2.params)) for p in m1.params: self.assertEqual(m1.getOrDefault(p), m2.getOrDefault(p)) diff --git a/python/pyspark/ml/tuning.py b/python/pyspark/ml/tuning.py index ea8c61b7ef..456d79d897 100644 --- a/python/pyspark/ml/tuning.py +++ b/python/pyspark/ml/tuning.py @@ -24,7 +24,7 @@ from pyspark.ml import Estimator, Model from pyspark.ml.param import Params, Param, TypeConverters from pyspark.ml.param.shared import HasSeed from pyspark.ml.util import keyword_only, JavaMLWriter, JavaMLReader, MLReadable, MLWritable -from pyspark.ml.wrapper import JavaWrapper +from pyspark.ml.wrapper import JavaParams from pyspark.sql.functions import rand from pyspark.mllib.common import inherit_doc, _py2java @@ -148,8 +148,8 @@ class ValidatorParams(HasSeed): """ # Load information from java_stage to the instance. - estimator = JavaWrapper._from_java(java_stage.getEstimator()) - evaluator = JavaWrapper._from_java(java_stage.getEvaluator()) + estimator = JavaParams._from_java(java_stage.getEstimator()) + evaluator = JavaParams._from_java(java_stage.getEvaluator()) epms = [estimator._transfer_param_map_from_java(epm) for epm in java_stage.getEstimatorParamMaps()] return estimator, epms, evaluator @@ -329,7 +329,7 @@ class CrossValidator(Estimator, ValidatorParams, MLReadable, MLWritable): estimator, epms, evaluator = super(CrossValidator, self)._to_java_impl() - _java_obj = JavaWrapper._new_java_obj("org.apache.spark.ml.tuning.CrossValidator", self.uid) + _java_obj = JavaParams._new_java_obj("org.apache.spark.ml.tuning.CrossValidator", self.uid) _java_obj.setEstimatorParamMaps(epms) _java_obj.setEvaluator(evaluator) _java_obj.setEstimator(estimator) @@ -393,7 +393,7 @@ class CrossValidatorModel(Model, ValidatorParams, MLReadable, MLWritable): """ # Load information from java_stage to the instance. - bestModel = JavaWrapper._from_java(java_stage.bestModel()) + bestModel = JavaParams._from_java(java_stage.bestModel()) estimator, epms, evaluator = super(CrossValidatorModel, cls)._from_java_impl(java_stage) # Create a new instance of this stage. py_stage = cls(bestModel=bestModel)\ @@ -410,10 +410,10 @@ class CrossValidatorModel(Model, ValidatorParams, MLReadable, MLWritable): sc = SparkContext._active_spark_context - _java_obj = JavaWrapper._new_java_obj("org.apache.spark.ml.tuning.CrossValidatorModel", - self.uid, - self.bestModel._to_java(), - _py2java(sc, [])) + _java_obj = JavaParams._new_java_obj("org.apache.spark.ml.tuning.CrossValidatorModel", + self.uid, + self.bestModel._to_java(), + _py2java(sc, [])) estimator, epms, evaluator = super(CrossValidatorModel, self)._to_java_impl() _java_obj.set("evaluator", evaluator) @@ -574,8 +574,8 @@ class TrainValidationSplit(Estimator, ValidatorParams, MLReadable, MLWritable): estimator, epms, evaluator = super(TrainValidationSplit, self)._to_java_impl() - _java_obj = JavaWrapper._new_java_obj("org.apache.spark.ml.tuning.TrainValidationSplit", - self.uid) + _java_obj = JavaParams._new_java_obj("org.apache.spark.ml.tuning.TrainValidationSplit", + self.uid) _java_obj.setEstimatorParamMaps(epms) _java_obj.setEvaluator(evaluator) _java_obj.setEstimator(estimator) @@ -639,7 +639,7 @@ class TrainValidationSplitModel(Model, ValidatorParams, MLReadable, MLWritable): """ # Load information from java_stage to the instance. - bestModel = JavaWrapper._from_java(java_stage.bestModel()) + bestModel = JavaParams._from_java(java_stage.bestModel()) estimator, epms, evaluator = \ super(TrainValidationSplitModel, cls)._from_java_impl(java_stage) # Create a new instance of this stage. @@ -657,7 +657,7 @@ class TrainValidationSplitModel(Model, ValidatorParams, MLReadable, MLWritable): sc = SparkContext._active_spark_context - _java_obj = JavaWrapper._new_java_obj( + _java_obj = JavaParams._new_java_obj( "org.apache.spark.ml.tuning.TrainValidationSplitModel", self.uid, self.bestModel._to_java(), diff --git a/python/pyspark/ml/util.py b/python/pyspark/ml/util.py index d4411fdfb9..9dfcef0e40 100644 --- a/python/pyspark/ml/util.py +++ b/python/pyspark/ml/util.py @@ -99,7 +99,7 @@ class MLWriter(object): @inherit_doc class JavaMLWriter(MLWriter): """ - (Private) Specialization of :py:class:`MLWriter` for :py:class:`JavaWrapper` types + (Private) Specialization of :py:class:`MLWriter` for :py:class:`JavaParams` types """ def __init__(self, instance): @@ -178,7 +178,7 @@ class MLReader(object): @inherit_doc class JavaMLReader(MLReader): """ - (Private) Specialization of :py:class:`MLReader` for :py:class:`JavaWrapper` types + (Private) Specialization of :py:class:`MLReader` for :py:class:`JavaParams` types """ def __init__(self, clazz): diff --git a/python/pyspark/ml/wrapper.py b/python/pyspark/ml/wrapper.py index bbeb6cfe6f..cd0e5b80d5 100644 --- a/python/pyspark/ml/wrapper.py +++ b/python/pyspark/ml/wrapper.py @@ -25,29 +25,32 @@ from pyspark.ml.util import _jvm from pyspark.mllib.common import inherit_doc, _java2py, _py2java -@inherit_doc -class JavaWrapper(Params): +class JavaWrapper(object): """ - Utility class to help create wrapper classes from Java/Scala - implementations of pipeline components. + Wrapper class for a Java companion object """ + def __init__(self, java_obj=None): + super(JavaWrapper, self).__init__() + self._java_obj = java_obj - __metaclass__ = ABCMeta - - def __init__(self): + @classmethod + def _create_from_java_class(cls, java_class, *args): """ - Initialize the wrapped java object to None + Construct this object from given Java classname and arguments """ - super(JavaWrapper, self).__init__() - #: The wrapped Java companion object. Subclasses should initialize - #: it properly. The param values in the Java object should be - #: synced with the Python wrapper in fit/transform/evaluate/copy. - self._java_obj = None + java_obj = JavaWrapper._new_java_obj(java_class, *args) + return cls(java_obj) + + def _call_java(self, name, *args): + m = getattr(self._java_obj, name) + sc = SparkContext._active_spark_context + java_args = [_py2java(sc, arg) for arg in args] + return _java2py(sc, m(*java_args)) @staticmethod def _new_java_obj(java_class, *args): """ - Construct a new Java object. + Returns a new Java object. """ sc = SparkContext._active_spark_context java_obj = _jvm() @@ -56,6 +59,18 @@ class JavaWrapper(Params): java_args = [_py2java(sc, arg) for arg in args] return java_obj(*java_args) + +@inherit_doc +class JavaParams(JavaWrapper, Params): + """ + Utility class to help create wrapper classes from Java/Scala + implementations of pipeline components. + """ + #: The param values in the Java object should be + #: synced with the Python wrapper in fit/transform/evaluate/copy. + + __metaclass__ = ABCMeta + def _make_java_param_pair(self, param, value): """ Makes a Java parm pair. @@ -151,7 +166,7 @@ class JavaWrapper(Params): stage_name = java_stage.getClass().getName().replace("org.apache.spark", "pyspark") # Generate a default new instance from the stage_name class. py_type = __get_class(stage_name) - if issubclass(py_type, JavaWrapper): + if issubclass(py_type, JavaParams): # Load information from java_stage to the instance. py_stage = py_type() py_stage._java_obj = java_stage @@ -166,7 +181,7 @@ class JavaWrapper(Params): @inherit_doc -class JavaEstimator(Estimator, JavaWrapper): +class JavaEstimator(JavaParams, Estimator): """ Base class for :py:class:`Estimator`s that wrap Java/Scala implementations. @@ -199,7 +214,7 @@ class JavaEstimator(Estimator, JavaWrapper): @inherit_doc -class JavaTransformer(Transformer, JavaWrapper): +class JavaTransformer(JavaParams, Transformer): """ Base class for :py:class:`Transformer`s that wrap Java/Scala implementations. Subclasses should ensure they have the transformer Java object @@ -213,30 +228,8 @@ class JavaTransformer(Transformer, JavaWrapper): return DataFrame(self._java_obj.transform(dataset._jdf), dataset.sql_ctx) -class JavaCallable(object): - """ - Wrapper for a plain object in JVM to make Java calls, can be used - as a mixin to another class that defines a _java_obj wrapper - """ - def __init__(self, java_obj=None, sc=None): - super(JavaCallable, self).__init__() - self._sc = sc if sc is not None else SparkContext._active_spark_context - # if this class is a mixin and _java_obj is already defined then don't initialize - if java_obj is not None or not hasattr(self, "_java_obj"): - self._java_obj = java_obj - - def __del__(self): - if self._java_obj is not None: - self._sc._gateway.detach(self._java_obj) - - def _call_java(self, name, *args): - m = getattr(self._java_obj, name) - java_args = [_py2java(self._sc, arg) for arg in args] - return _java2py(self._sc, m(*java_args)) - - @inherit_doc -class JavaModel(Model, JavaCallable, JavaTransformer): +class JavaModel(JavaTransformer, Model): """ Base class for :py:class:`Model`s that wrap Java/Scala implementations. Subclasses should inherit this class before @@ -259,9 +252,8 @@ class JavaModel(Model, JavaCallable, JavaTransformer): these wrappers depend on pyspark.ml.util (both directly and via other ML classes). """ - super(JavaModel, self).__init__() + super(JavaModel, self).__init__(java_model) if java_model is not None: - self._java_obj = java_model self.uid = java_model.uid() def copy(self, extra=None): -- cgit v1.2.3 From 62b7f306fbf77de7f6cbb36181ebebdb4a55acc5 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Wed, 13 Apr 2016 17:17:19 -0700 Subject: [SPARK-14607] [SPARK-14484] [SQL] fix case-insensitive predicates in FileSourceStrategy ## What changes were proposed in this pull request? When prune the partitions or push down predicates, case-sensitivity is not respected. In order to make it work with case-insensitive, this PR update the AttributeReference inside predicate to use the name from schema. ## How was this patch tested? Add regression tests for case-insensitive. Author: Davies Liu Closes #12371 from davies/case_insensi. --- .../execution/datasources/FileSourceStrategy.scala | 14 +++++++++-- .../org/apache/spark/sql/sources/interfaces.scala | 5 +--- .../datasources/FileSourceStrategySuite.scala | 28 ++++++++++++++++++++++ 3 files changed, 41 insertions(+), 6 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala index bcddf72851..80a9156ddc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala @@ -64,18 +64,28 @@ private[sql] object FileSourceStrategy extends Strategy with Logging { // - filters that need to be evaluated again after the scan val filterSet = ExpressionSet(filters) + // The attribute name of predicate could be different than the one in schema in case of + // case insensitive, we should change them to match the one in schema, so we donot need to + // worry about case sensitivity anymore. + val normalizedFilters = filters.map { e => + e transform { + case a: AttributeReference => + a.withName(l.output.find(_.semanticEquals(a)).get.name) + } + } + val partitionColumns = l.resolve(files.partitionSchema, files.sqlContext.sessionState.analyzer.resolver) val partitionSet = AttributeSet(partitionColumns) val partitionKeyFilters = - ExpressionSet(filters.filter(_.references.subsetOf(partitionSet))) + ExpressionSet(normalizedFilters.filter(_.references.subsetOf(partitionSet))) logInfo(s"Pruning directories with: ${partitionKeyFilters.mkString(",")}") val dataColumns = l.resolve(files.dataSchema, files.sqlContext.sessionState.analyzer.resolver) // Partition keys are not available in the statistics of the files. - val dataFilters = filters.filter(_.references.intersect(partitionSet).isEmpty) + val dataFilters = normalizedFilters.filter(_.references.intersect(partitionSet).isEmpty) // Predicates with both partition keys and attributes need to be evaluated after the scan. val afterScanFilters = filterSet -- partitionKeyFilters diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala index bea243a3be..4b9bf8daae 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala @@ -593,10 +593,7 @@ class HDFSFileCatalog( } if (partitionPruningPredicates.nonEmpty) { - val predicate = - partitionPruningPredicates - .reduceOption(expressions.And) - .getOrElse(Literal(true)) + val predicate = partitionPruningPredicates.reduce(expressions.And) val boundPredicate = InterpretedPredicate.create(predicate.transform { case a: AttributeReference => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala index 90d7f53884..0b74f07540 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala @@ -196,6 +196,34 @@ class FileSourceStrategySuite extends QueryTest with SharedSQLContext with Predi checkDataFilters(Set(IsNotNull("c1"), EqualTo("c1", 1))) } + test("partitioned table - case insensitive") { + withSQLConf("spark.sql.caseSensitive" -> "false") { + val table = + createTable( + files = Seq( + "p1=1/file1" -> 10, + "p1=2/file2" -> 10)) + + // Only one file should be read. + checkScan(table.where("P1 = 1")) { partitions => + assert(partitions.size == 1, "when checking partitions") + assert(partitions.head.files.size == 1, "when files in partition 1") + } + // We don't need to reevaluate filters that are only on partitions. + checkDataFilters(Set.empty) + + // Only one file should be read. + checkScan(table.where("P1 = 1 AND C1 = 1 AND (P1 + C1) = 1")) { partitions => + assert(partitions.size == 1, "when checking partitions") + assert(partitions.head.files.size == 1, "when checking files in partition 1") + assert(partitions.head.files.head.partitionValues.getInt(0) == 1, + "when checking partition values") + } + // Only the filters that do not contain the partition column should be pushed down + checkDataFilters(Set(IsNotNull("c1"), EqualTo("c1", 1))) + } + } + test("partitioned table - after scan filters") { val table = createTable( -- cgit v1.2.3 From b4819404a65f9b97c1f8deb1fcb8419969831574 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Thu, 14 Apr 2016 15:43:44 +0800 Subject: [SPARK-14596][SQL] Remove not used SqlNewHadoopRDD and some more unused imports ## What changes were proposed in this pull request? Old `HadoopFsRelation` API includes `buildInternalScan()` which uses `SqlNewHadoopRDD` in `ParquetRelation`. Because now the old API is removed, `SqlNewHadoopRDD` is not used anymore. So, this PR removes `SqlNewHadoopRDD` and several unused imports. This was discussed in https://github.com/apache/spark/pull/12326. ## How was this patch tested? Several related existing unit tests and `sbt scalastyle`. Author: hyukjinkwon Closes #12354 from HyukjinKwon/SPARK-14596. --- .../scala/org/apache/spark/rdd/HadoopRDD.scala | 8 +- .../org/apache/spark/rdd/InputFileNameHolder.scala | 41 +++ .../apache/spark/rdd/SqlNewHadoopRDDState.scala | 41 --- project/MimaExcludes.scala | 5 - .../sql/catalyst/expressions/InputFileName.scala | 8 +- .../sql/execution/datasources/FileScanRDD.scala | 11 +- .../execution/datasources/SqlNewHadoopRDD.scala | 282 --------------------- .../org/apache/spark/sql/internal/SQLConf.scala | 1 - .../datasources/FileSourceStrategySuite.scala | 5 +- 9 files changed, 54 insertions(+), 348 deletions(-) create mode 100644 core/src/main/scala/org/apache/spark/rdd/InputFileNameHolder.scala delete mode 100644 core/src/main/scala/org/apache/spark/rdd/SqlNewHadoopRDDState.scala delete mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SqlNewHadoopRDD.scala diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala index 08db96edd6..ac5ba9e79f 100644 --- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala @@ -213,15 +213,13 @@ class HadoopRDD[K, V]( logInfo("Input split: " + split.inputSplit) val jobConf = getJobConf() - // TODO: there is a lot of duplicate code between this and NewHadoopRDD and SqlNewHadoopRDD - val inputMetrics = context.taskMetrics().registerInputMetrics(DataReadMethod.Hadoop) val existingBytesRead = inputMetrics.bytesRead // Sets the thread local variable for the file's name split.inputSplit.value match { - case fs: FileSplit => SqlNewHadoopRDDState.setInputFileName(fs.getPath.toString) - case _ => SqlNewHadoopRDDState.unsetInputFileName() + case fs: FileSplit => InputFileNameHolder.setInputFileName(fs.getPath.toString) + case _ => InputFileNameHolder.unsetInputFileName() } // Find a function that will return the FileSystem bytes read by this thread. Do this before @@ -271,7 +269,7 @@ class HadoopRDD[K, V]( override def close() { if (reader != null) { - SqlNewHadoopRDDState.unsetInputFileName() + InputFileNameHolder.unsetInputFileName() // Close the reader and release it. Note: it's very important that we don't close the // reader more than once, since that exposes us to MAPREDUCE-5918 when running against // Hadoop 1.x and older Hadoop 2.x releases. That bug can lead to non-deterministic diff --git a/core/src/main/scala/org/apache/spark/rdd/InputFileNameHolder.scala b/core/src/main/scala/org/apache/spark/rdd/InputFileNameHolder.scala new file mode 100644 index 0000000000..108e9d2558 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/rdd/InputFileNameHolder.scala @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rdd + +import org.apache.spark.unsafe.types.UTF8String + +/** + * This holds file names of the current Spark task. This is used in HadoopRDD, + * FileScanRDD and InputFileName function in Spark SQL. + */ +private[spark] object InputFileNameHolder { + /** + * The thread variable for the name of the current file being read. This is used by + * the InputFileName function in Spark SQL. + */ + private[this] val inputFileName: ThreadLocal[UTF8String] = new ThreadLocal[UTF8String] { + override protected def initialValue(): UTF8String = UTF8String.fromString("") + } + + def getInputFileName(): UTF8String = inputFileName.get() + + private[spark] def setInputFileName(file: String) = inputFileName.set(UTF8String.fromString(file)) + + private[spark] def unsetInputFileName(): Unit = inputFileName.remove() + +} diff --git a/core/src/main/scala/org/apache/spark/rdd/SqlNewHadoopRDDState.scala b/core/src/main/scala/org/apache/spark/rdd/SqlNewHadoopRDDState.scala deleted file mode 100644 index 3f15fff793..0000000000 --- a/core/src/main/scala/org/apache/spark/rdd/SqlNewHadoopRDDState.scala +++ /dev/null @@ -1,41 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.rdd - -import org.apache.spark.unsafe.types.UTF8String - -/** - * State for SqlNewHadoopRDD objects. This is split this way because of the package splits. - * TODO: Move/Combine this with org.apache.spark.sql.datasources.SqlNewHadoopRDD - */ -private[spark] object SqlNewHadoopRDDState { - /** - * The thread variable for the name of the current file being read. This is used by - * the InputFileName function in Spark SQL. - */ - private[this] val inputFileName: ThreadLocal[UTF8String] = new ThreadLocal[UTF8String] { - override protected def initialValue(): UTF8String = UTF8String.fromString("") - } - - def getInputFileName(): UTF8String = inputFileName.get() - - private[spark] def setInputFileName(file: String) = inputFileName.set(UTF8String.fromString(file)) - - private[spark] def unsetInputFileName(): Unit = inputFileName.remove() - -} diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index a30581eb48..313bf93b5d 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -847,7 +847,6 @@ object MimaExcludes { ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.PreInsertCastAndRename$"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.CreateTableUsingAsSelect$"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.InsertIntoDataSource$"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.SqlNewHadoopPartition"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.PartitioningUtils$PartitionValues$"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.DefaultWriterContainer"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.PartitioningUtils$PartitionValues"), @@ -856,10 +855,8 @@ object MimaExcludes { ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.PartitionSpec"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.DynamicPartitionWriterContainer"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.CreateTableUsingAsSelect"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.SqlNewHadoopRDD$"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.DescribeCommand$"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.PartitioningUtils$"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.SqlNewHadoopRDD"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.PreInsertCastAndRename"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.Partition$"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.LogicalRelation$"), @@ -870,7 +867,6 @@ object MimaExcludes { ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.PreWriteCheck"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.CreateTableUsing"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.RefreshTable"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.SqlNewHadoopRDD$NewHadoopMapPartitionsWithSplitRDD"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.DataSourceStrategy$"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.CreateTempTableUsing"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.CreateTempTableUsingAsSelect$"), @@ -884,7 +880,6 @@ object MimaExcludes { ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.CaseInsensitiveMap"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.InsertIntoHadoopFsRelation$"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.DataSourceStrategy"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.SqlNewHadoopRDD$NewHadoopMapPartitionsWithSplitRDD$"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.PartitionSpec$"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.DescribeCommand"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.DDLException"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InputFileName.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InputFileName.scala index dbd0acf06c..2ed6fc0d38 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InputFileName.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InputFileName.scala @@ -17,14 +17,14 @@ package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.rdd.SqlNewHadoopRDDState +import org.apache.spark.rdd.InputFileNameHolder import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} import org.apache.spark.sql.types.{DataType, StringType} import org.apache.spark.unsafe.types.UTF8String /** - * Expression that returns the name of the current file being read in using [[SqlNewHadoopRDD]] + * Expression that returns the name of the current file being read. */ @ExpressionDescription( usage = "_FUNC_() - Returns the name of the current file being read if available", @@ -40,12 +40,12 @@ case class InputFileName() extends LeafExpression with Nondeterministic { override protected def initInternal(): Unit = {} override protected def evalInternal(input: InternalRow): UTF8String = { - SqlNewHadoopRDDState.getInputFileName() + InputFileNameHolder.getInputFileName() } override def genCode(ctx: CodegenContext, ev: ExprCode): String = { ev.isNull = "false" s"final ${ctx.javaType(dataType)} ${ev.value} = " + - "org.apache.spark.rdd.SqlNewHadoopRDDState.getInputFileName();" + "org.apache.spark.rdd.InputFileNameHolder.getInputFileName();" } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala index 988c785dbe..468e101fed 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution.datasources import org.apache.spark.{Partition, TaskContext} -import org.apache.spark.rdd.{RDD, SqlNewHadoopRDDState} +import org.apache.spark.rdd.{InputFileNameHolder, RDD} import org.apache.spark.sql.SQLContext import org.apache.spark.sql.catalyst.InternalRow @@ -37,7 +37,6 @@ case class PartitionedFile( } } - /** * A collection of files that should be read as a single task possibly from multiple partitioned * directories. @@ -50,7 +49,7 @@ class FileScanRDD( @transient val sqlContext: SQLContext, readFunction: (PartitionedFile) => Iterator[InternalRow], @transient val filePartitions: Seq[FilePartition]) - extends RDD[InternalRow](sqlContext.sparkContext, Nil) { + extends RDD[InternalRow](sqlContext.sparkContext, Nil) { override def compute(split: Partition, context: TaskContext): Iterator[InternalRow] = { val iterator = new Iterator[Object] with AutoCloseable { @@ -65,17 +64,17 @@ class FileScanRDD( if (files.hasNext) { val nextFile = files.next() logInfo(s"Reading File $nextFile") - SqlNewHadoopRDDState.setInputFileName(nextFile.filePath) + InputFileNameHolder.setInputFileName(nextFile.filePath) currentIterator = readFunction(nextFile) hasNext } else { - SqlNewHadoopRDDState.unsetInputFileName() + InputFileNameHolder.unsetInputFileName() false } } override def close() = { - SqlNewHadoopRDDState.unsetInputFileName() + InputFileNameHolder.unsetInputFileName() } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SqlNewHadoopRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SqlNewHadoopRDD.scala deleted file mode 100644 index 4d6864d8ba..0000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SqlNewHadoopRDD.scala +++ /dev/null @@ -1,282 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.rdd - -import java.text.SimpleDateFormat -import java.util.Date - -import scala.reflect.ClassTag - -import org.apache.hadoop.conf.{Configurable, Configuration} -import org.apache.hadoop.io.Writable -import org.apache.hadoop.mapreduce._ -import org.apache.hadoop.mapreduce.lib.input.{CombineFileSplit, FileSplit} -import org.apache.hadoop.mapreduce.task.{JobContextImpl, TaskAttemptContextImpl} - -import org.apache.spark.{Partition => SparkPartition, _} -import org.apache.spark.broadcast.Broadcast -import org.apache.spark.deploy.SparkHadoopUtil -import org.apache.spark.executor.DataReadMethod -import org.apache.spark.internal.Logging -import org.apache.spark.sql.SQLContext -import org.apache.spark.storage.StorageLevel -import org.apache.spark.util.{SerializableConfiguration, ShutdownHookManager} - -private[spark] class SqlNewHadoopPartition( - rddId: Int, - val index: Int, - rawSplit: InputSplit with Writable) - extends SparkPartition { - - val serializableHadoopSplit = new SerializableWritable(rawSplit) - - override def hashCode(): Int = 41 * (41 + rddId) + index -} - -/** - * An RDD that provides core functionality for reading data stored in Hadoop (e.g., files in HDFS, - * sources in HBase, or S3), using the new MapReduce API (`org.apache.hadoop.mapreduce`). - * It is based on [[org.apache.spark.rdd.NewHadoopRDD]]. It has three additions. - * 1. A shared broadcast Hadoop Configuration. - * 2. An optional closure `initDriverSideJobFuncOpt` that set configurations at the driver side - * to the shared Hadoop Configuration. - * 3. An optional closure `initLocalJobFuncOpt` that set configurations at both the driver side - * and the executor side to the shared Hadoop Configuration. - * - * Note: This is RDD is basically a cloned version of [[org.apache.spark.rdd.NewHadoopRDD]] with - * changes based on [[org.apache.spark.rdd.HadoopRDD]]. - */ -private[spark] class SqlNewHadoopRDD[V: ClassTag]( - sqlContext: SQLContext, - broadcastedConf: Broadcast[SerializableConfiguration], - @transient private val initDriverSideJobFuncOpt: Option[Job => Unit], - initLocalJobFuncOpt: Option[Job => Unit], - inputFormatClass: Class[_ <: InputFormat[Void, V]], - valueClass: Class[V]) - extends RDD[V](sqlContext.sparkContext, Nil) with Logging { - - protected def getJob(): Job = { - val conf = broadcastedConf.value.value - // "new Job" will make a copy of the conf. Then, it is - // safe to mutate conf properties with initLocalJobFuncOpt - // and initDriverSideJobFuncOpt. - val newJob = Job.getInstance(conf) - initLocalJobFuncOpt.map(f => f(newJob)) - newJob - } - - def getConf(isDriverSide: Boolean): Configuration = { - val job = getJob() - if (isDriverSide) { - initDriverSideJobFuncOpt.map(f => f(job)) - } - job.getConfiguration - } - - private val jobTrackerId: String = { - val formatter = new SimpleDateFormat("yyyyMMddHHmm") - formatter.format(new Date()) - } - - @transient protected val jobId = new JobID(jobTrackerId, id) - - override def getPartitions: Array[SparkPartition] = { - val conf = getConf(isDriverSide = true) - val inputFormat = inputFormatClass.newInstance - inputFormat match { - case configurable: Configurable => - configurable.setConf(conf) - case _ => - } - val jobContext = new JobContextImpl(conf, jobId) - val rawSplits = inputFormat.getSplits(jobContext).toArray - val result = new Array[SparkPartition](rawSplits.size) - for (i <- 0 until rawSplits.size) { - result(i) = - new SqlNewHadoopPartition(id, i, rawSplits(i).asInstanceOf[InputSplit with Writable]) - } - result - } - - override def compute( - theSplit: SparkPartition, - context: TaskContext): Iterator[V] = { - val iter = new Iterator[V] { - val split = theSplit.asInstanceOf[SqlNewHadoopPartition] - logInfo("Input split: " + split.serializableHadoopSplit) - val conf = getConf(isDriverSide = false) - - val inputMetrics = context.taskMetrics().registerInputMetrics(DataReadMethod.Hadoop) - val existingBytesRead = inputMetrics.bytesRead - - // Sets the thread local variable for the file's name - split.serializableHadoopSplit.value match { - case fs: FileSplit => SqlNewHadoopRDDState.setInputFileName(fs.getPath.toString) - case _ => SqlNewHadoopRDDState.unsetInputFileName() - } - - // Find a function that will return the FileSystem bytes read by this thread. Do this before - // creating RecordReader, because RecordReader's constructor might read some bytes - val getBytesReadCallback: Option[() => Long] = split.serializableHadoopSplit.value match { - case _: FileSplit | _: CombineFileSplit => - SparkHadoopUtil.get.getFSBytesReadOnThreadCallback() - case _ => None - } - - // For Hadoop 2.5+, we get our input bytes from thread-local Hadoop FileSystem statistics. - // If we do a coalesce, however, we are likely to compute multiple partitions in the same - // task and in the same thread, in which case we need to avoid override values written by - // previous partitions (SPARK-13071). - def updateBytesRead(): Unit = { - getBytesReadCallback.foreach { getBytesRead => - inputMetrics.setBytesRead(existingBytesRead + getBytesRead()) - } - } - - val format = inputFormatClass.newInstance - format match { - case configurable: Configurable => - configurable.setConf(conf) - case _ => - } - val attemptId = new TaskAttemptID(jobTrackerId, id, TaskType.MAP, split.index, 0) - val hadoopAttemptContext = new TaskAttemptContextImpl(conf, attemptId) - private[this] var reader: RecordReader[Void, V] = format.createRecordReader( - split.serializableHadoopSplit.value, hadoopAttemptContext) - reader.initialize(split.serializableHadoopSplit.value, hadoopAttemptContext) - - // Register an on-task-completion callback to close the input stream. - context.addTaskCompletionListener(context => close()) - - private[this] var havePair = false - private[this] var finished = false - - override def hasNext: Boolean = { - if (context.isInterrupted()) { - throw new TaskKilledException - } - if (!finished && !havePair) { - finished = !reader.nextKeyValue - if (finished) { - // Close and release the reader here; close() will also be called when the task - // completes, but for tasks that read from many files, it helps to release the - // resources early. - close() - } - havePair = !finished - } - !finished - } - - override def next(): V = { - if (!hasNext) { - throw new java.util.NoSuchElementException("End of stream") - } - havePair = false - if (!finished) { - inputMetrics.incRecordsReadInternal(1) - } - if (inputMetrics.recordsRead % SparkHadoopUtil.UPDATE_INPUT_METRICS_INTERVAL_RECORDS == 0) { - updateBytesRead() - } - reader.getCurrentValue - } - - private def close() { - if (reader != null) { - SqlNewHadoopRDDState.unsetInputFileName() - // Close the reader and release it. Note: it's very important that we don't close the - // reader more than once, since that exposes us to MAPREDUCE-5918 when running against - // Hadoop 1.x and older Hadoop 2.x releases. That bug can lead to non-deterministic - // corruption issues when reading compressed input. - try { - reader.close() - } catch { - case e: Exception => - if (!ShutdownHookManager.inShutdown()) { - logWarning("Exception in RecordReader.close()", e) - } - } finally { - reader = null - } - if (getBytesReadCallback.isDefined) { - updateBytesRead() - } else if (split.serializableHadoopSplit.value.isInstanceOf[FileSplit] || - split.serializableHadoopSplit.value.isInstanceOf[CombineFileSplit]) { - // If we can't get the bytes read from the FS stats, fall back to the split size, - // which may be inaccurate. - try { - inputMetrics.incBytesReadInternal(split.serializableHadoopSplit.value.getLength) - } catch { - case e: java.io.IOException => - logWarning("Unable to get input size to set InputMetrics for task", e) - } - } - } - } - } - iter - } - - override def getPreferredLocations(hsplit: SparkPartition): Seq[String] = { - val split = hsplit.asInstanceOf[SqlNewHadoopPartition].serializableHadoopSplit.value - val locs = HadoopRDD.SPLIT_INFO_REFLECTIONS match { - case Some(c) => - try { - val infos = c.newGetLocationInfo.invoke(split).asInstanceOf[Array[AnyRef]] - Some(HadoopRDD.convertSplitLocationInfo(infos)) - } catch { - case e : Exception => - logDebug("Failed to use InputSplit#getLocationInfo.", e) - None - } - case None => None - } - locs.getOrElse(split.getLocations.filter(_ != "localhost")) - } - - override def persist(storageLevel: StorageLevel): this.type = { - if (storageLevel.deserialized) { - logWarning("Caching NewHadoopRDDs as deserialized objects usually leads to undesired" + - " behavior because Hadoop's RecordReader reuses the same Writable object for all records." + - " Use a map transformation to make copies of the records.") - } - super.persist(storageLevel) - } - - /** - * Analogous to [[org.apache.spark.rdd.MapPartitionsRDD]], but passes in an InputSplit to - * the given function rather than the index of the partition. - */ - private[spark] class NewHadoopMapPartitionsWithSplitRDD[U: ClassTag, T: ClassTag]( - prev: RDD[T], - f: (InputSplit, Iterator[T]) => Iterator[U], - preservesPartitioning: Boolean = false) - extends RDD[U](prev) { - - override val partitioner = if (preservesPartitioning) firstParent[T].partitioner else None - - override def getPartitions: Array[SparkPartition] = firstParent[T].partitions - - override def compute(split: SparkPartition, context: TaskContext): Iterator[U] = { - val partition = split.asInstanceOf[SqlNewHadoopPartition] - val inputSplit = partition.serializableHadoopSplit.value - f(inputSplit, firstParent[T].iterator(split, context)) - } - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index e74fb00cb2..2f9d63c2e8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -28,7 +28,6 @@ import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ import org.apache.spark.network.util.ByteUnit import org.apache.spark.sql.catalyst.CatalystConf -import org.apache.spark.util.Utils //////////////////////////////////////////////////////////////////////////////////////////////////// // This file defines the configuration options for Spark SQL. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala index 0b74f07540..dac56d3936 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala @@ -22,8 +22,6 @@ import java.io.File import org.apache.hadoop.fs.FileStatus import org.apache.hadoop.mapreduce.Job -import org.apache.spark.broadcast.Broadcast -import org.apache.spark.rdd.RDD import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionSet, PredicateHelper} @@ -34,8 +32,7 @@ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources._ import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.{IntegerType, StructType} -import org.apache.spark.util.{SerializableConfiguration, Utils} -import org.apache.spark.util.collection.BitSet +import org.apache.spark.util.Utils class FileSourceStrategySuite extends QueryTest with SharedSQLContext with PredicateHelper { import testImplicits._ -- cgit v1.2.3 From 478af2f45595913c9b8f560d13e8d88447486f99 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Thu, 14 Apr 2016 09:42:15 +0100 Subject: [SPARK-14573][PYSPARK][BUILD] Fix PyDoc Makefile & highlighting issues ## What changes were proposed in this pull request? The PyDoc Makefile used "=" rather than "?=" for setting env variables so it overwrote the user values. This ignored the environment variables we set for linting allowing warnings through. This PR also fixes the warnings that had been introduced. ## How was this patch tested? manual local export & make Author: Holden Karau Closes #12336 from holdenk/SPARK-14573-fix-pydoc-makefile. --- python/docs/Makefile | 8 ++++---- python/pyspark/ml/regression.py | 2 +- python/pyspark/sql/context.py | 2 +- python/pyspark/sql/dataframe.py | 2 +- 4 files changed, 7 insertions(+), 7 deletions(-) diff --git a/python/docs/Makefile b/python/docs/Makefile index 903009790b..905e0215c2 100644 --- a/python/docs/Makefile +++ b/python/docs/Makefile @@ -2,10 +2,10 @@ # # You can set these variables from the command line. -SPHINXOPTS = -SPHINXBUILD = sphinx-build -PAPER = -BUILDDIR = _build +SPHINXOPTS ?= +SPHINXBUILD ?= sphinx-build +PAPER ?= +BUILDDIR ?= _build export PYTHONPATH=$(realpath ..):$(realpath ../lib/py4j-0.9.2-src.zip) diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py index 316d7e30bc..c064fe500c 100644 --- a/python/pyspark/ml/regression.py +++ b/python/pyspark/ml/regression.py @@ -28,7 +28,7 @@ from pyspark.sql import DataFrame __all__ = ['AFTSurvivalRegression', 'AFTSurvivalRegressionModel', 'DecisionTreeRegressor', 'DecisionTreeRegressionModel', 'GBTRegressor', 'GBTRegressionModel', - 'GeneralizedLinearRegression', 'GeneralizedLinearRegressionModel' + 'GeneralizedLinearRegression', 'GeneralizedLinearRegressionModel', 'IsotonicRegression', 'IsotonicRegressionModel', 'LinearRegression', 'LinearRegressionModel', 'LinearRegressionSummary', 'LinearRegressionTrainingSummary', diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py index 4008332c84..11dfcfe13e 100644 --- a/python/pyspark/sql/context.py +++ b/python/pyspark/sql/context.py @@ -405,7 +405,7 @@ class SQLContext(object): >>> sqlContext.createDataFrame(rdd, "boolean").collect() # doctest: +IGNORE_EXCEPTION_DETAIL Traceback (most recent call last): ... - Py4JJavaError:... + Py4JJavaError: ... """ if isinstance(data, DataFrame): raise TypeError("data is already a DataFrame") diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index d473d6b534..b4fa836893 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -60,7 +60,7 @@ class DataFrame(object): people = sqlContext.read.parquet("...") department = sqlContext.read.parquet("...") - people.filter(people.age > 30).join(department, people.deptId == department.id)) \ + people.filter(people.age > 30).join(department, people.deptId == department.id)\ .groupBy(department.name, "gender").agg({"salary": "avg", "age": "max"}) .. note:: Experimental -- cgit v1.2.3 From 6fc3dc8839eaed673c64ec87af6dfe24f8cebe0c Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Thu, 14 Apr 2016 09:43:41 +0100 Subject: [MINOR][SQL] Remove extra anonymous closure within functional transformations ## What changes were proposed in this pull request? This PR removes extra anonymous closure within functional transformations. For example, ```scala .map(item => { ... }) ``` which can be just simply as below: ```scala .map { item => ... } ``` ## How was this patch tested? Related unit tests and `sbt scalastyle`. Author: hyukjinkwon Closes #12382 from HyukjinKwon/minor-extra-closers. --- .../main/scala/org/apache/spark/SparkContext.scala | 4 ++-- .../org/apache/spark/deploy/master/Master.scala | 4 ++-- .../main/scala/org/apache/spark/rdd/BlockRDD.scala | 4 ++-- .../scala/org/apache/spark/rdd/HadoopRDD.scala | 4 ++-- .../apache/spark/rdd/ParallelCollectionRDD.scala | 4 ++-- .../spark/rdd/PartitionerAwareUnionRDD.scala | 4 ++-- .../cluster/mesos/MesosSchedulerUtils.scala | 4 ++-- .../spark/shuffle/BlockStoreShuffleReader.scala | 4 ++-- .../org/apache/spark/ui/jobs/ExecutorTable.scala | 4 ++-- .../streaming/RecoverableNetworkWordCount.scala | 4 ++-- .../examples/streaming/SqlNetworkWordCount.scala | 4 ++-- .../flume/sink/SparkAvroCallbackHandler.scala | 4 ++-- .../spark/streaming/flume/sink/SparkSink.scala | 12 +++++----- .../flume/sink/TransactionProcessor.scala | 12 +++++----- .../streaming/flume/FlumePollingInputDStream.scala | 4 ++-- .../streaming/flume/PollingFlumeTestUtils.scala | 4 ++-- .../expressions/codegen/CodeGenerator.scala | 4 ++-- .../spark/sql/catalyst/optimizer/Optimizer.scala | 4 ++-- .../spark/sql/execution/basicOperators.scala | 4 ++-- .../sql/execution/datasources/jdbc/JdbcUtils.scala | 4 ++-- .../org/apache/spark/sql/hive/HiveInspectors.scala | 27 +++++++++------------- .../org/apache/spark/streaming/Checkpoint.scala | 8 +++---- .../spark/streaming/dstream/StateDStream.scala | 4 ++-- .../streaming/scheduler/ReceiverTracker.scala | 4 ++-- 24 files changed, 67 insertions(+), 72 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 966198dd5e..e41088f7c8 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -723,7 +723,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli (safeEnd - safeStart) / step + 1 } } - parallelize(0 until numSlices, numSlices).mapPartitionsWithIndex((i, _) => { + parallelize(0 until numSlices, numSlices).mapPartitionsWithIndex { (i, _) => val partitionStart = (i * numElements) / numSlices * step + start val partitionEnd = (((i + 1) * numElements) / numSlices) * step + start def getSafeMargin(bi: BigInt): Long = @@ -762,7 +762,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli ret } } - }) + } } /** Distribute a local Scala collection to form an RDD. diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index 9bd3fc1033..b443e8f051 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -843,10 +843,10 @@ private[deploy] class Master( addressToApp -= app.driver.address if (completedApps.size >= RETAINED_APPLICATIONS) { val toRemove = math.max(RETAINED_APPLICATIONS / 10, 1) - completedApps.take(toRemove).foreach( a => { + completedApps.take(toRemove).foreach { a => Option(appIdToUI.remove(a.id)).foreach { ui => webUi.detachSparkUI(ui) } applicationMetricsSystem.removeSource(a.appSource) - }) + } completedApps.trimStart(toRemove) } completedApps += app // Remember it in our history diff --git a/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala b/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala index 8358244987..63d1d1767a 100644 --- a/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala @@ -35,9 +35,9 @@ class BlockRDD[T: ClassTag](sc: SparkContext, @transient val blockIds: Array[Blo override def getPartitions: Array[Partition] = { assertValid() - (0 until blockIds.length).map(i => { + (0 until blockIds.length).map { i => new BlockRDDPartition(blockIds(i), i).asInstanceOf[Partition] - }).toArray + }.toArray } override def compute(split: Partition, context: TaskContext): Iterator[T] = { diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala index ac5ba9e79f..f7c646c668 100644 --- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala @@ -422,7 +422,7 @@ private[spark] object HadoopRDD extends Logging { private[spark] def convertSplitLocationInfo(infos: Array[AnyRef]): Seq[String] = { val out = ListBuffer[String]() - infos.foreach { loc => { + infos.foreach { loc => val locationStr = HadoopRDD.SPLIT_INFO_REFLECTIONS.get. getLocation.invoke(loc).asInstanceOf[String] if (locationStr != "localhost") { @@ -434,7 +434,7 @@ private[spark] object HadoopRDD extends Logging { out += new HostTaskLocation(locationStr).toString } } - }} + } out.seq } } diff --git a/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala index 462fb39ea2..bb84e4af15 100644 --- a/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala @@ -121,11 +121,11 @@ private object ParallelCollectionRDD { // Sequences need to be sliced at the same set of index positions for operations // like RDD.zip() to behave as expected def positions(length: Long, numSlices: Int): Iterator[(Int, Int)] = { - (0 until numSlices).iterator.map(i => { + (0 until numSlices).iterator.map { i => val start = ((i * length) / numSlices).toInt val end = (((i + 1) * length) / numSlices).toInt (start, end) - }) + } } seq match { case r: Range => diff --git a/core/src/main/scala/org/apache/spark/rdd/PartitionerAwareUnionRDD.scala b/core/src/main/scala/org/apache/spark/rdd/PartitionerAwareUnionRDD.scala index c3579d761d..0abba15bec 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PartitionerAwareUnionRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PartitionerAwareUnionRDD.scala @@ -68,9 +68,9 @@ class PartitionerAwareUnionRDD[T: ClassTag]( override def getPartitions: Array[Partition] = { val numPartitions = partitioner.get.numPartitions - (0 until numPartitions).map(index => { + (0 until numPartitions).map { index => new PartitionerAwareUnionRDDPartition(rdds, index) - }).toArray + }.toArray } // Get the location where most of the partitions of parent RDDs are located diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala index 7295d50682..1e322ac679 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala @@ -226,7 +226,7 @@ private[mesos] trait MesosSchedulerUtils extends Logging { * @return */ protected def toAttributeMap(offerAttributes: JList[Attribute]): Map[String, GeneratedMessage] = { - offerAttributes.asScala.map(attr => { + offerAttributes.asScala.map { attr => val attrValue = attr.getType match { case Value.Type.SCALAR => attr.getScalar case Value.Type.RANGES => attr.getRanges @@ -234,7 +234,7 @@ private[mesos] trait MesosSchedulerUtils extends Logging { case Value.Type.TEXT => attr.getText } (attr.getName, attrValue) - }).toMap + }.toMap } diff --git a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala index 637b2dfc19..876cdfaa87 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala @@ -69,10 +69,10 @@ private[spark] class BlockStoreShuffleReader[K, C]( // Update the context task metrics for each record read. val readMetrics = context.taskMetrics.registerTempShuffleReadMetrics() val metricIter = CompletionIterator[(Any, Any), Iterator[(Any, Any)]]( - recordIter.map(record => { + recordIter.map { record => readMetrics.incRecordsRead(1) record - }), + }, context.taskMetrics().mergeShuffleReadMetrics()) // An interruptible iterator must be used here in order to support task cancellation diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala index 1304efd8f2..f609fb4cd2 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala @@ -42,13 +42,13 @@ private[ui] class ExecutorTable(stageId: Int, stageAttemptId: Int, parent: Stage var hasShuffleWrite = false var hasShuffleRead = false var hasBytesSpilled = false - stageData.foreach(data => { + stageData.foreach { data => hasInput = data.hasInput hasOutput = data.hasOutput hasShuffleRead = data.hasShuffleRead hasShuffleWrite = data.hasShuffleWrite hasBytesSpilled = data.hasBytesSpilled - }) + }
    false Indicates whether the history server should use kerberos to login. This is required - if the history server is accessing HDFS files on a secure Hadoop cluster. If this is + if the history server is accessing HDFS files on a secure Hadoop cluster. If this is true, it uses the configs spark.history.kerberos.principal and - spark.history.kerberos.keytab. + spark.history.kerberos.keytab.
    false Specifies whether acls should be checked to authorize users viewing the applications. - If enabled, access control checks are made regardless of what the individual application had + If enabled, access control checks are made regardless of what the individual application had set for spark.ui.acls.enable when the application was run. The application owner - will always have authorization to view their own application and any users specified via + will always have authorization to view their own application and any users specified via spark.ui.view.acls when the application was run will also have authorization - to view that application. - If disabled, no access control checks are made. + to view that application. + If disabled, no access control checks are made.
    spark.yarn.dist.jars(none) + Comma-separated list of jars to be placed in the working directory of each executor. +
    spark.executor.cores 1 in YARN mode, all the available cores on the worker in standalone mode.{eventCount.toString} events{numRecords.toString} records {formattedSchedulingDelay} {graphUIDataForEventRateOfAllStreams.generateTimelineHtml(jsCollector)}{graphUIDataForEventRateOfAllStreams.generateHistogramHtml(jsCollector)}{graphUIDataForRecordRateOfAllStreams.generateTimelineHtml(jsCollector)}{graphUIDataForRecordRateOfAllStreams.generateHistogramHtml(jsCollector)}
    {receiverName}
    -
    Avg: {receivedRecords.formattedAvg} events/sec
    +
    Avg: {receivedRecords.formattedAvg} records/sec
    {receiverActive}
    - {graphUIDataForEventRate.generateTimelineHtml(jsCollector)} + {graphUIDataForRecordRate.generateTimelineHtml(jsCollector)} {graphUIDataForEventRate.generateHistogramHtml(jsCollector)}{graphUIDataForRecordRate.generateHistogramHtml(jsCollector)}
    spark.yarn.stagingDirCurrent user's home directory in the filesystem + Staging directory used while submitting applications. +
    spark.yarn.preserve.staging.files false
    spark.sql.parquet.output.committer.classorg.apache.parquet.hadoop.
    ParquetOutputCommitter
    -

    - The output committer class used by Parquet. The specified class needs to be a subclass of - org.apache.hadoop.
    mapreduce.OutputCommitter
    . Typically, it's also a - subclass of org.apache.parquet.hadoop.ParquetOutputCommitter. -

    -

    - Note: -

      -
    • - This option is automatically ignored if spark.speculation is turned on. -
    • -
    • - This option must be set via Hadoop Configuration rather than Spark - SQLConf. -
    • -
    • - This option overrides spark.sql.sources.
      outputCommitterClass
      . -
    • -
    -

    -

    - Spark SQL comes with a builtin - org.apache.spark.sql.
    parquet.DirectParquetOutputCommitter
    , which can be more - efficient then the default Parquet output committer when writing data to S3. -

    -
    spark.sql.parquet.mergeSchema false
    {summary}{details}
    diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/RecoverableNetworkWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/RecoverableNetworkWordCount.scala index b6b8bc33f7..bb2af9cd72 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/RecoverableNetworkWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/RecoverableNetworkWordCount.scala @@ -116,7 +116,7 @@ object RecoverableNetworkWordCount { val lines = ssc.socketTextStream(ip, port) val words = lines.flatMap(_.split(" ")) val wordCounts = words.map(x => (x, 1)).reduceByKey(_ + _) - wordCounts.foreachRDD((rdd: RDD[(String, Int)], time: Time) => { + wordCounts.foreachRDD { (rdd: RDD[(String, Int)], time: Time) => // Get or register the blacklist Broadcast val blacklist = WordBlacklist.getInstance(rdd.sparkContext) // Get or register the droppedWordsCounter Accumulator @@ -135,7 +135,7 @@ object RecoverableNetworkWordCount { println("Dropped " + droppedWordsCounter.value + " word(s) totally") println("Appending to " + outputFile.getAbsolutePath) Files.append(output + "\n", outputFile, Charset.defaultCharset()) - }) + } ssc } diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/SqlNetworkWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/SqlNetworkWordCount.scala index 3727f8fe6a..918e124065 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/SqlNetworkWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/SqlNetworkWordCount.scala @@ -59,7 +59,7 @@ object SqlNetworkWordCount { val words = lines.flatMap(_.split(" ")) // Convert RDDs of the words DStream to DataFrame and run SQL query - words.foreachRDD((rdd: RDD[String], time: Time) => { + words.foreachRDD { (rdd: RDD[String], time: Time) => // Get the singleton instance of SQLContext val sqlContext = SQLContextSingleton.getInstance(rdd.sparkContext) import sqlContext.implicits._ @@ -75,7 +75,7 @@ object SqlNetworkWordCount { sqlContext.sql("select word, count(*) as total from words group by word") println(s"========= $time =========") wordCountsDataFrame.show() - }) + } ssc.start() ssc.awaitTermination() diff --git a/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkAvroCallbackHandler.scala b/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkAvroCallbackHandler.scala index 719fca0938..8050ec357e 100644 --- a/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkAvroCallbackHandler.scala +++ b/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkAvroCallbackHandler.scala @@ -129,9 +129,9 @@ private[flume] class SparkAvroCallbackHandler(val threads: Int, val channel: Cha * @param success Whether the batch was successful or not. */ private def completeTransaction(sequenceNumber: CharSequence, success: Boolean) { - removeAndGetProcessor(sequenceNumber).foreach(processor => { + removeAndGetProcessor(sequenceNumber).foreach { processor => processor.batchProcessed(success) - }) + } } /** diff --git a/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkSink.scala b/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkSink.scala index 14dffb15fe..41f27e9376 100644 --- a/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkSink.scala +++ b/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkSink.scala @@ -88,23 +88,23 @@ class SparkSink extends AbstractSink with Logging with Configurable { // dependencies which are being excluded in the build. In practice, // Netty dependencies are already available on the JVM as Flume would have pulled them in. serverOpt = Option(new NettyServer(responder, new InetSocketAddress(hostname, port))) - serverOpt.foreach(server => { + serverOpt.foreach { server => logInfo("Starting Avro server for sink: " + getName) server.start() - }) + } super.start() } override def stop() { logInfo("Stopping Spark Sink: " + getName) - handler.foreach(callbackHandler => { + handler.foreach { callbackHandler => callbackHandler.shutdown() - }) - serverOpt.foreach(server => { + } + serverOpt.foreach { server => logInfo("Stopping Avro Server for sink: " + getName) server.close() server.join() - }) + } blockingLatch.countDown() super.stop() } diff --git a/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/TransactionProcessor.scala b/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/TransactionProcessor.scala index b15c2097e5..19e736f016 100644 --- a/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/TransactionProcessor.scala +++ b/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/TransactionProcessor.scala @@ -110,7 +110,7 @@ private class TransactionProcessor(val channel: Channel, val seqNum: String, eventBatch.setErrorMsg("Something went wrong. Channel was " + "unable to create a transaction!") } - txOpt.foreach(tx => { + txOpt.foreach { tx => tx.begin() val events = new util.ArrayList[SparkSinkEvent](maxBatchSize) val loop = new Breaks @@ -145,7 +145,7 @@ private class TransactionProcessor(val channel: Channel, val seqNum: String, // At this point, the events are available, so fill them into the event batch eventBatch = new EventBatch("", seqNum, events) } - }) + } } catch { case interrupted: InterruptedException => // Don't pollute logs if the InterruptedException came from this being stopped @@ -156,9 +156,9 @@ private class TransactionProcessor(val channel: Channel, val seqNum: String, logWarning("Error while processing transaction.", e) eventBatch.setErrorMsg(e.getMessage) try { - txOpt.foreach(tx => { + txOpt.foreach { tx => rollbackAndClose(tx, close = true) - }) + } } finally { txOpt = None } @@ -174,7 +174,7 @@ private class TransactionProcessor(val channel: Channel, val seqNum: String, */ private def processAckOrNack() { batchAckLatch.await(transactionTimeout, TimeUnit.SECONDS) - txOpt.foreach(tx => { + txOpt.foreach { tx => if (batchSuccess) { try { logDebug("Committing transaction") @@ -197,7 +197,7 @@ private class TransactionProcessor(val channel: Channel, val seqNum: String, // cause issues. This is required to ensure the TransactionProcessor instance is not leaked parent.removeAndGetProcessor(seqNum) } - }) + } } /** diff --git a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumePollingInputDStream.scala b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumePollingInputDStream.scala index 250bfc1718..54565840fa 100644 --- a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumePollingInputDStream.scala +++ b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumePollingInputDStream.scala @@ -79,11 +79,11 @@ private[streaming] class FlumePollingReceiver( override def onStart(): Unit = { // Create the connections to each Flume agent. - addresses.foreach(host => { + addresses.foreach { host => val transceiver = new NettyTransceiver(host, channelFactory) val client = SpecificRequestor.getClient(classOf[SparkFlumeProtocol.Callback], transceiver) connections.add(new FlumeConnection(transceiver, client)) - }) + } for (i <- 0 until parallelism) { logInfo("Starting Flume Polling Receiver worker threads..") // Threads that pull data from Flume. diff --git a/external/flume/src/main/scala/org/apache/spark/streaming/flume/PollingFlumeTestUtils.scala b/external/flume/src/main/scala/org/apache/spark/streaming/flume/PollingFlumeTestUtils.scala index 1a96df6e94..6a4dafb8ed 100644 --- a/external/flume/src/main/scala/org/apache/spark/streaming/flume/PollingFlumeTestUtils.scala +++ b/external/flume/src/main/scala/org/apache/spark/streaming/flume/PollingFlumeTestUtils.scala @@ -123,9 +123,9 @@ private[flume] class PollingFlumeTestUtils { val latch = new CountDownLatch(batchCount * channels.size) sinks.foreach(_.countdownWhenBatchReceived(latch)) - channels.foreach(channel => { + channels.foreach { channel => executorCompletion.submit(new TxnSubmitter(channel)) - }) + } for (i <- 0 until channels.size) { executorCompletion.take() diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index ee7f4fadca..f43626ca81 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -519,7 +519,7 @@ class CodegenContext { // Get all the expressions that appear at least twice and set up the state for subexpression // elimination. val commonExprs = equivalentExpressions.getAllEquivalentExprs.filter(_.size > 1) - commonExprs.foreach(e => { + commonExprs.foreach { e => val expr = e.head val fnName = freshName("evalExpr") val isNull = s"${fnName}IsNull" @@ -561,7 +561,7 @@ class CodegenContext { subexprFunctions += s"$fnName($INPUT_ROW);" val state = SubExprEliminationState(isNull, value) e.foreach(subExprEliminationExprs.put(_, state)) - }) + } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 438cbabdbb..aeb1842677 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -286,10 +286,10 @@ object SetOperationPushDown extends Rule[LogicalPlan] with PredicateHelper { assert(children.nonEmpty) if (projectList.forall(_.deterministic)) { val newFirstChild = Project(projectList, children.head) - val newOtherChildren = children.tail.map ( child => { + val newOtherChildren = children.tail.map { child => val rewrites = buildRewrites(children.head, child) Project(projectList.map(pushToRight(_, rewrites)), child) - } ) + } Union(newFirstChild +: newOtherChildren) } else { p diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index aba500ad8d..344aaff348 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -400,7 +400,7 @@ case class Range( sqlContext .sparkContext .parallelize(0 until numSlices, numSlices) - .mapPartitionsWithIndex((i, _) => { + .mapPartitionsWithIndex { (i, _) => val partitionStart = (i * numElements) / numSlices * step + start val partitionEnd = (((i + 1) * numElements) / numSlices) * step + start def getSafeMargin(bi: BigInt): Long = @@ -444,7 +444,7 @@ case class Range( unsafeRow } } - }) + } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala index b7ff5f7242..065c8572b0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala @@ -251,12 +251,12 @@ object JdbcUtils extends Logging { def schemaString(df: DataFrame, url: String): String = { val sb = new StringBuilder() val dialect = JdbcDialects.get(url) - df.schema.fields foreach { field => { + df.schema.fields foreach { field => val name = field.name val typ: String = getJdbcType(field.dataType, dialect).databaseTypeDefinition val nullable = if (field.nullable) "" else "NOT NULL" sb.append(s", $name $typ $nullable") - }} + } if (sb.length < 2) "" else sb.substring(2) } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala index 589862c7c0..585befe378 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala @@ -450,9 +450,7 @@ private[hive] trait HiveInspectors { if (o != null) { val array = o.asInstanceOf[ArrayData] val values = new java.util.ArrayList[Any](array.numElements()) - array.foreach(elementType, (_, e) => { - values.add(wrapper(e)) - }) + array.foreach(elementType, (_, e) => values.add(wrapper(e))) values } else { null @@ -468,9 +466,8 @@ private[hive] trait HiveInspectors { if (o != null) { val map = o.asInstanceOf[MapData] val jmap = new java.util.HashMap[Any, Any](map.numElements()) - map.foreach(mt.keyType, mt.valueType, (k, v) => { - jmap.put(keyWrapper(k), valueWrapper(v)) - }) + map.foreach(mt.keyType, mt.valueType, (k, v) => + jmap.put(keyWrapper(k), valueWrapper(v))) jmap } else { null @@ -587,9 +584,9 @@ private[hive] trait HiveInspectors { case x: ListObjectInspector => val list = new java.util.ArrayList[Object] val tpe = dataType.asInstanceOf[ArrayType].elementType - a.asInstanceOf[ArrayData].foreach(tpe, (_, e) => { + a.asInstanceOf[ArrayData].foreach(tpe, (_, e) => list.add(wrap(e, x.getListElementObjectInspector, tpe)) - }) + ) list case x: MapObjectInspector => val keyType = dataType.asInstanceOf[MapType].keyType @@ -599,10 +596,10 @@ private[hive] trait HiveInspectors { // Some UDFs seem to assume we pass in a HashMap. val hashMap = new java.util.HashMap[Any, Any](map.numElements()) - map.foreach(keyType, valueType, (k, v) => { + map.foreach(keyType, valueType, (k, v) => hashMap.put(wrap(k, x.getMapKeyObjectInspector, keyType), wrap(v, x.getMapValueObjectInspector, valueType)) - }) + ) hashMap } @@ -704,9 +701,8 @@ private[hive] trait HiveInspectors { ObjectInspectorFactory.getStandardConstantListObjectInspector(listObjectInspector, null) } else { val list = new java.util.ArrayList[Object]() - value.asInstanceOf[ArrayData].foreach(dt, (_, e) => { - list.add(wrap(e, listObjectInspector, dt)) - }) + value.asInstanceOf[ArrayData].foreach(dt, (_, e) => + list.add(wrap(e, listObjectInspector, dt))) ObjectInspectorFactory.getStandardConstantListObjectInspector(listObjectInspector, list) } case Literal(value, MapType(keyType, valueType, _)) => @@ -718,9 +714,8 @@ private[hive] trait HiveInspectors { val map = value.asInstanceOf[MapData] val jmap = new java.util.HashMap[Any, Any](map.numElements()) - map.foreach(keyType, valueType, (k, v) => { - jmap.put(wrap(k, keyOI, keyType), wrap(v, valueOI, valueType)) - }) + map.foreach(keyType, valueType, (k, v) => + jmap.put(wrap(k, keyOI, keyType), wrap(v, valueOI, valueType))) ObjectInspectorFactory.getStandardConstantMapObjectInspector(keyOI, valueOI, jmap) } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala index 5cc677d085..0395600954 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala @@ -247,10 +247,10 @@ class CheckpointWriter( // Delete old checkpoint files val allCheckpointFiles = Checkpoint.getCheckpointFiles(checkpointDir, Some(fs)) if (allCheckpointFiles.size > 10) { - allCheckpointFiles.take(allCheckpointFiles.size - 10).foreach(file => { + allCheckpointFiles.take(allCheckpointFiles.size - 10).foreach { file => logInfo("Deleting " + file) fs.delete(file, true) - }) + } } // All done, print success @@ -345,7 +345,7 @@ object CheckpointReader extends Logging { // Try to read the checkpoint files in the order logInfo("Checkpoint files found: " + checkpointFiles.mkString(",")) var readError: Exception = null - checkpointFiles.foreach(file => { + checkpointFiles.foreach { file => logInfo("Attempting to load checkpoint from file " + file) try { val fis = fs.open(file) @@ -358,7 +358,7 @@ object CheckpointReader extends Logging { readError = e logWarning("Error reading checkpoint from file " + file, e) } - }) + } // If none of checkpoint files could be read, then throw exception if (!ignoreReadError) { diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/StateDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/StateDStream.scala index 28aed0ca45..8efb09a8ce 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/StateDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/StateDStream.scala @@ -48,11 +48,11 @@ class StateDStream[K: ClassTag, V: ClassTag, S: ClassTag]( // and then apply the update function val updateFuncLocal = updateFunc val finalFunc = (iterator: Iterator[(K, (Iterable[V], Iterable[S]))]) => { - val i = iterator.map(t => { + val i = iterator.map { t => val itr = t._2._2.iterator val headOption = if (itr.hasNext) Some(itr.next()) else None (t._1, t._2._1.toSeq, headOption) - }) + } updateFuncLocal(i) } val cogroupedRDD = parentRDD.cogroup(prevStateRDD, partitioner) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala index 3b33a979df..9aa2f0bbb9 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala @@ -434,11 +434,11 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false * worker nodes as a parallel collection, and runs them. */ private def launchReceivers(): Unit = { - val receivers = receiverInputStreams.map(nis => { + val receivers = receiverInputStreams.map { nis => val rcvr = nis.getReceiver() rcvr.setReceiverId(nis.id) rcvr - }) + } runDummySparkJob() -- cgit v1.2.3 From 3cf3db17b35c98a408014e1810cb797d8415ffd3 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Thu, 14 Apr 2016 08:08:09 -0700 Subject: [SPARK-14518][SQL] Support Comment in CREATE VIEW #### What changes were proposed in this pull request? **HQL Syntax**: [Create View](https://cwiki.apache.org/confluence/display/Hive/LanguageManual+DDL#LanguageManualDDL-Create/Drop/AlterView ) ```SQL CREATE VIEW [IF NOT EXISTS] [db_name.]view_name [(column_name [COMMENT column_comment], ...) ] [COMMENT view_comment] [TBLPROPERTIES (property_name = property_value, ...)] AS SELECT ...; ``` Add a support for the `[COMMENT view_comment]` clause #### How was this patch tested? Modified the existing test cases to verify the correctness. Author: gatorsmile Author: xiaoli Author: Xiao Li Closes #12288 from gatorsmile/addCommentInCreateView. --- .../apache/spark/sql/hive/execution/HiveSqlParser.scala | 16 ++++++---------- .../org/apache/spark/sql/hive/HiveDDLCommandSuite.scala | 11 ++--------- .../apache/spark/sql/hive/execution/HiveDDLSuite.scala | 16 ++++++++++++++++ 3 files changed, 24 insertions(+), 19 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveSqlParser.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveSqlParser.scala index b14db7fe71..8c707079a1 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveSqlParser.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveSqlParser.scala @@ -26,18 +26,13 @@ import org.apache.hadoop.hive.ql.session.SessionState import org.apache.hadoop.hive.serde.serdeConstants import org.apache.hadoop.hive.serde2.`lazy`.LazySimpleSerDe -import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.analysis.UnresolvedGenerator import org.apache.spark.sql.catalyst.catalog._ -import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.parser._ import org.apache.spark.sql.catalyst.parser.SqlBaseParser._ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.SparkSqlAstBuilder import org.apache.spark.sql.execution.command.CreateTable -import org.apache.spark.sql.hive.{CreateTableAsSelect => CTAS, CreateViewAsSelect => CreateView} -import org.apache.spark.sql.hive.{HiveGenericUDTF, HiveMetastoreTypes, HiveSerDe} -import org.apache.spark.sql.hive.HiveShim.HiveFunctionWrapper +import org.apache.spark.sql.hive.{CreateTableAsSelect => CTAS, CreateViewAsSelect => CreateView, HiveSerDe} /** * Concrete parser for HiveQl statements. @@ -252,9 +247,6 @@ class HiveSqlAstBuilder extends SparkSqlAstBuilder { if (ctx.identifierList != null) { throw new ParseException(s"Operation not allowed: partitioned views", ctx) } else { - if (ctx.STRING != null) { - throw new ParseException("Unsupported operation: COMMENT clause", ctx) - } val identifiers = Option(ctx.identifierCommentList).toSeq.flatMap(_.identifierComment.asScala) val schema = identifiers.map { ic => CatalogColumn(ic.identifier.getText, null, nullable = true, Option(ic.STRING).map(string)) @@ -262,6 +254,7 @@ class HiveSqlAstBuilder extends SparkSqlAstBuilder { createView( ctx, ctx.tableIdentifier, + comment = Option(ctx.STRING).map(string), schema, ctx.query, Option(ctx.tablePropertyList).map(visitTablePropertyList).getOrElse(Map.empty), @@ -278,6 +271,7 @@ class HiveSqlAstBuilder extends SparkSqlAstBuilder { createView( ctx, ctx.tableIdentifier, + comment = None, Seq.empty, ctx.query, Map.empty, @@ -291,6 +285,7 @@ class HiveSqlAstBuilder extends SparkSqlAstBuilder { private def createView( ctx: ParserRuleContext, name: TableIdentifierContext, + comment: Option[String], schema: Seq[CatalogColumn], query: QueryContext, properties: Map[String, String], @@ -304,7 +299,8 @@ class HiveSqlAstBuilder extends SparkSqlAstBuilder { storage = EmptyStorageFormat, properties = properties, viewOriginalText = sql, - viewText = sql) + viewText = sql, + comment = comment) CreateView(tableDesc, plan(query), allowExist, replace, command(ctx)) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDDLCommandSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDDLCommandSuite.scala index 68d3ea6ed9..b87f035cd1 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDDLCommandSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDDLCommandSuite.scala @@ -236,15 +236,6 @@ class HiveDDLCommandSuite extends PlanTest { |FROM testData """.stripMargin) } - intercept[ParseException] { - parser.parsePlan( - """ - |CREATE OR REPLACE VIEW IF NOT EXISTS view1 (col1, col3) - |COMMENT 'blabla' - |TBLPROPERTIES('prop1Key'="prop1Val") - |AS SELECT * FROM tab1 - """.stripMargin) - } } test("Invalid interval term should throw AnalysisException") { @@ -532,6 +523,7 @@ class HiveDDLCommandSuite extends PlanTest { """ |CREATE OR REPLACE VIEW IF NOT EXISTS view1 |(col1, col3) + |COMMENT 'BLABLA' |TBLPROPERTIES('prop1Key'="prop1Val") |AS SELECT * FROM tab1 """.stripMargin @@ -551,6 +543,7 @@ class HiveDDLCommandSuite extends PlanTest { assert(desc.storage.outputFormat.isEmpty) assert(desc.storage.serde.isEmpty) assert(desc.properties == Map("prop1Key" -> "prop1Val")) + assert(desc.comment == Option("BLABLA")) } test("create view -- partitioned view") { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala index 78ccdc7adb..c82c7f6ca6 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala @@ -110,6 +110,22 @@ class HiveDDLSuite extends QueryTest with SQLTestUtils with TestHiveSingleton { } } + test("create table and view with comment") { + val catalog = hiveContext.sessionState.catalog + val tabName = "tab1" + withTable(tabName) { + sql(s"CREATE TABLE $tabName(c1 int) COMMENT 'BLABLA'") + val viewName = "view1" + withView(viewName) { + sql(s"CREATE VIEW $viewName COMMENT 'no comment' AS SELECT * FROM $tabName") + val tableMetadata = catalog.getTableMetadata(TableIdentifier(tabName, Some("default"))) + val viewMetadata = catalog.getTableMetadata(TableIdentifier(viewName, Some("default"))) + assert(tableMetadata.properties.get("comment") == Option("BLABLA")) + assert(viewMetadata.properties.get("comment") == Option("no comment")) + } + } + } + test("drop views") { withTable("tab1") { val tabName = "tab1" -- cgit v1.2.3 From f83ba454a507bec0cc389d9a382cd71add7f17c1 Mon Sep 17 00:00:00 2001 From: Dhruve Ashar Date: Thu, 14 Apr 2016 10:29:14 -0500 Subject: [SPARK-14572][DOC] Update config docs to allow -Xms in extraJavaOptions ## What changes were proposed in this pull request? The configuration docs are updated to reflect the changes introduced with [SPARK-12384](https://issues.apache.org/jira/browse/SPARK-12384). This allows the user to specify initial heap memory settings through the extraJavaOptions for executor, driver and am. ## How was this patch tested? The changes are tested in [SPARK-12384](https://issues.apache.org/jira/browse/SPARK-12384). This is just documenting the changes made. Author: Dhruve Ashar Closes #12333 from dhruve/doc/SPARK-14572. --- docs/configuration.md | 11 +++++++---- docs/running-on-yarn.md | 4 +++- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/docs/configuration.md b/docs/configuration.md index 937852ffde..16d5be62f9 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -225,11 +225,14 @@ Apart from these, the following properties are also available, and may be useful + your default properties file. @@ -269,9 +272,9 @@ Apart from these, the following properties are also available, and may be useful diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md index ddc75a70b9..09701abdb0 100644 --- a/docs/running-on-yarn.md +++ b/docs/running-on-yarn.md @@ -342,7 +342,9 @@ If you need a reference to the proper location to put log files in the YARN so t -- cgit v1.2.3 From 0d22092cd9c8876a7f226add578ff1c025012fe9 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Thu, 14 Apr 2016 08:34:11 -0700 Subject: [SPARK-14125][SQL] Native DDL Support: Alter View #### What changes were proposed in this pull request? This PR is to provide a native DDL support for the following three Alter View commands: Based on the Hive DDL document: https://cwiki.apache.org/confluence/display/Hive/LanguageManual+DDL ##### 1. ALTER VIEW RENAME **Syntax:** ```SQL ALTER VIEW view_name RENAME TO new_view_name ``` - to change the name of a view to a different name - not allowed to rename a view's name by ALTER TABLE ##### 2. ALTER VIEW SET TBLPROPERTIES **Syntax:** ```SQL ALTER VIEW view_name SET TBLPROPERTIES ('comment' = new_comment); ``` - to add metadata to a view - not allowed to set views' properties by ALTER TABLE - ignore it if trying to set a view's existing property key when the value is the same - overwrite the value if trying to set a view's existing key to a different value ##### 3. ALTER VIEW UNSET TBLPROPERTIES **Syntax:** ```SQL ALTER VIEW view_name UNSET TBLPROPERTIES [IF EXISTS] ('comment', 'key') ``` - to remove metadata from a view - not allowed to unset views' properties by ALTER TABLE - issue an exception if trying to unset a view's non-existent key #### How was this patch tested? Added test cases to verify if it works properly. Author: gatorsmile Author: xiaoli Author: Xiao Li Closes #12324 from gatorsmile/alterView. --- .../spark/sql/execution/SparkSqlParser.scala | 9 +- .../apache/spark/sql/execution/command/ddl.scala | 29 +++++- .../spark/sql/execution/command/tables.scala | 4 +- .../sql/execution/command/DDLCommandSuite.scala | 18 ++-- .../spark/sql/hive/execution/HiveDDLSuite.scala | 112 +++++++++++++++++++++ 5 files changed, 157 insertions(+), 15 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala index af92cecee5..8ed6ed21d0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala @@ -393,7 +393,8 @@ class SparkSqlAstBuilder extends AstBuilder { override def visitRenameTable(ctx: RenameTableContext): LogicalPlan = withOrigin(ctx) { AlterTableRename( visitTableIdentifier(ctx.from), - visitTableIdentifier(ctx.to)) + visitTableIdentifier(ctx.to), + ctx.VIEW != null) } /** @@ -409,7 +410,8 @@ class SparkSqlAstBuilder extends AstBuilder { ctx: SetTablePropertiesContext): LogicalPlan = withOrigin(ctx) { AlterTableSetProperties( visitTableIdentifier(ctx.tableIdentifier), - visitTablePropertyList(ctx.tablePropertyList)) + visitTablePropertyList(ctx.tablePropertyList), + ctx.VIEW != null) } /** @@ -426,7 +428,8 @@ class SparkSqlAstBuilder extends AstBuilder { AlterTableUnsetProperties( visitTableIdentifier(ctx.tableIdentifier), visitTablePropertyList(ctx.tablePropertyList).keys.toSeq, - ctx.EXISTS != null) + ctx.EXISTS != null, + ctx.VIEW != null) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala index 234099ad15..fc37a142cd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala @@ -23,7 +23,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.{AnalysisException, Row, SQLContext} import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.catalog.{CatalogDatabase, CatalogTable} -import org.apache.spark.sql.catalyst.catalog.{CatalogTablePartition, CatalogTableType} +import org.apache.spark.sql.catalyst.catalog.{CatalogTablePartition, CatalogTableType, SessionCatalog} import org.apache.spark.sql.catalyst.catalog.ExternalCatalog.TablePartitionSpec import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} import org.apache.spark.sql.types._ @@ -235,11 +235,13 @@ case class DropTable( */ case class AlterTableSetProperties( tableName: TableIdentifier, - properties: Map[String, String]) + properties: Map[String, String], + isView: Boolean) extends RunnableCommand { override def run(sqlContext: SQLContext): Seq[Row] = { val catalog = sqlContext.sessionState.catalog + DDLUtils.verifyAlterTableType(catalog, tableName, isView) val table = catalog.getTableMetadata(tableName) val newProperties = table.properties ++ properties if (DDLUtils.isDatasourceTable(newProperties)) { @@ -265,11 +267,13 @@ case class AlterTableSetProperties( case class AlterTableUnsetProperties( tableName: TableIdentifier, propKeys: Seq[String], - ifExists: Boolean) + ifExists: Boolean, + isView: Boolean) extends RunnableCommand { override def run(sqlContext: SQLContext): Seq[Row] = { val catalog = sqlContext.sessionState.catalog + DDLUtils.verifyAlterTableType(catalog, tableName, isView) val table = catalog.getTableMetadata(tableName) if (DDLUtils.isDatasourceTable(table)) { throw new AnalysisException( @@ -513,5 +517,24 @@ private object DDLUtils { def isDatasourceTable(table: CatalogTable): Boolean = { isDatasourceTable(table.properties) } + + /** + * If the command ALTER VIEW is to alter a table or ALTER TABLE is to alter a view, + * issue an exception [[AnalysisException]]. + */ + def verifyAlterTableType( + catalog: SessionCatalog, + tableIdentifier: TableIdentifier, + isView: Boolean): Unit = { + catalog.getTableMetadataOption(tableIdentifier).map(_.tableType match { + case CatalogTableType.VIRTUAL_VIEW if !isView => + throw new AnalysisException( + "Cannot alter a view with ALTER TABLE. Please use ALTER VIEW instead") + case o if o != CatalogTableType.VIRTUAL_VIEW && isView => + throw new AnalysisException( + s"Cannot alter a table with ALTER VIEW. Please use ALTER TABLE instead") + case _ => + }) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala index 9c6030502d..e315598daa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala @@ -67,11 +67,13 @@ case class CreateTable(table: CatalogTable, ifNotExists: Boolean) extends Runnab */ case class AlterTableRename( oldName: TableIdentifier, - newName: TableIdentifier) + newName: TableIdentifier, + isView: Boolean) extends RunnableCommand { override def run(sqlContext: SQLContext): Seq[Row] = { val catalog = sqlContext.sessionState.catalog + DDLUtils.verifyAlterTableType(catalog, oldName, isView) catalog.invalidateTable(oldName) catalog.renameTable(oldName, newName) Seq.empty[Row] diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala index 6e6475ee29..d6ccaf9348 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala @@ -214,10 +214,12 @@ class DDLCommandSuite extends PlanTest { val parsed_view = parser.parsePlan(sql_view) val expected_table = AlterTableRename( TableIdentifier("table_name", None), - TableIdentifier("new_table_name", None)) + TableIdentifier("new_table_name", None), + isView = false) val expected_view = AlterTableRename( TableIdentifier("table_name", None), - TableIdentifier("new_table_name", None)) + TableIdentifier("new_table_name", None), + isView = true) comparePlans(parsed_table, expected_table) comparePlans(parsed_view, expected_view) } @@ -244,14 +246,14 @@ class DDLCommandSuite extends PlanTest { val tableIdent = TableIdentifier("table_name", None) val expected1_table = AlterTableSetProperties( - tableIdent, Map("test" -> "test", "comment" -> "new_comment")) + tableIdent, Map("test" -> "test", "comment" -> "new_comment"), isView = false) val expected2_table = AlterTableUnsetProperties( - tableIdent, Seq("comment", "test"), ifExists = false) + tableIdent, Seq("comment", "test"), ifExists = false, isView = false) val expected3_table = AlterTableUnsetProperties( - tableIdent, Seq("comment", "test"), ifExists = true) - val expected1_view = expected1_table - val expected2_view = expected2_table - val expected3_view = expected3_table + tableIdent, Seq("comment", "test"), ifExists = true, isView = false) + val expected1_view = expected1_table.copy(isView = true) + val expected2_view = expected2_table.copy(isView = true) + val expected3_view = expected3_table.copy(isView = true) comparePlans(parsed1_table, expected1_table) comparePlans(parsed2_table, expected2_table) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala index c82c7f6ca6..249dcdfff5 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala @@ -147,6 +147,118 @@ class HiveDDLSuite extends QueryTest with SQLTestUtils with TestHiveSingleton { } } + test("alter views - rename") { + val tabName = "tab1" + withTable(tabName) { + sqlContext.range(10).write.saveAsTable(tabName) + val oldViewName = "view1" + val newViewName = "view2" + withView(oldViewName, newViewName) { + val catalog = hiveContext.sessionState.catalog + sql(s"CREATE VIEW $oldViewName AS SELECT * FROM $tabName") + + assert(catalog.tableExists(TableIdentifier(oldViewName))) + assert(!catalog.tableExists(TableIdentifier(newViewName))) + sql(s"ALTER VIEW $oldViewName RENAME TO $newViewName") + assert(!catalog.tableExists(TableIdentifier(oldViewName))) + assert(catalog.tableExists(TableIdentifier(newViewName))) + } + } + } + + test("alter views - set/unset tblproperties") { + val tabName = "tab1" + withTable(tabName) { + sqlContext.range(10).write.saveAsTable(tabName) + val viewName = "view1" + withView(viewName) { + val catalog = hiveContext.sessionState.catalog + sql(s"CREATE VIEW $viewName AS SELECT * FROM $tabName") + + assert(catalog.getTableMetadata(TableIdentifier(viewName)) + .properties.filter(_._1 != "transient_lastDdlTime") == Map()) + sql(s"ALTER VIEW $viewName SET TBLPROPERTIES ('p' = 'an')") + assert(catalog.getTableMetadata(TableIdentifier(viewName)) + .properties.filter(_._1 != "transient_lastDdlTime") == Map("p" -> "an")) + + // no exception or message will be issued if we set it again + sql(s"ALTER VIEW $viewName SET TBLPROPERTIES ('p' = 'an')") + assert(catalog.getTableMetadata(TableIdentifier(viewName)) + .properties.filter(_._1 != "transient_lastDdlTime") == Map("p" -> "an")) + + // the value will be updated if we set the same key to a different value + sql(s"ALTER VIEW $viewName SET TBLPROPERTIES ('p' = 'b')") + assert(catalog.getTableMetadata(TableIdentifier(viewName)) + .properties.filter(_._1 != "transient_lastDdlTime") == Map("p" -> "b")) + + sql(s"ALTER VIEW $viewName UNSET TBLPROPERTIES ('p')") + assert(catalog.getTableMetadata(TableIdentifier(viewName)) + .properties.filter(_._1 != "transient_lastDdlTime") == Map()) + + val message = intercept[AnalysisException] { + sql(s"ALTER VIEW $viewName UNSET TBLPROPERTIES ('p')") + }.getMessage + assert(message.contains( + "attempted to unset non-existent property 'p' in table '`view1`'")) + } + } + } + + test("alter views and alter table - misuse") { + val tabName = "tab1" + withTable(tabName) { + sqlContext.range(10).write.saveAsTable(tabName) + val oldViewName = "view1" + val newViewName = "view2" + withView(oldViewName, newViewName) { + val catalog = hiveContext.sessionState.catalog + sql(s"CREATE VIEW $oldViewName AS SELECT * FROM $tabName") + + assert(catalog.tableExists(TableIdentifier(tabName))) + assert(catalog.tableExists(TableIdentifier(oldViewName))) + + var message = intercept[AnalysisException] { + sql(s"ALTER VIEW $tabName RENAME TO $newViewName") + }.getMessage + assert(message.contains( + "Cannot alter a table with ALTER VIEW. Please use ALTER TABLE instead")) + + message = intercept[AnalysisException] { + sql(s"ALTER VIEW $tabName SET TBLPROPERTIES ('p' = 'an')") + }.getMessage + assert(message.contains( + "Cannot alter a table with ALTER VIEW. Please use ALTER TABLE instead")) + + message = intercept[AnalysisException] { + sql(s"ALTER VIEW $tabName UNSET TBLPROPERTIES ('p')") + }.getMessage + assert(message.contains( + "Cannot alter a table with ALTER VIEW. Please use ALTER TABLE instead")) + + message = intercept[AnalysisException] { + sql(s"ALTER TABLE $oldViewName RENAME TO $newViewName") + }.getMessage + assert(message.contains( + "Cannot alter a view with ALTER TABLE. Please use ALTER VIEW instead")) + + message = intercept[AnalysisException] { + sql(s"ALTER TABLE $oldViewName SET TBLPROPERTIES ('p' = 'an')") + }.getMessage + assert(message.contains( + "Cannot alter a view with ALTER TABLE. Please use ALTER VIEW instead")) + + message = intercept[AnalysisException] { + sql(s"ALTER TABLE $oldViewName UNSET TBLPROPERTIES ('p')") + }.getMessage + assert(message.contains( + "Cannot alter a view with ALTER TABLE. Please use ALTER VIEW instead")) + + assert(catalog.tableExists(TableIdentifier(tabName))) + assert(catalog.tableExists(TableIdentifier(oldViewName))) + } + } + } + test("drop table using drop view") { withTable("tab1") { sql("CREATE TABLE tab1(c1 int)") -- cgit v1.2.3 From de2ad52855aee3c60bbc4642afb180d6fe62173b Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Thu, 14 Apr 2016 10:12:29 -0700 Subject: [SPARK-14625] TaskUIData and ExecutorUIData shouldn't be case classes ## What changes were proposed in this pull request? I was trying to understand the accumulator and metrics update source code and these two classes don't really need to be case classes. It would also be more consistent with other UI classes if they are not case classes. This is part of my bigger effort to simplify accumulators and task metrics. ## How was this patch tested? This is a straightforward refactoring without behavior change. Author: Reynold Xin Closes #12386 from rxin/SPARK-14625. --- .../spark/status/api/v1/AllStagesResource.scala | 4 +- .../org/apache/spark/ui/exec/ExecutorsTab.scala | 2 +- .../apache/spark/ui/jobs/JobProgressListener.scala | 8 +- .../scala/org/apache/spark/ui/jobs/StagePage.scala | 85 +++++++++++----------- .../scala/org/apache/spark/ui/jobs/UIData.scala | 6 +- .../spark/ui/jobs/JobProgressListenerSuite.scala | 10 +-- 6 files changed, 58 insertions(+), 57 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/AllStagesResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/AllStagesResource.scala index 9c92a50150..f8d6e9fbbb 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/AllStagesResource.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/AllStagesResource.scala @@ -147,7 +147,7 @@ private[v1] object AllStagesResource { speculative = uiData.taskInfo.speculative, accumulatorUpdates = uiData.taskInfo.accumulables.map { convertAccumulableInfo }, errorMessage = uiData.errorMessage, - taskMetrics = uiData.taskMetrics.map { convertUiTaskMetrics } + taskMetrics = uiData.metrics.map { convertUiTaskMetrics } ) } @@ -155,7 +155,7 @@ private[v1] object AllStagesResource { allTaskData: Iterable[TaskUIData], quantiles: Array[Double]): TaskMetricDistributions = { - val rawMetrics = allTaskData.flatMap{_.taskMetrics}.toSeq + val rawMetrics = allTaskData.flatMap{_.metrics}.toSeq def metricQuantiles(f: InternalTaskMetrics => Double): IndexedSeq[Double] = Distribution(rawMetrics.map { d => f(d) }).get.getQuantiles(quantiles) diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala index 788f35ec77..3fd0efd3a1 100644 --- a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala +++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala @@ -70,7 +70,7 @@ class ExecutorsListener(storageStatusListener: StorageStatusListener, conf: Spar executorToLogUrls(eid) = executorAdded.executorInfo.logUrlMap executorToTotalCores(eid) = executorAdded.executorInfo.totalCores executorToTasksMax(eid) = executorToTotalCores(eid) / conf.getInt("spark.task.cpus", 1) - executorIdToData(eid) = ExecutorUIData(executorAdded.time) + executorIdToData(eid) = new ExecutorUIData(executorAdded.time) } override def onExecutorRemoved( diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala index ed3ab66e3b..13f5f84d06 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala @@ -396,13 +396,13 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { None } taskMetrics.foreach { m => - val oldMetrics = stageData.taskData.get(info.taskId).flatMap(_.taskMetrics) + val oldMetrics = stageData.taskData.get(info.taskId).flatMap(_.metrics) updateAggregateMetrics(stageData, info.executorId, m, oldMetrics) } val taskData = stageData.taskData.getOrElseUpdate(info.taskId, new TaskUIData(info)) taskData.taskInfo = info - taskData.taskMetrics = taskMetrics + taskData.metrics = taskMetrics taskData.errorMessage = errorMessage for ( @@ -506,9 +506,9 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { val metrics = TaskMetrics.fromAccumulatorUpdates(accumUpdates) taskData.foreach { t => if (!t.taskInfo.finished) { - updateAggregateMetrics(stageData, executorMetricsUpdate.execId, metrics, t.taskMetrics) + updateAggregateMetrics(stageData, executorMetricsUpdate.execId, metrics, t.metrics) // Overwrite task metrics - t.taskMetrics = Some(metrics) + t.metrics = Some(metrics) } } } diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala index 689ab7dd5e..8a44bbd9fc 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala @@ -330,7 +330,7 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { else taskTable.dataSource.slicedTaskIds // Excludes tasks which failed and have incomplete metrics - val validTasks = tasks.filter(t => t.taskInfo.status == "SUCCESS" && t.taskMetrics.isDefined) + val validTasks = tasks.filter(t => t.taskInfo.status == "SUCCESS" && t.metrics.isDefined) val summaryTable: Option[Seq[Node]] = if (validTasks.size == 0) { @@ -348,8 +348,8 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { getDistributionQuantiles(data).map(d => ) } - val deserializationTimes = validTasks.map { case TaskUIData(_, metrics, _) => - metrics.get.executorDeserializeTime.toDouble + val deserializationTimes = validTasks.map { taskUIData: TaskUIData => + taskUIData.metrics.get.executorDeserializeTime.toDouble } val deserializationQuantiles = +: getFormattedTimeQuantiles(deserializationTimes) - val serviceTimes = validTasks.map { case TaskUIData(_, metrics, _) => - metrics.get.executorRunTime.toDouble + val serviceTimes = validTasks.map { taskUIData: TaskUIData => + taskUIData.metrics.get.executorRunTime.toDouble } val serviceQuantiles = +: getFormattedTimeQuantiles(serviceTimes) - val gcTimes = validTasks.map { case TaskUIData(_, metrics, _) => - metrics.get.jvmGCTime.toDouble + val gcTimes = validTasks.map { taskUIData: TaskUIData => + taskUIData.metrics.get.jvmGCTime.toDouble } val gcQuantiles = +: getFormattedTimeQuantiles(gcTimes) - val serializationTimes = validTasks.map { case TaskUIData(_, metrics, _) => - metrics.get.resultSerializationTime.toDouble + val serializationTimes = validTasks.map { taskUIData: TaskUIData => + taskUIData.metrics.get.resultSerializationTime.toDouble } val serializationQuantiles = +: getFormattedTimeQuantiles(serializationTimes) - val gettingResultTimes = validTasks.map { case TaskUIData(info, _, _) => - getGettingResultTime(info, currentTime).toDouble + val gettingResultTimes = validTasks.map { taskUIData: TaskUIData => + getGettingResultTime(taskUIData.taskInfo, currentTime).toDouble } val gettingResultQuantiles = +: getFormattedTimeQuantiles(gettingResultTimes) - val peakExecutionMemory = validTasks.map { case TaskUIData(_, metrics, _) => - metrics.get.peakExecutionMemory.toDouble + val peakExecutionMemory = validTasks.map { taskUIData: TaskUIData => + taskUIData.metrics.get.peakExecutionMemory.toDouble } val peakExecutionMemoryQuantiles = { @@ -427,30 +427,30 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { ) } - val inputSizes = validTasks.map { case TaskUIData(_, metrics, _) => - metrics.get.inputMetrics.map(_.bytesRead).getOrElse(0L).toDouble + val inputSizes = validTasks.map { taskUIData: TaskUIData => + taskUIData.metrics.get.inputMetrics.map(_.bytesRead).getOrElse(0L).toDouble } - val inputRecords = validTasks.map { case TaskUIData(_, metrics, _) => - metrics.get.inputMetrics.map(_.recordsRead).getOrElse(0L).toDouble + val inputRecords = validTasks.map { taskUIData: TaskUIData => + taskUIData.metrics.get.inputMetrics.map(_.recordsRead).getOrElse(0L).toDouble } val inputQuantiles = +: getFormattedSizeQuantilesWithRecords(inputSizes, inputRecords) - val outputSizes = validTasks.map { case TaskUIData(_, metrics, _) => - metrics.get.outputMetrics.map(_.bytesWritten).getOrElse(0L).toDouble + val outputSizes = validTasks.map { taskUIData: TaskUIData => + taskUIData.metrics.get.outputMetrics.map(_.bytesWritten).getOrElse(0L).toDouble } - val outputRecords = validTasks.map { case TaskUIData(_, metrics, _) => - metrics.get.outputMetrics.map(_.recordsWritten).getOrElse(0L).toDouble + val outputRecords = validTasks.map { taskUIData: TaskUIData => + taskUIData.metrics.get.outputMetrics.map(_.recordsWritten).getOrElse(0L).toDouble } val outputQuantiles = +: getFormattedSizeQuantilesWithRecords(outputSizes, outputRecords) - val shuffleReadBlockedTimes = validTasks.map { case TaskUIData(_, metrics, _) => - metrics.get.shuffleReadMetrics.map(_.fetchWaitTime).getOrElse(0L).toDouble + val shuffleReadBlockedTimes = validTasks.map { taskUIData: TaskUIData => + taskUIData.metrics.get.shuffleReadMetrics.map(_.fetchWaitTime).getOrElse(0L).toDouble } val shuffleReadBlockedQuantiles = +: getFormattedTimeQuantiles(shuffleReadBlockedTimes) - val shuffleReadTotalSizes = validTasks.map { case TaskUIData(_, metrics, _) => - metrics.get.shuffleReadMetrics.map(_.totalBytesRead).getOrElse(0L).toDouble + val shuffleReadTotalSizes = validTasks.map { taskUIData: TaskUIData => + taskUIData.metrics.get.shuffleReadMetrics.map(_.totalBytesRead).getOrElse(0L).toDouble } - val shuffleReadTotalRecords = validTasks.map { case TaskUIData(_, metrics, _) => - metrics.get.shuffleReadMetrics.map(_.recordsRead).getOrElse(0L).toDouble + val shuffleReadTotalRecords = validTasks.map { taskUIData: TaskUIData => + taskUIData.metrics.get.shuffleReadMetrics.map(_.recordsRead).getOrElse(0L).toDouble } val shuffleReadTotalQuantiles = +: getFormattedSizeQuantilesWithRecords(shuffleReadTotalSizes, shuffleReadTotalRecords) - val shuffleReadRemoteSizes = validTasks.map { case TaskUIData(_, metrics, _) => - metrics.get.shuffleReadMetrics.map(_.remoteBytesRead).getOrElse(0L).toDouble + val shuffleReadRemoteSizes = validTasks.map { taskUIData: TaskUIData => + taskUIData.metrics.get.shuffleReadMetrics.map(_.remoteBytesRead).getOrElse(0L).toDouble } val shuffleReadRemoteQuantiles = +: getFormattedSizeQuantiles(shuffleReadRemoteSizes) - val shuffleWriteSizes = validTasks.map { case TaskUIData(_, metrics, _) => - metrics.get.shuffleWriteMetrics.map(_.bytesWritten).getOrElse(0L).toDouble + val shuffleWriteSizes = validTasks.map { taskUIData: TaskUIData => + taskUIData.metrics.get.shuffleWriteMetrics.map(_.bytesWritten).getOrElse(0L).toDouble } - val shuffleWriteRecords = validTasks.map { case TaskUIData(_, metrics, _) => - metrics.get.shuffleWriteMetrics.map(_.recordsWritten).getOrElse(0L).toDouble + val shuffleWriteRecords = validTasks.map { taskUIData: TaskUIData => + taskUIData.metrics.get.shuffleWriteMetrics.map(_.recordsWritten).getOrElse(0L).toDouble } val shuffleWriteQuantiles = +: getFormattedSizeQuantilesWithRecords(shuffleWriteSizes, shuffleWriteRecords) - val memoryBytesSpilledSizes = validTasks.map { case TaskUIData(_, metrics, _) => - metrics.get.memoryBytesSpilled.toDouble + val memoryBytesSpilledSizes = validTasks.map { taskUIData: TaskUIData => + taskUIData.metrics.get.memoryBytesSpilled.toDouble } val memoryBytesSpilledQuantiles = +: getFormattedSizeQuantiles(memoryBytesSpilledSizes) - val diskBytesSpilledSizes = validTasks.map { case TaskUIData(_, metrics, _) => - metrics.get.diskBytesSpilled.toDouble + val diskBytesSpilledSizes = validTasks.map { taskUIData: TaskUIData => + taskUIData.metrics.get.diskBytesSpilled.toDouble } val diskBytesSpilledQuantiles = +: getFormattedSizeQuantiles(diskBytesSpilledSizes) @@ -601,7 +601,7 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { def toProportion(time: Long) = time.toDouble / totalExecutionTime * 100 - val metricsOpt = taskUIData.taskMetrics + val metricsOpt = taskUIData.metrics val shuffleReadTime = metricsOpt.flatMap(_.shuffleReadMetrics.map(_.fetchWaitTime)).getOrElse(0L) val shuffleReadTimeProportion = toProportion(shuffleReadTime) @@ -868,7 +868,8 @@ private[ui] class TaskDataSource( def slicedTaskIds: Set[Long] = _slicedTaskIds private def taskRow(taskData: TaskUIData): TaskTableRowData = { - val TaskUIData(info, metrics, errorMessage) = taskData + val info = taskData.taskInfo + val metrics = taskData.metrics val duration = if (info.status == "RUNNING") info.timeRunning(currentTime) else metrics.map(_.executorRunTime).getOrElse(1L) val formatDuration = if (info.status == "RUNNING") UIUtils.formatDuration(duration) @@ -1014,7 +1015,7 @@ private[ui] class TaskDataSource( shuffleRead, shuffleWrite, bytesSpilled, - errorMessage.getOrElse("")) + taskData.errorMessage.getOrElse("")) } /** diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala b/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala index 78165d7b74..b454ef1b20 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala @@ -105,12 +105,12 @@ private[spark] object UIData { /** * These are kept mutable and reused throughout a task's lifetime to avoid excessive reallocation. */ - case class TaskUIData( + class TaskUIData( var taskInfo: TaskInfo, - var taskMetrics: Option[TaskMetrics] = None, + var metrics: Option[TaskMetrics] = None, var errorMessage: Option[String] = None) - case class ExecutorUIData( + class ExecutorUIData( val startTime: Long, var finishTime: Option[Long] = None, var finishReason: Option[String] = None) diff --git a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala index 9876bded33..7d4c0863bc 100644 --- a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala @@ -322,11 +322,11 @@ class JobProgressListenerSuite extends SparkFunSuite with LocalSparkContext with assert(stage1Data.inputBytes == 207) assert(stage0Data.outputBytes == 116) assert(stage1Data.outputBytes == 208) - assert(stage0Data.taskData.get(1234L).get.taskMetrics.get.shuffleReadMetrics.get + assert(stage0Data.taskData.get(1234L).get.metrics.get.shuffleReadMetrics.get .totalBlocksFetched == 2) - assert(stage0Data.taskData.get(1235L).get.taskMetrics.get.shuffleReadMetrics.get + assert(stage0Data.taskData.get(1235L).get.metrics.get.shuffleReadMetrics.get .totalBlocksFetched == 102) - assert(stage1Data.taskData.get(1236L).get.taskMetrics.get.shuffleReadMetrics.get + assert(stage1Data.taskData.get(1236L).get.metrics.get.shuffleReadMetrics.get .totalBlocksFetched == 202) // task that was included in a heartbeat @@ -355,9 +355,9 @@ class JobProgressListenerSuite extends SparkFunSuite with LocalSparkContext with assert(stage1Data.inputBytes == 614) assert(stage0Data.outputBytes == 416) assert(stage1Data.outputBytes == 616) - assert(stage0Data.taskData.get(1234L).get.taskMetrics.get.shuffleReadMetrics.get + assert(stage0Data.taskData.get(1234L).get.metrics.get.shuffleReadMetrics.get .totalBlocksFetched == 302) - assert(stage1Data.taskData.get(1237L).get.taskMetrics.get.shuffleReadMetrics.get + assert(stage1Data.taskData.get(1237L).get.metrics.get.shuffleReadMetrics.get .totalBlocksFetched == 402) } } -- cgit v1.2.3 From 3e27940a19e7bab448f1af11d2065ecd1ec66197 Mon Sep 17 00:00:00 2001 From: Liwei Lin Date: Thu, 14 Apr 2016 10:14:38 -0700 Subject: [SPARK-14630][BUILD][CORE][SQL][STREAMING] Code style: public abstract methods should have explicit return types ## What changes were proposed in this pull request? Currently many public abstract methods (in abstract classes as well as traits) don't declare return types explicitly, such as in [o.a.s.streaming.dstream.InputDStream](https://github.com/apache/spark/blob/master/streaming/src/main/scala/org/apache/spark/streaming/dstream/InputDStream.scala#L110): ```scala def start() // should be: def start(): Unit def stop() // should be: def stop(): Unit ``` These methods exist in core, sql, streaming; this PR fixes them. ## How was this patch tested? N/A ## Which piece of scala style rule led to the changes? the rule was added separately in https://github.com/apache/spark/pull/12396 Author: Liwei Lin Closes #12389 from lw-lin/public-abstract-methods. --- core/src/main/scala/org/apache/spark/ContextCleaner.scala | 10 +++++----- core/src/main/scala/org/apache/spark/FutureAction.scala | 4 ++-- .../org/apache/spark/deploy/client/AppClientListener.scala | 3 ++- .../org/apache/spark/deploy/master/LeaderElectionAgent.scala | 4 ++-- .../org/apache/spark/deploy/master/PersistenceEngine.scala | 4 ++-- .../scala/org/apache/spark/deploy/worker/DriverRunner.scala | 2 +- .../main/scala/org/apache/spark/executor/ExecutorBackend.scala | 2 +- .../scala/org/apache/spark/network/BlockTransferService.scala | 2 +- .../main/scala/org/apache/spark/scheduler/JobListener.scala | 4 ++-- .../scala/org/apache/spark/scheduler/SchedulableBuilder.scala | 4 ++-- .../main/scala/org/apache/spark/scheduler/TaskScheduler.scala | 2 +- .../scala/org/apache/spark/serializer/KryoSerializer.scala | 2 +- .../org/apache/spark/shuffle/FileShuffleBlockResolver.scala | 2 +- core/src/main/scala/org/apache/spark/ui/WebUI.scala | 2 +- .../scala/org/apache/spark/util/logging/RollingPolicy.scala | 4 ++-- .../main/scala/org/apache/spark/util/random/Pseudorandom.scala | 2 +- .../spark/sql/catalyst/expressions/SpecificMutableRow.scala | 2 +- .../scala/org/apache/spark/sql/catalyst/expressions/rows.scala | 2 +- .../apache/spark/sql/execution/columnar/ColumnAccessor.scala | 2 +- .../apache/spark/sql/execution/columnar/ColumnBuilder.scala | 4 ++-- .../spark/sql/execution/streaming/state/StateStore.scala | 2 +- .../org/apache/spark/sql/util/ContinuousQueryListener.scala | 6 +++--- .../org/apache/spark/streaming/dstream/InputDStream.scala | 4 ++-- .../org/apache/spark/streaming/receiver/BlockGenerator.scala | 8 ++++---- .../apache/spark/streaming/receiver/ReceivedBlockHandler.scala | 2 +- .../scala/org/apache/spark/streaming/receiver/Receiver.scala | 4 ++-- .../apache/spark/streaming/receiver/ReceiverSupervisor.scala | 10 +++++----- 27 files changed, 50 insertions(+), 49 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/ContextCleaner.scala b/core/src/main/scala/org/apache/spark/ContextCleaner.scala index 8fc657c5eb..76692ccec8 100644 --- a/core/src/main/scala/org/apache/spark/ContextCleaner.scala +++ b/core/src/main/scala/org/apache/spark/ContextCleaner.scala @@ -278,9 +278,9 @@ private object ContextCleaner { * Listener class used for testing when any item has been cleaned by the Cleaner class. */ private[spark] trait CleanerListener { - def rddCleaned(rddId: Int) - def shuffleCleaned(shuffleId: Int) - def broadcastCleaned(broadcastId: Long) - def accumCleaned(accId: Long) - def checkpointCleaned(rddId: Long) + def rddCleaned(rddId: Int): Unit + def shuffleCleaned(shuffleId: Int): Unit + def broadcastCleaned(broadcastId: Long): Unit + def accumCleaned(accId: Long): Unit + def checkpointCleaned(rddId: Long): Unit } diff --git a/core/src/main/scala/org/apache/spark/FutureAction.scala b/core/src/main/scala/org/apache/spark/FutureAction.scala index ce11772a6d..339266a5d4 100644 --- a/core/src/main/scala/org/apache/spark/FutureAction.scala +++ b/core/src/main/scala/org/apache/spark/FutureAction.scala @@ -41,7 +41,7 @@ trait FutureAction[T] extends Future[T] { /** * Cancels the execution of this action. */ - def cancel() + def cancel(): Unit /** * Blocks until this action completes. @@ -65,7 +65,7 @@ trait FutureAction[T] extends Future[T] { * When this action is completed, either through an exception, or a value, applies the provided * function. */ - def onComplete[U](func: (Try[T]) => U)(implicit executor: ExecutionContext) + def onComplete[U](func: (Try[T]) => U)(implicit executor: ExecutionContext): Unit /** * Returns whether the action has already been completed with a value or an exception. diff --git a/core/src/main/scala/org/apache/spark/deploy/client/AppClientListener.scala b/core/src/main/scala/org/apache/spark/deploy/client/AppClientListener.scala index e584952a9a..94506a0cbb 100644 --- a/core/src/main/scala/org/apache/spark/deploy/client/AppClientListener.scala +++ b/core/src/main/scala/org/apache/spark/deploy/client/AppClientListener.scala @@ -33,7 +33,8 @@ private[spark] trait AppClientListener { /** An application death is an unrecoverable failure condition. */ def dead(reason: String): Unit - def executorAdded(fullId: String, workerId: String, hostPort: String, cores: Int, memory: Int) + def executorAdded( + fullId: String, workerId: String, hostPort: String, cores: Int, memory: Int): Unit def executorRemoved(fullId: String, message: String, exitStatus: Option[Int]): Unit } diff --git a/core/src/main/scala/org/apache/spark/deploy/master/LeaderElectionAgent.scala b/core/src/main/scala/org/apache/spark/deploy/master/LeaderElectionAgent.scala index 70f21fbe0d..52e2854961 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/LeaderElectionAgent.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/LeaderElectionAgent.scala @@ -32,8 +32,8 @@ trait LeaderElectionAgent { @DeveloperApi trait LeaderElectable { - def electedLeader() - def revokedLeadership() + def electedLeader(): Unit + def revokedLeadership(): Unit } /** Single-node implementation of LeaderElectionAgent -- we're initially and always the leader. */ diff --git a/core/src/main/scala/org/apache/spark/deploy/master/PersistenceEngine.scala b/core/src/main/scala/org/apache/spark/deploy/master/PersistenceEngine.scala index dddf2be57e..b30bc821b7 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/PersistenceEngine.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/PersistenceEngine.scala @@ -40,12 +40,12 @@ abstract class PersistenceEngine { * Defines how the object is serialized and persisted. Implementation will * depend on the store used. */ - def persist(name: String, obj: Object) + def persist(name: String, obj: Object): Unit /** * Defines how the object referred by its name is removed from the store. */ - def unpersist(name: String) + def unpersist(name: String): Unit /** * Gives all objects, matching a prefix. This defines how objects are diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala b/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala index 9c6bc5c62f..aad2e91b25 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala @@ -218,7 +218,7 @@ private[deploy] class DriverRunner( } private[deploy] trait Sleeper { - def sleep(seconds: Int) + def sleep(seconds: Int): Unit } // Needed because ProcessBuilder is a final class and cannot be mocked diff --git a/core/src/main/scala/org/apache/spark/executor/ExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/ExecutorBackend.scala index e07cb31cbe..7153323d01 100644 --- a/core/src/main/scala/org/apache/spark/executor/ExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/ExecutorBackend.scala @@ -25,6 +25,6 @@ import org.apache.spark.TaskState.TaskState * A pluggable interface used by the Executor to send updates to the cluster scheduler. */ private[spark] trait ExecutorBackend { - def statusUpdate(taskId: Long, state: TaskState, data: ByteBuffer) + def statusUpdate(taskId: Long, state: TaskState, data: ByteBuffer): Unit } diff --git a/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala index e43e3a2de2..09ce012e4e 100644 --- a/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala @@ -36,7 +36,7 @@ abstract class BlockTransferService extends ShuffleClient with Closeable with Lo * Initialize the transfer service by giving it the BlockDataManager that can be used to fetch * local blocks or put local blocks. */ - def init(blockDataManager: BlockDataManager) + def init(blockDataManager: BlockDataManager): Unit /** * Tear down the transfer service. diff --git a/core/src/main/scala/org/apache/spark/scheduler/JobListener.scala b/core/src/main/scala/org/apache/spark/scheduler/JobListener.scala index 50c2b9acd6..e0f7c8f021 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/JobListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/JobListener.scala @@ -23,6 +23,6 @@ package org.apache.spark.scheduler * job fails (and no further taskSucceeded events will happen). */ private[spark] trait JobListener { - def taskSucceeded(index: Int, result: Any) - def jobFailed(exception: Exception) + def taskSucceeded(index: Int, result: Any): Unit + def jobFailed(exception: Exception): Unit } diff --git a/core/src/main/scala/org/apache/spark/scheduler/SchedulableBuilder.scala b/core/src/main/scala/org/apache/spark/scheduler/SchedulableBuilder.scala index 5baebe8c1f..100ed76ecb 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SchedulableBuilder.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SchedulableBuilder.scala @@ -34,9 +34,9 @@ import org.apache.spark.util.Utils private[spark] trait SchedulableBuilder { def rootPool: Pool - def buildPools() + def buildPools(): Unit - def addTaskSetManager(manager: Schedulable, properties: Properties) + def addTaskSetManager(manager: Schedulable, properties: Properties): Unit } private[spark] class FIFOSchedulableBuilder(val rootPool: Pool) diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala index 8477a66b39..647d44a0f0 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala @@ -51,7 +51,7 @@ private[spark] trait TaskScheduler { def submitTasks(taskSet: TaskSet): Unit // Cancel a stage. - def cancelTasks(stageId: Int, interruptThread: Boolean) + def cancelTasks(stageId: Int, interruptThread: Boolean): Unit // Set the DAG scheduler for upcalls. This is guaranteed to be set before submitTasks is called. def setDAGScheduler(dagScheduler: DAGScheduler): Unit diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala index 3d090a4353..918ae376f6 100644 --- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala @@ -357,7 +357,7 @@ private[spark] class KryoSerializerInstance(ks: KryoSerializer) extends Serializ * serialization. */ trait KryoRegistrator { - def registerClasses(kryo: Kryo) + def registerClasses(kryo: Kryo): Unit } private[serializer] object KryoSerializer { diff --git a/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala index 6cd7d69518..be1e84a2ba 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala @@ -35,7 +35,7 @@ private[spark] trait ShuffleWriterGroup { val writers: Array[DiskBlockObjectWriter] /** @param success Indicates all writes were successful. If false, no blocks will be recorded. */ - def releaseWriters(success: Boolean) + def releaseWriters(success: Boolean): Unit } /** diff --git a/core/src/main/scala/org/apache/spark/ui/WebUI.scala b/core/src/main/scala/org/apache/spark/ui/WebUI.scala index 3939b111b5..2b0bc32cf6 100644 --- a/core/src/main/scala/org/apache/spark/ui/WebUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/WebUI.scala @@ -129,7 +129,7 @@ private[spark] abstract class WebUI( } /** Initialize all components of the server. */ - def initialize() + def initialize(): Unit /** Bind to the HTTP server behind this web interface. */ def bind() { diff --git a/core/src/main/scala/org/apache/spark/util/logging/RollingPolicy.scala b/core/src/main/scala/org/apache/spark/util/logging/RollingPolicy.scala index b34880d3a7..6e80db2f51 100644 --- a/core/src/main/scala/org/apache/spark/util/logging/RollingPolicy.scala +++ b/core/src/main/scala/org/apache/spark/util/logging/RollingPolicy.scala @@ -32,10 +32,10 @@ private[spark] trait RollingPolicy { def shouldRollover(bytesToBeWritten: Long): Boolean /** Notify that rollover has occurred */ - def rolledOver() + def rolledOver(): Unit /** Notify that bytes have been written */ - def bytesWritten(bytes: Long) + def bytesWritten(bytes: Long): Unit /** Get the desired name of the rollover file */ def generateRolledOverFileSuffix(): String diff --git a/core/src/main/scala/org/apache/spark/util/random/Pseudorandom.scala b/core/src/main/scala/org/apache/spark/util/random/Pseudorandom.scala index 70f3dd62b9..41f28f6e51 100644 --- a/core/src/main/scala/org/apache/spark/util/random/Pseudorandom.scala +++ b/core/src/main/scala/org/apache/spark/util/random/Pseudorandom.scala @@ -26,5 +26,5 @@ import org.apache.spark.annotation.DeveloperApi @DeveloperApi trait Pseudorandom { /** Set random seed. */ - def setSeed(seed: Long) + def setSeed(seed: Long): Unit } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala index 4615c55d67..61ca7272df 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala @@ -62,7 +62,7 @@ import org.apache.spark.sql.types._ abstract class MutableValue extends Serializable { var isNull: Boolean = true def boxed: Any - def update(v: Any) + def update(v: Any): Unit def copy(): MutableValue } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala index be6b2530ef..93a8278528 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala @@ -164,7 +164,7 @@ trait BaseGenericInternalRow extends InternalRow { abstract class MutableRow extends InternalRow { def setNullAt(i: Int): Unit - def update(i: Int, value: Any) + def update(i: Int, value: Any): Unit // default implementation (slow) def setBoolean(i: Int, value: Boolean): Unit = { update(i, value) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnAccessor.scala index 78664baa56..7cde04b626 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnAccessor.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnAccessor.scala @@ -38,7 +38,7 @@ private[columnar] trait ColumnAccessor { def hasNext: Boolean - def extractTo(row: MutableRow, ordinal: Int) + def extractTo(row: MutableRow, ordinal: Int): Unit protected def underlyingBuffer: ByteBuffer } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnBuilder.scala index 9a173367f4..d30655e0c4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnBuilder.scala @@ -28,12 +28,12 @@ private[columnar] trait ColumnBuilder { /** * Initializes with an approximate lower bound on the expected number of elements in this column. */ - def initialize(initialSize: Int, columnName: String = "", useCompression: Boolean = false) + def initialize(initialSize: Int, columnName: String = "", useCompression: Boolean = false): Unit /** * Appends `row(ordinal)` to the column builder. */ - def appendFrom(row: InternalRow, ordinal: Int) + def appendFrom(row: InternalRow, ordinal: Int): Unit /** * Column statistics information diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala index cc5327e0e2..9521506325 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala @@ -50,7 +50,7 @@ trait StateStore { def get(key: UnsafeRow): Option[UnsafeRow] /** Put a new value for a key. */ - def put(key: UnsafeRow, value: UnsafeRow) + def put(key: UnsafeRow, value: UnsafeRow): Unit /** * Remove keys that match the following condition. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/util/ContinuousQueryListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/util/ContinuousQueryListener.scala index bf78be9d9f..ba1facf11b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/util/ContinuousQueryListener.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/util/ContinuousQueryListener.scala @@ -37,7 +37,7 @@ abstract class ContinuousQueryListener { * `DataFrameWriter.startStream()` returns the corresponding [[ContinuousQuery]]. Please * don't block this method as it will block your query. */ - def onQueryStarted(queryStarted: QueryStarted) + def onQueryStarted(queryStarted: QueryStarted): Unit /** * Called when there is some status update (ingestion rate updated, etc.) @@ -47,10 +47,10 @@ abstract class ContinuousQueryListener { * may be changed before/when you process the event. E.g., you may find [[ContinuousQuery]] * is terminated when you are processing [[QueryProgress]]. */ - def onQueryProgress(queryProgress: QueryProgress) + def onQueryProgress(queryProgress: QueryProgress): Unit /** Called when a query is stopped, with or without error */ - def onQueryTerminated(queryTerminated: QueryTerminated) + def onQueryTerminated(queryTerminated: QueryTerminated): Unit } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/InputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/InputDStream.scala index dc88349db5..a3c125c306 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/InputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/InputDStream.scala @@ -107,8 +107,8 @@ abstract class InputDStream[T: ClassTag](_ssc: StreamingContext) } /** Method called to start receiving data. Subclasses must implement this method. */ - def start() + def start(): Unit /** Method called to stop receiving data. Subclasses must implement this method. */ - def stop() + def stop(): Unit } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/BlockGenerator.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/BlockGenerator.scala index e42bea6ec6..4592e015ed 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/BlockGenerator.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/BlockGenerator.scala @@ -37,7 +37,7 @@ private[streaming] trait BlockGeneratorListener { * that will be useful when a block is generated. Any long blocking operation in this callback * will hurt the throughput. */ - def onAddData(data: Any, metadata: Any) + def onAddData(data: Any, metadata: Any): Unit /** * Called when a new block of data is generated by the block generator. The block generation @@ -47,7 +47,7 @@ private[streaming] trait BlockGeneratorListener { * be useful when the block has been successfully stored. Any long blocking operation in this * callback will hurt the throughput. */ - def onGenerateBlock(blockId: StreamBlockId) + def onGenerateBlock(blockId: StreamBlockId): Unit /** * Called when a new block is ready to be pushed. Callers are supposed to store the block into @@ -55,13 +55,13 @@ private[streaming] trait BlockGeneratorListener { * thread, that is not synchronized with any other callbacks. Hence it is okay to do long * blocking operation in this callback. */ - def onPushBlock(blockId: StreamBlockId, arrayBuffer: ArrayBuffer[_]) + def onPushBlock(blockId: StreamBlockId, arrayBuffer: ArrayBuffer[_]): Unit /** * Called when an error has occurred in the BlockGenerator. Can be called form many places * so better to not do any long block operation in this callback. */ - def onError(message: String, throwable: Throwable) + def onError(message: String, throwable: Throwable): Unit } /** diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala index 85350ff658..7aea1c9b64 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala @@ -48,7 +48,7 @@ private[streaming] trait ReceivedBlockHandler { def storeBlock(blockId: StreamBlockId, receivedBlock: ReceivedBlock): ReceivedBlockStoreResult /** Cleanup old blocks older than the given threshold time */ - def cleanupOldBlocks(threshTime: Long) + def cleanupOldBlocks(threshTime: Long): Unit } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/Receiver.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/Receiver.scala index 3376cd557d..5157ca62dc 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/Receiver.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/Receiver.scala @@ -99,13 +99,13 @@ abstract class Receiver[T](val storageLevel: StorageLevel) extends Serializable * (iii) `restart(...)` can be called to restart the receiver. This will call `onStop()` * immediately, and then `onStart()` after a delay. */ - def onStart() + def onStart(): Unit /** * This method is called by the system when the receiver is stopped. All resources * (threads, buffers, etc.) set up in `onStart()` must be cleaned up in this method. */ - def onStop() + def onStop(): Unit /** Override this to specify a preferred location (hostname). */ def preferredLocation: Option[String] = None diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisor.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisor.scala index e0fe8d2206..42fc84c19b 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisor.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisor.scala @@ -70,28 +70,28 @@ private[streaming] abstract class ReceiverSupervisor( @volatile private[streaming] var receiverState = Initialized /** Push a single data item to backend data store. */ - def pushSingle(data: Any) + def pushSingle(data: Any): Unit /** Store the bytes of received data as a data block into Spark's memory. */ def pushBytes( bytes: ByteBuffer, optionalMetadata: Option[Any], optionalBlockId: Option[StreamBlockId] - ) + ): Unit /** Store a iterator of received data as a data block into Spark's memory. */ def pushIterator( iterator: Iterator[_], optionalMetadata: Option[Any], optionalBlockId: Option[StreamBlockId] - ) + ): Unit /** Store an ArrayBuffer of received data as a data block into Spark's memory. */ def pushArrayBuffer( arrayBuffer: ArrayBuffer[_], optionalMetadata: Option[Any], optionalBlockId: Option[StreamBlockId] - ) + ): Unit /** * Create a custom [[BlockGenerator]] that the receiver implementation can directly control @@ -103,7 +103,7 @@ private[streaming] abstract class ReceiverSupervisor( def createBlockGenerator(blockGeneratorListener: BlockGeneratorListener): BlockGenerator /** Report errors. */ - def reportError(message: String, throwable: Throwable) + def reportError(message: String, throwable: Throwable): Unit /** * Called when supervisor is started. -- cgit v1.2.3 From 9fa43a33b91c3a9b6be39bf3e00febf61a4b5b59 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Thu, 14 Apr 2016 10:48:17 -0700 Subject: [SPARK-14612][ML] Consolidate the version of dependencies in mllib and mllib-local into one place ## What changes were proposed in this pull request? Move json4s, breeze dependency declaration into parent ## How was this patch tested? Should be no functional change, but Jenkins tests will test that. Author: Sean Owen Closes #12390 from srowen/SPARK-14612. --- core/pom.xml | 1 - mllib-local/pom.xml | 13 ------------- mllib/pom.xml | 13 ------------- pom.xml | 22 ++++++++++++++++++++++ 4 files changed, 22 insertions(+), 27 deletions(-) diff --git a/core/pom.xml b/core/pom.xml index 4c7e3a3662..7349ad35b9 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -192,7 +192,6 @@ org.json4s json4s-jackson_${scala.binary.version} - 3.2.10 com.sun.jersey diff --git a/mllib-local/pom.xml b/mllib-local/pom.xml index c56561f215..68f15dd905 100644 --- a/mllib-local/pom.xml +++ b/mllib-local/pom.xml @@ -38,19 +38,6 @@ org.scalanlp breeze_${scala.binary.version} - 0.11.2 - - - - junit - junit - - - org.apache.commons - commons-math3 - - org.apache.commons diff --git a/mllib/pom.xml b/mllib/pom.xml index e56eafc300..24d8274e22 100644 --- a/mllib/pom.xml +++ b/mllib/pom.xml @@ -77,19 +77,6 @@ org.scalanlp breeze_${scala.binary.version} - 0.11.2 - - - - junit - junit - - - org.apache.commons - commons-math3 - - org.apache.commons diff --git a/pom.xml b/pom.xml index 4585c8b9c2..a772d51337 100644 --- a/pom.xml +++ b/pom.xml @@ -584,6 +584,28 @@ ${jersey.version} ${hadoop.deps.scope} + + org.scalanlp + breeze_${scala.binary.version} + 0.11.2 + + + + junit + junit + + + org.apache.commons + commons-math3 + + + + + org.json4s + json4s-jackson_${scala.binary.version} + 3.2.10 + com.sun.jersey jersey-json -- cgit v1.2.3 From dac40b68dc52d5ab855dfde63f0872064aa3d999 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Thu, 14 Apr 2016 10:54:57 -0700 Subject: [SPARK-14619] Track internal accumulators (metrics) by stage attempt ## What changes were proposed in this pull request? When there are multiple attempts for a stage, we currently only reset internal accumulator values if all the tasks are resubmitted. It would make more sense to reset the accumulator values for each stage attempt. This will allow us to eventually get rid of the internal flag in the Accumulator class. This is part of my bigger effort to simplify accumulators and task metrics. ## How was this patch tested? Covered by existing tests. Author: Reynold Xin Closes #12378 from rxin/SPARK-14619. --- .../scala/org/apache/spark/InternalAccumulator.scala | 2 +- .../org/apache/spark/scheduler/DAGScheduler.scala | 11 ++--------- .../main/scala/org/apache/spark/scheduler/Stage.scala | 19 ++----------------- .../scala/org/apache/spark/scheduler/StageInfo.scala | 10 +++++++++- .../main/scala/org/apache/spark/ui/jobs/JobPage.scala | 2 +- .../scala/org/apache/spark/util/JsonProtocol.scala | 6 ++++-- .../apache/spark/ExecutorAllocationManagerSuite.scala | 4 ++-- .../test/scala/org/apache/spark/ShuffleSuite.scala | 6 +++--- .../apache/spark/scheduler/DAGSchedulerSuite.scala | 2 +- .../sql/execution/UnsafeRowSerializerSuite.scala | 2 +- 10 files changed, 26 insertions(+), 38 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/InternalAccumulator.scala b/core/src/main/scala/org/apache/spark/InternalAccumulator.scala index 7aa9057858..0dd4ec656f 100644 --- a/core/src/main/scala/org/apache/spark/InternalAccumulator.scala +++ b/core/src/main/scala/org/apache/spark/InternalAccumulator.scala @@ -187,7 +187,7 @@ private[spark] object InternalAccumulator { * add to the same set of accumulators. We do this to report the distribution of accumulator * values across all tasks within each stage. */ - def create(sc: SparkContext): Seq[Accumulator[_]] = { + def createAll(sc: SparkContext): Seq[Accumulator[_]] = { val accums = createAll() accums.foreach { accum => Accumulators.register(accum) diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 4609b244e6..c27aad268d 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -950,13 +950,6 @@ class DAGScheduler( // First figure out the indexes of partition ids to compute. val partitionsToCompute: Seq[Int] = stage.findMissingPartitions() - // Create internal accumulators if the stage has no accumulators initialized. - // Reset internal accumulators only if this stage is not partially submitted - // Otherwise, we may override existing accumulator values from some tasks - if (stage.internalAccumulators.isEmpty || stage.numPartitions == partitionsToCompute.size) { - stage.resetInternalAccumulators() - } - // Use the scheduling pool, job group, description, etc. from an ActiveJob associated // with this Stage val properties = jobIdToActiveJob(jobId).properties @@ -1036,7 +1029,7 @@ class DAGScheduler( val locs = taskIdToLocations(id) val part = stage.rdd.partitions(id) new ShuffleMapTask(stage.id, stage.latestInfo.attemptId, - taskBinary, part, locs, stage.internalAccumulators, properties) + taskBinary, part, locs, stage.latestInfo.internalAccumulators, properties) } case stage: ResultStage => @@ -1046,7 +1039,7 @@ class DAGScheduler( val part = stage.rdd.partitions(p) val locs = taskIdToLocations(id) new ResultTask(stage.id, stage.latestInfo.attemptId, - taskBinary, part, locs, id, properties, stage.internalAccumulators) + taskBinary, part, locs, id, properties, stage.latestInfo.internalAccumulators) } } } catch { diff --git a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala index a40b700cdd..b6d4e39fe5 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala @@ -75,22 +75,6 @@ private[scheduler] abstract class Stage( val name: String = callSite.shortForm val details: String = callSite.longForm - private var _internalAccumulators: Seq[Accumulator[_]] = Seq.empty - - /** Internal accumulators shared across all tasks in this stage. */ - def internalAccumulators: Seq[Accumulator[_]] = _internalAccumulators - - /** - * Re-initialize the internal accumulators associated with this stage. - * - * This is called every time the stage is submitted, *except* when a subset of tasks - * belonging to this stage has already finished. Otherwise, reinitializing the internal - * accumulators here again will override partial values from the finished tasks. - */ - def resetInternalAccumulators(): Unit = { - _internalAccumulators = InternalAccumulator.create(rdd.sparkContext) - } - /** * Pointer to the [StageInfo] object for the most recent attempt. This needs to be initialized * here, before any attempts have actually been created, because the DAGScheduler uses this @@ -127,7 +111,8 @@ private[scheduler] abstract class Stage( numPartitionsToCompute: Int, taskLocalityPreferences: Seq[Seq[TaskLocation]] = Seq.empty): Unit = { _latestInfo = StageInfo.fromStage( - this, nextAttemptId, Some(numPartitionsToCompute), taskLocalityPreferences) + this, nextAttemptId, Some(numPartitionsToCompute), + InternalAccumulator.createAll(rdd.sparkContext), taskLocalityPreferences) nextAttemptId += 1 } diff --git a/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala index 24796c1430..0fd58c41cd 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala @@ -19,6 +19,7 @@ package org.apache.spark.scheduler import scala.collection.mutable.HashMap +import org.apache.spark.Accumulator import org.apache.spark.annotation.DeveloperApi import org.apache.spark.storage.RDDInfo @@ -35,6 +36,7 @@ class StageInfo( val rddInfos: Seq[RDDInfo], val parentIds: Seq[Int], val details: String, + val internalAccumulators: Seq[Accumulator[_]] = Seq.empty, private[spark] val taskLocalityPreferences: Seq[Seq[TaskLocation]] = Seq.empty) { /** When this stage was submitted from the DAGScheduler to a TaskScheduler. */ var submissionTime: Option[Long] = None @@ -42,7 +44,11 @@ class StageInfo( var completionTime: Option[Long] = None /** If the stage failed, the reason why. */ var failureReason: Option[String] = None - /** Terminal values of accumulables updated during this stage. */ + + /** + * Terminal values of accumulables updated during this stage, including all the user-defined + * accumulators. + */ val accumulables = HashMap[Long, AccumulableInfo]() def stageFailed(reason: String) { @@ -75,6 +81,7 @@ private[spark] object StageInfo { stage: Stage, attemptId: Int, numTasks: Option[Int] = None, + internalAccumulators: Seq[Accumulator[_]] = Seq.empty, taskLocalityPreferences: Seq[Seq[TaskLocation]] = Seq.empty ): StageInfo = { val ancestorRddInfos = stage.rdd.getNarrowAncestors.map(RDDInfo.fromRdd) @@ -87,6 +94,7 @@ private[spark] object StageInfo { rddInfos, stage.parents.map(_.id), stage.details, + internalAccumulators, taskLocalityPreferences) } } diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala index 645e2d2e36..bd4797ae8e 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala @@ -203,7 +203,7 @@ private[ui] class JobPage(parent: JobsTab) extends WebUIPage("job") { // This could be empty if the JobProgressListener hasn't received information about the // stage or if the stage information has been garbage collected listener.stageIdToInfo.getOrElse(stageId, - new StageInfo(stageId, 0, "Unknown", 0, Seq.empty, Seq.empty, "Unknown")) + new StageInfo(stageId, 0, "Unknown", 0, Seq.empty, Seq.empty, "Unknown", Seq.empty)) } val activeStages = Buffer[StageInfo]() diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala index 09d955300a..3b78458065 100644 --- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala +++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala @@ -578,7 +578,9 @@ private[spark] object JsonProtocol { // The "Stage Infos" field was added in Spark 1.2.0 val stageInfos = Utils.jsonOption(json \ "Stage Infos") .map(_.extract[Seq[JValue]].map(stageInfoFromJson)).getOrElse { - stageIds.map(id => new StageInfo(id, 0, "unknown", 0, Seq.empty, Seq.empty, "unknown")) + stageIds.map { id => + new StageInfo(id, 0, "unknown", 0, Seq.empty, Seq.empty, "unknown", Seq.empty) + } } SparkListenerJobStart(jobId, submissionTime, stageInfos, properties) } @@ -686,7 +688,7 @@ private[spark] object JsonProtocol { } val stageInfo = new StageInfo( - stageId, attemptId, stageName, numTasks, rddInfos, parentIds, details) + stageId, attemptId, stageName, numTasks, rddInfos, parentIds, details, Seq.empty) stageInfo.submissionTime = submissionTime stageInfo.completionTime = completionTime stageInfo.failureReason = failureReason diff --git a/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala b/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala index 80a1de6065..ee6b991461 100644 --- a/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala @@ -928,8 +928,8 @@ private object ExecutorAllocationManagerSuite extends PrivateMethodTester { numTasks: Int, taskLocalityPreferences: Seq[Seq[TaskLocation]] = Seq.empty ): StageInfo = { - new StageInfo( - stageId, 0, "name", numTasks, Seq.empty, Seq.empty, "no details", taskLocalityPreferences) + new StageInfo(stageId, 0, "name", numTasks, Seq.empty, Seq.empty, "no details", + Seq.empty, taskLocalityPreferences) } private def createTaskInfo(taskId: Int, taskIndex: Int, executorId: String): TaskInfo = { diff --git a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala index 00f3f15c45..cd7d2e1570 100644 --- a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala +++ b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala @@ -337,7 +337,7 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC // first attempt -- its successful val writer1 = manager.getWriter[Int, Int](shuffleHandle, 0, new TaskContextImpl(0, 0, 0L, 0, taskMemoryManager, new Properties, metricsSystem, - InternalAccumulator.create(sc))) + InternalAccumulator.createAll(sc))) val data1 = (1 to 10).map { x => x -> x} // second attempt -- also successful. We'll write out different data, @@ -345,7 +345,7 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC // depending on what gets spilled, what gets combined, etc. val writer2 = manager.getWriter[Int, Int](shuffleHandle, 0, new TaskContextImpl(0, 0, 1L, 0, taskMemoryManager, new Properties, metricsSystem, - InternalAccumulator.create(sc))) + InternalAccumulator.createAll(sc))) val data2 = (11 to 20).map { x => x -> x} // interleave writes of both attempts -- we want to test that both attempts can occur @@ -374,7 +374,7 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC val reader = manager.getReader[Int, Int](shuffleHandle, 0, 1, new TaskContextImpl(1, 0, 2L, 0, taskMemoryManager, new Properties, metricsSystem, - InternalAccumulator.create(sc))) + InternalAccumulator.createAll(sc))) val readData = reader.read().toIndexedSeq assert(readData === data1.toIndexedSeq || readData === data2.toIndexedSeq) diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index 2293c11dad..fd96fb04f8 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -1144,7 +1144,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou // SPARK-9809 -- this stage is submitted without a task for each partition (because some of // the shuffle map output is still available from stage 0); make sure we've still got internal // accumulators setup - assert(scheduler.stageIdToStage(2).internalAccumulators.nonEmpty) + assert(scheduler.stageIdToStage(2).latestInfo.internalAccumulators.nonEmpty) completeShuffleMapStageSuccessfully(2, 0, 2) completeNextResultStageWithSuccess(3, 1, idx => idx + 1234) assert(results === Map(0 -> 1234, 1 -> 1235)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala index 7db1f9654b..01687877ee 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala @@ -114,7 +114,7 @@ class UnsafeRowSerializerSuite extends SparkFunSuite with LocalSparkContext { } val taskMemoryManager = new TaskMemoryManager(sc.env.memoryManager, 0) val taskContext = new TaskContextImpl( - 0, 0, 0, 0, taskMemoryManager, new Properties, null, InternalAccumulator.create(sc)) + 0, 0, 0, 0, taskMemoryManager, new Properties, null, InternalAccumulator.createAll(sc)) val sorter = new ExternalSorter[Int, UnsafeRow, UnsafeRow]( taskContext, -- cgit v1.2.3 From a46f98d3f4ba6a79f4ef789806fec80a7d4f342d Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Thu, 14 Apr 2016 10:56:13 -0700 Subject: [SPARK-14617] Remove deprecated APIs in TaskMetrics ## What changes were proposed in this pull request? This patch removes some of the deprecated APIs in TaskMetrics. This is part of my bigger effort to simplify accumulators and task metrics. ## How was this patch tested? N/A - only removals Author: Reynold Xin Closes #12375 from rxin/SPARK-14617. --- .../org/apache/spark/executor/InputMetrics.scala | 32 ++-------------------- .../org/apache/spark/executor/OutputMetrics.scala | 30 -------------------- .../apache/spark/executor/ShuffleReadMetrics.scala | 21 ++++++++++++++ .../org/apache/spark/executor/TaskMetrics.scala | 27 +----------------- .../scala/org/apache/spark/rdd/HadoopRDD.scala | 4 +-- .../scala/org/apache/spark/rdd/NewHadoopRDD.scala | 4 +-- core/src/main/scala/org/apache/spark/rdd/RDD.scala | 4 +-- .../scala/org/apache/spark/util/JsonProtocol.scala | 4 +-- .../apache/spark/executor/TaskMetricsSuite.scala | 4 +-- .../org/apache/spark/util/JsonProtocolSuite.scala | 2 +- project/MimaExcludes.scala | 5 +++- 11 files changed, 40 insertions(+), 97 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/executor/InputMetrics.scala b/core/src/main/scala/org/apache/spark/executor/InputMetrics.scala index 6d30d3c76a..83e11c5e23 100644 --- a/core/src/main/scala/org/apache/spark/executor/InputMetrics.scala +++ b/core/src/main/scala/org/apache/spark/executor/InputMetrics.scala @@ -81,35 +81,9 @@ class InputMetrics private ( */ def readMethod: DataReadMethod.Value = DataReadMethod.withName(_readMethod.localValue) - // Once incBytesRead & intRecordsRead is ready to be removed from the public API - // we can remove the internal versions and make the previous public API private. - // This has been done to suppress warnings when building. - @deprecated("incrementing input metrics is for internal use only", "2.0.0") - def incBytesRead(v: Long): Unit = _bytesRead.add(v) - private[spark] def incBytesReadInternal(v: Long): Unit = _bytesRead.add(v) - @deprecated("incrementing input metrics is for internal use only", "2.0.0") - def incRecordsRead(v: Long): Unit = _recordsRead.add(v) - private[spark] def incRecordsReadInternal(v: Long): Unit = _recordsRead.add(v) + private[spark] def incBytesRead(v: Long): Unit = _bytesRead.add(v) + private[spark] def incRecordsRead(v: Long): Unit = _recordsRead.add(v) private[spark] def setBytesRead(v: Long): Unit = _bytesRead.setValue(v) - private[spark] def setReadMethod(v: DataReadMethod.Value): Unit = - _readMethod.setValue(v.toString) + private[spark] def setReadMethod(v: DataReadMethod.Value): Unit = _readMethod.setValue(v.toString) } - -/** - * Deprecated methods to preserve case class matching behavior before Spark 2.0. - */ -object InputMetrics { - - @deprecated("matching on InputMetrics will not be supported in the future", "2.0.0") - def apply(readMethod: DataReadMethod.Value): InputMetrics = { - val im = new InputMetrics - im.setReadMethod(readMethod) - im - } - - @deprecated("matching on InputMetrics will not be supported in the future", "2.0.0") - def unapply(input: InputMetrics): Option[DataReadMethod.Value] = { - Some(input.readMethod) - } -} diff --git a/core/src/main/scala/org/apache/spark/executor/OutputMetrics.scala b/core/src/main/scala/org/apache/spark/executor/OutputMetrics.scala index 0b37d559c7..93f953846f 100644 --- a/core/src/main/scala/org/apache/spark/executor/OutputMetrics.scala +++ b/core/src/main/scala/org/apache/spark/executor/OutputMetrics.scala @@ -51,18 +51,6 @@ class OutputMetrics private ( TaskMetrics.getAccum[String](accumMap, InternalAccumulator.output.WRITE_METHOD)) } - /** - * Create a new [[OutputMetrics]] that is not associated with any particular task. - * - * This is only used for preserving matching behavior on [[OutputMetrics]], which used to be - * a case class before Spark 2.0. Once we remove support for matching on [[OutputMetrics]] - * we can remove this constructor as well. - */ - private[executor] def this() { - this(InternalAccumulator.createOutputAccums() - .map { a => (a.name.get, a) }.toMap[String, Accumulator[_]]) - } - /** * Total number of bytes written. */ @@ -84,21 +72,3 @@ class OutputMetrics private ( _writeMethod.setValue(v.toString) } - -/** - * Deprecated methods to preserve case class matching behavior before Spark 2.0. - */ -object OutputMetrics { - - @deprecated("matching on OutputMetrics will not be supported in the future", "2.0.0") - def apply(writeMethod: DataWriteMethod.Value): OutputMetrics = { - val om = new OutputMetrics - om.setWriteMethod(writeMethod) - om - } - - @deprecated("matching on OutputMetrics will not be supported in the future", "2.0.0") - def unapply(output: OutputMetrics): Option[DataWriteMethod.Value] = { - Some(output.writeMethod) - } -} diff --git a/core/src/main/scala/org/apache/spark/executor/ShuffleReadMetrics.scala b/core/src/main/scala/org/apache/spark/executor/ShuffleReadMetrics.scala index 50bb645d97..71a24770b5 100644 --- a/core/src/main/scala/org/apache/spark/executor/ShuffleReadMetrics.scala +++ b/core/src/main/scala/org/apache/spark/executor/ShuffleReadMetrics.scala @@ -116,4 +116,25 @@ class ShuffleReadMetrics private ( private[spark] def setFetchWaitTime(v: Long): Unit = _fetchWaitTime.setValue(v) private[spark] def setRecordsRead(v: Long): Unit = _recordsRead.setValue(v) + /** + * Resets the value of the current metrics (`this`) and and merges all the independent + * [[ShuffleReadMetrics]] into `this`. + */ + private[spark] def setMergeValues(metrics: Seq[ShuffleReadMetrics]): Unit = { + _remoteBlocksFetched.setValue(_remoteBlocksFetched.zero) + _localBlocksFetched.setValue(_localBlocksFetched.zero) + _remoteBytesRead.setValue(_remoteBytesRead.zero) + _localBytesRead.setValue(_localBytesRead.zero) + _fetchWaitTime.setValue(_fetchWaitTime.zero) + _recordsRead.setValue(_recordsRead.zero) + metrics.foreach { metric => + _remoteBlocksFetched.add(metric.remoteBlocksFetched) + _localBlocksFetched.add(metric.localBlocksFetched) + _remoteBytesRead.add(metric.remoteBytesRead) + _localBytesRead.add(metric.localBytesRead) + _fetchWaitTime.add(metric.fetchWaitTime) + _recordsRead.add(metric.recordsRead) + } + } + } diff --git a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala index 02219a84ab..bda2a91d9d 100644 --- a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala +++ b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala @@ -139,16 +139,6 @@ class TaskMetrics private[spark] (initialAccums: Seq[Accumulator[_]]) extends Se */ def updatedBlockStatuses: Seq[(BlockId, BlockStatus)] = _updatedBlockStatuses.localValue - @deprecated("use updatedBlockStatuses instead", "2.0.0") - def updatedBlocks: Option[Seq[(BlockId, BlockStatus)]] = { - if (updatedBlockStatuses.nonEmpty) Some(updatedBlockStatuses) else None - } - - @deprecated("setting updated blocks is not allowed", "2.0.0") - def updatedBlocks_=(blocks: Option[Seq[(BlockId, BlockStatus)]]): Unit = { - blocks.foreach(setUpdatedBlockStatuses) - } - // Setters and increment-ers private[spark] def setExecutorDeserializeTime(v: Long): Unit = _executorDeserializeTime.setValue(v) @@ -225,11 +215,6 @@ class TaskMetrics private[spark] (initialAccums: Seq[Accumulator[_]]) extends Se */ def outputMetrics: Option[OutputMetrics] = _outputMetrics - @deprecated("setting OutputMetrics is for internal use only", "2.0.0") - def outputMetrics_=(om: Option[OutputMetrics]): Unit = { - _outputMetrics = om - } - /** * Get or create a new [[OutputMetrics]] associated with this task. */ @@ -285,12 +270,7 @@ class TaskMetrics private[spark] (initialAccums: Seq[Accumulator[_]]) extends Se private[spark] def mergeShuffleReadMetrics(): Unit = synchronized { if (tempShuffleReadMetrics.nonEmpty) { val metrics = new ShuffleReadMetrics(initialAccumsMap) - metrics.setRemoteBlocksFetched(tempShuffleReadMetrics.map(_.remoteBlocksFetched).sum) - metrics.setLocalBlocksFetched(tempShuffleReadMetrics.map(_.localBlocksFetched).sum) - metrics.setFetchWaitTime(tempShuffleReadMetrics.map(_.fetchWaitTime).sum) - metrics.setRemoteBytesRead(tempShuffleReadMetrics.map(_.remoteBytesRead).sum) - metrics.setLocalBytesRead(tempShuffleReadMetrics.map(_.localBytesRead).sum) - metrics.setRecordsRead(tempShuffleReadMetrics.map(_.recordsRead).sum) + metrics.setMergeValues(tempShuffleReadMetrics) _shuffleReadMetrics = Some(metrics) } } @@ -306,11 +286,6 @@ class TaskMetrics private[spark] (initialAccums: Seq[Accumulator[_]]) extends Se */ def shuffleWriteMetrics: Option[ShuffleWriteMetrics] = _shuffleWriteMetrics - @deprecated("setting ShuffleWriteMetrics is for internal use only", "2.0.0") - def shuffleWriteMetrics_=(swm: Option[ShuffleWriteMetrics]): Unit = { - _shuffleWriteMetrics = swm - } - /** * Get or create a new [[ShuffleWriteMetrics]] associated with this task. */ diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala index f7c646c668..35d190b464 100644 --- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala @@ -259,7 +259,7 @@ class HadoopRDD[K, V]( finished = true } if (!finished) { - inputMetrics.incRecordsReadInternal(1) + inputMetrics.incRecordsRead(1) } if (inputMetrics.recordsRead % SparkHadoopUtil.UPDATE_INPUT_METRICS_INTERVAL_RECORDS == 0) { updateBytesRead() @@ -291,7 +291,7 @@ class HadoopRDD[K, V]( // If we can't get the bytes read from the FS stats, fall back to the split size, // which may be inaccurate. try { - inputMetrics.incBytesReadInternal(split.inputSplit.value.getLength) + inputMetrics.incBytesRead(split.inputSplit.value.getLength) } catch { case e: java.io.IOException => logWarning("Unable to get input size to set InputMetrics for task", e) diff --git a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala index fb9606ae38..3ccd616cbf 100644 --- a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala @@ -189,7 +189,7 @@ class NewHadoopRDD[K, V]( } havePair = false if (!finished) { - inputMetrics.incRecordsReadInternal(1) + inputMetrics.incRecordsRead(1) } if (inputMetrics.recordsRead % SparkHadoopUtil.UPDATE_INPUT_METRICS_INTERVAL_RECORDS == 0) { updateBytesRead() @@ -220,7 +220,7 @@ class NewHadoopRDD[K, V]( // If we can't get the bytes read from the FS stats, fall back to the split size, // which may be inaccurate. try { - inputMetrics.incBytesReadInternal(split.serializableHadoopSplit.value.getLength) + inputMetrics.incBytesRead(split.serializableHadoopSplit.value.getLength) } catch { case e: java.io.IOException => logWarning("Unable to get input size to set InputMetrics for task", e) diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index 032939b49a..36ff3bcaae 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -333,10 +333,10 @@ abstract class RDD[T: ClassTag]( case Left(blockResult) => if (readCachedBlock) { val existingMetrics = context.taskMetrics().registerInputMetrics(blockResult.readMethod) - existingMetrics.incBytesReadInternal(blockResult.bytes) + existingMetrics.incBytesRead(blockResult.bytes) new InterruptibleIterator[T](context, blockResult.data.asInstanceOf[Iterator[T]]) { override def next(): T = { - existingMetrics.incRecordsReadInternal(1) + existingMetrics.incRecordsRead(1) delegate.next() } } diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala index 3b78458065..558767e36f 100644 --- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala +++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala @@ -813,8 +813,8 @@ private[spark] object JsonProtocol { Utils.jsonOption(json \ "Input Metrics").foreach { inJson => val readMethod = DataReadMethod.withName((inJson \ "Data Read Method").extract[String]) val inputMetrics = metrics.registerInputMetrics(readMethod) - inputMetrics.incBytesReadInternal((inJson \ "Bytes Read").extract[Long]) - inputMetrics.incRecordsReadInternal((inJson \ "Records Read").extractOpt[Long].getOrElse(0L)) + inputMetrics.incBytesRead((inJson \ "Bytes Read").extract[Long]) + inputMetrics.incRecordsRead((inJson \ "Records Read").extractOpt[Long].getOrElse(0L)) } // Updated blocks diff --git a/core/src/test/scala/org/apache/spark/executor/TaskMetricsSuite.scala b/core/src/test/scala/org/apache/spark/executor/TaskMetricsSuite.scala index 088b05403c..d91f50f18f 100644 --- a/core/src/test/scala/org/apache/spark/executor/TaskMetricsSuite.scala +++ b/core/src/test/scala/org/apache/spark/executor/TaskMetricsSuite.scala @@ -285,8 +285,8 @@ class TaskMetricsSuite extends SparkFunSuite { // set and increment values in.setBytesRead(1L) in.setBytesRead(2L) - in.incRecordsReadInternal(1L) - in.incRecordsReadInternal(2L) + in.incRecordsRead(1L) + in.incRecordsRead(2L) in.setReadMethod(DataReadMethod.Disk) // assert new values exist assertValEquals(_.bytesRead, BYTES_READ, 2L) diff --git a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala index 6a2d4c9f2c..de6f408fa8 100644 --- a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala @@ -853,7 +853,7 @@ private[spark] object JsonProtocolSuite extends Assertions { if (hasHadoopInput) { val inputMetrics = t.registerInputMetrics(DataReadMethod.Hadoop) inputMetrics.setBytesRead(d + e + f) - inputMetrics.incRecordsReadInternal(if (hasRecords) (d + e + f) / 100 else -1) + inputMetrics.incRecordsRead(if (hasRecords) (d + e + f) / 100 else -1) } else { val sr = t.registerTempShuffleReadMetrics() sr.incRemoteBytesRead(b + d) diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 313bf93b5d..71f337ce1f 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -627,7 +627,10 @@ object MimaExcludes { ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.mllib.clustering.DistributedLDAModel.this") ) ++ Seq( // [SPARK-14475] Propagate user-defined context from driver to executors - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.TaskContext.getLocalProperty") + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.TaskContext.getLocalProperty"), + // [SPARK-14617] Remove deprecated APIs in TaskMetrics + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.executor.InputMetrics$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.executor.OutputMetrics$") ) case v if v.startsWith("1.6") => Seq( -- cgit v1.2.3 From 1d04c86fc575470e15f6667076377cea102552d7 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 14 Apr 2016 10:58:06 -0700 Subject: [SPARK-14558][CORE] In ClosureCleaner, clean the outer pointer if it's a REPL line object ## What changes were proposed in this pull request? When we clean a closure, if its outermost parent is not a closure, we won't clone and clean it as cloning user's objects is dangerous. However, if it's a REPL line object, which may carry a lot of unnecessary references(like hadoop conf, spark conf, etc.), we should clean it as it's not a user object. This PR improves the check for user's objects to exclude REPL line object. ## How was this patch tested? existing tests. Author: Wenchen Fan Closes #12327 from cloud-fan/closure. --- .../org/apache/spark/util/ClosureCleaner.scala | 53 ++++++++++------------ .../scala/org/apache/spark/repl/ReplSuite.scala | 27 +++++++++++ 2 files changed, 50 insertions(+), 30 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala b/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala index 2f6924f7de..489688cb08 100644 --- a/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala +++ b/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala @@ -19,7 +19,8 @@ package org.apache.spark.util import java.io.{ByteArrayInputStream, ByteArrayOutputStream} -import scala.collection.mutable.{Map, Set} +import scala.collection.mutable.{Map, Set, Stack} +import scala.language.existentials import org.apache.xbean.asm5.{ClassReader, ClassVisitor, MethodVisitor, Type} import org.apache.xbean.asm5.Opcodes._ @@ -77,35 +78,19 @@ private[spark] object ClosureCleaner extends Logging { */ private def getInnerClosureClasses(obj: AnyRef): List[Class[_]] = { val seen = Set[Class[_]](obj.getClass) - var stack = List[Class[_]](obj.getClass) + val stack = Stack[Class[_]](obj.getClass) while (!stack.isEmpty) { - val cr = getClassReader(stack.head) - stack = stack.tail + val cr = getClassReader(stack.pop()) val set = Set[Class[_]]() cr.accept(new InnerClosureFinder(set), 0) for (cls <- set -- seen) { seen += cls - stack = cls :: stack + stack.push(cls) } } (seen - obj.getClass).toList } - private def createNullValue(cls: Class[_]): AnyRef = { - if (cls.isPrimitive) { - cls match { - case java.lang.Boolean.TYPE => new java.lang.Boolean(false) - case java.lang.Character.TYPE => new java.lang.Character('\u0000') - case java.lang.Void.TYPE => - // This should not happen because `Foo(void x) {}` does not compile. - throw new IllegalStateException("Unexpected void parameter in constructor") - case _ => new java.lang.Byte(0: Byte) - } - } else { - null - } - } - /** * Clean the given closure in place. * @@ -233,16 +218,24 @@ private[spark] object ClosureCleaner extends Logging { // Note that all outer objects but the outermost one (first one in this list) must be closures var outerPairs: List[(Class[_], AnyRef)] = (outerClasses zip outerObjects).reverse var parent: AnyRef = null - if (outerPairs.size > 0 && !isClosure(outerPairs.head._1)) { - // The closure is ultimately nested inside a class; keep the object of that - // class without cloning it since we don't want to clone the user's objects. - // Note that we still need to keep around the outermost object itself because - // we need it to clone its child closure later (see below). - logDebug(s" + outermost object is not a closure, so do not clone it: ${outerPairs.head}") - parent = outerPairs.head._2 // e.g. SparkContext - outerPairs = outerPairs.tail - } else if (outerPairs.size > 0) { - logDebug(s" + outermost object is a closure, so we just keep it: ${outerPairs.head}") + if (outerPairs.size > 0) { + val (outermostClass, outermostObject) = outerPairs.head + if (isClosure(outermostClass)) { + logDebug(s" + outermost object is a closure, so we clone it: ${outerPairs.head}") + } else if (outermostClass.getName.startsWith("$line")) { + // SPARK-14558: if the outermost object is a REPL line object, we should clone and clean it + // as it may carray a lot of unnecessary information, e.g. hadoop conf, spark conf, etc. + logDebug(s" + outermost object is a REPL line object, so we clone it: ${outerPairs.head}") + } else { + // The closure is ultimately nested inside a class; keep the object of that + // class without cloning it since we don't want to clone the user's objects. + // Note that we still need to keep around the outermost object itself because + // we need it to clone its child closure later (see below). + logDebug(" + outermost object is not a closure or REPL line object, so do not clone it: " + + outerPairs.head) + parent = outermostObject // e.g. SparkContext + outerPairs = outerPairs.tail + } } else { logDebug(" + there are no enclosing objects!") } diff --git a/repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala b/repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala index 7e10f15226..d3dafe9c42 100644 --- a/repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala +++ b/repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala @@ -373,4 +373,31 @@ class ReplSuite extends SparkFunSuite { assertDoesNotContain("error:", output) assertDoesNotContain("Exception", output) } + + test("should clone and clean line object in ClosureCleaner") { + val output = runInterpreterInPasteMode("local-cluster[1,4,4096]", + """ + |import org.apache.spark.rdd.RDD + | + |val lines = sc.textFile("pom.xml") + |case class Data(s: String) + |val dataRDD = lines.map(line => Data(line.take(3))) + |dataRDD.cache.count + |val repartitioned = dataRDD.repartition(dataRDD.partitions.size) + |repartitioned.cache.count + | + |def getCacheSize(rdd: RDD[_]) = { + | sc.getRDDStorageInfo.filter(_.id == rdd.id).map(_.memSize).sum + |} + |val cacheSize1 = getCacheSize(dataRDD) + |val cacheSize2 = getCacheSize(repartitioned) + | + |// The cache size of dataRDD and the repartitioned one should be similar. + |val deviation = math.abs(cacheSize2 - cacheSize1).toDouble / cacheSize1 + |assert(deviation < 0.2, + | s"deviation too large: $deviation, first size: $cacheSize1, second size: $cacheSize2") + """.stripMargin) + assertDoesNotContain("AssertionError", output) + assertDoesNotContain("Exception", output) + } } -- cgit v1.2.3 From c971aee40d806ed02d3d6a5cc478b63654052e54 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Thu, 14 Apr 2016 11:03:19 -0700 Subject: [SPARK-14499][SQL][TEST] Drop Partition Does Not Delete Data of External Tables #### What changes were proposed in this pull request? This PR is to add a test to ensure drop partitions of an external table will not delete data. cc yhuai andrewor14 #### How was this patch tested? N/A Author: gatorsmile This patch had conflicts when merged, resolved by Committer: Andrew Or Closes #12350 from gatorsmile/testDropPartition. --- .../spark/sql/hive/execution/HiveDDLSuite.scala | 67 ++++++++++++++++++++++ 1 file changed, 67 insertions(+) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala index 249dcdfff5..206d911e0d 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.hive.execution +import java.io.File + import org.apache.hadoop.fs.Path import org.apache.spark.sql.{AnalysisException, QueryTest, SaveMode} @@ -126,6 +128,71 @@ class HiveDDLSuite extends QueryTest with SQLTestUtils with TestHiveSingleton { } } + test("add/drop partitions - external table") { + val catalog = hiveContext.sessionState.catalog + withTempDir { tmpDir => + val basePath = tmpDir.getCanonicalPath + val partitionPath_1stCol_part1 = new File(basePath + "/ds=2008-04-08") + val partitionPath_1stCol_part2 = new File(basePath + "/ds=2008-04-09") + val partitionPath_part1 = new File(basePath + "/ds=2008-04-08/hr=11") + val partitionPath_part2 = new File(basePath + "/ds=2008-04-09/hr=11") + val partitionPath_part3 = new File(basePath + "/ds=2008-04-08/hr=12") + val partitionPath_part4 = new File(basePath + "/ds=2008-04-09/hr=12") + val dirSet = + tmpDir :: partitionPath_1stCol_part1 :: partitionPath_1stCol_part2 :: + partitionPath_part1 :: partitionPath_part2 :: partitionPath_part3 :: + partitionPath_part4 :: Nil + + val externalTab = "extTable_with_partitions" + withTable(externalTab) { + assert(tmpDir.listFiles.isEmpty) + sql( + s""" + |CREATE EXTERNAL TABLE $externalTab (key INT, value STRING) + |PARTITIONED BY (ds STRING, hr STRING) + |LOCATION '$basePath' + """.stripMargin) + + // Before data insertion, all the directory are empty + assert(dirSet.forall(dir => dir.listFiles == null || dir.listFiles.isEmpty)) + + for (ds <- Seq("2008-04-08", "2008-04-09"); hr <- Seq("11", "12")) { + sql( + s""" + |INSERT OVERWRITE TABLE $externalTab + |partition (ds='$ds',hr='$hr') + |SELECT 1, 'a' + """.stripMargin) + } + + val hiveTable = catalog.getTableMetadata(TableIdentifier(externalTab, Some("default"))) + assert(hiveTable.tableType == CatalogTableType.EXTERNAL_TABLE) + // After data insertion, all the directory are not empty + assert(dirSet.forall(dir => dir.listFiles.nonEmpty)) + + sql( + s""" + |ALTER TABLE $externalTab DROP PARTITION (ds='2008-04-08'), + |PARTITION (ds='2008-04-09', hr='12') + """.stripMargin) + assert(catalog.listPartitions(TableIdentifier(externalTab)).map(_.spec).toSet == + Set(Map("ds" -> "2008-04-09", "hr" -> "11"))) + // drop partition will not delete the data of external table + assert(dirSet.forall(dir => dir.listFiles.nonEmpty)) + + sql(s"ALTER TABLE $externalTab ADD PARTITION (ds='2008-04-08', hr='12')") + assert(catalog.listPartitions(TableIdentifier(externalTab)).map(_.spec).toSet == + Set(Map("ds" -> "2008-04-08", "hr" -> "12"), Map("ds" -> "2008-04-09", "hr" -> "11"))) + // add partition will not delete the data + assert(dirSet.forall(dir => dir.listFiles.nonEmpty)) + + sql(s"DROP TABLE $externalTab") + // drop table will not delete the data of external table + assert(dirSet.forall(dir => dir.listFiles.nonEmpty)) + } + } + } + test("drop views") { withTable("tab1") { val tabName = "tab1" -- cgit v1.2.3 From 28efdd3fd789fa2ebed5be03b36ca0f682e37669 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 14 Apr 2016 11:08:08 -0700 Subject: [SPARK-14592][SQL] Native support for CREATE TABLE LIKE DDL command ## What changes were proposed in this pull request? JIRA: https://issues.apache.org/jira/browse/SPARK-14592 This patch adds native support for DDL command `CREATE TABLE LIKE`. The SQL syntax is like: CREATE TABLE table_name LIKE existing_table CREATE TABLE IF NOT EXISTS table_name LIKE existing_table ## How was this patch tested? `HiveDDLCommandSuite`. `HiveQuerySuite` already tests `CREATE TABLE LIKE`. Author: Liang-Chi Hsieh This patch had conflicts when merged, resolved by Committer: Andrew Or Closes #12362 from viirya/create-table-like. --- .../apache/spark/sql/catalyst/parser/SqlBase.g4 | 7 ++-- .../spark/sql/execution/command/tables.scala | 40 ++++++++++++++++++++-- .../hive/execution/HiveCompatibilitySuite.scala | 4 ++- .../spark/sql/hive/execution/HiveSqlParser.scala | 13 ++++++- .../spark/sql/hive/HiveDDLCommandSuite.scala | 24 ++++++++++++- 5 files changed, 79 insertions(+), 9 deletions(-) diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index a937ad1eb7..9cf2dd257e 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -55,6 +55,8 @@ statement rowFormat? createFileFormat? locationSpec? (TBLPROPERTIES tablePropertyList)? (AS? query)? #createTable + | CREATE TABLE (IF NOT EXISTS)? target=tableIdentifier + LIKE source=tableIdentifier #createTableLike | ANALYZE TABLE tableIdentifier partitionSpec? COMPUTE STATISTICS (identifier | FOR COLUMNS identifierSeq?)? #analyze | ALTER (TABLE | VIEW) from=tableIdentifier @@ -136,10 +138,7 @@ statement ; hiveNativeCommands - : createTableHeader LIKE tableIdentifier - rowFormat? createFileFormat? locationSpec? - (TBLPROPERTIES tablePropertyList)? - | DELETE FROM tableIdentifier (WHERE booleanExpression)? + : DELETE FROM tableIdentifier (WHERE booleanExpression)? | TRUNCATE TABLE tableIdentifier partitionSpec? (COLUMNS identifierList)? | SHOW COLUMNS (FROM | IN) tableIdentifier ((FROM|IN) identifier)? diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala index e315598daa..0b41985174 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala @@ -17,9 +17,45 @@ package org.apache.spark.sql.execution.command -import org.apache.spark.sql.{Row, SQLContext} +import org.apache.spark.sql.{AnalysisException, Row, SQLContext} import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.catalyst.catalog.CatalogTable +import org.apache.spark.sql.catalyst.catalog.{CatalogTable, CatalogTableType} + +/** + * A command to create a table with the same definition of the given existing table. + * + * The syntax of using this command in SQL is: + * {{{ + * CREATE TABLE [IF NOT EXISTS] [db_name.]table_name + * LIKE [other_db_name.]existing_table_name + * }}} + */ +case class CreateTableLike( + targetTable: TableIdentifier, + sourceTable: TableIdentifier, + ifNotExists: Boolean) extends RunnableCommand { + + override def run(sqlContext: SQLContext): Seq[Row] = { + val catalog = sqlContext.sessionState.catalog + if (!catalog.tableExists(sourceTable)) { + throw new AnalysisException( + s"Source table in CREATE TABLE LIKE does not exist: '$sourceTable'") + } + if (catalog.isTemporaryTable(sourceTable)) { + throw new AnalysisException( + s"Source table in CREATE TABLE LIKE cannot be temporary: '$sourceTable'") + } + + val tableToCreate = catalog.getTableMetadata(sourceTable).copy( + identifier = targetTable, + tableType = CatalogTableType.MANAGED_TABLE, + createTime = System.currentTimeMillis, + lastAccessTime = -1).withNewStorage(locationUri = None) + + catalog.createTable(tableToCreate, ifNotExists) + Seq.empty[Row] + } +} // TODO: move the rest of the table commands from ddl.scala to this file diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala index a45d180464..989e68aebe 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala @@ -416,6 +416,9 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "skewjoinopt18", "skewjoinopt9", + // This test tries to create a table like with TBLPROPERTIES clause, which we don't support. + "create_like_tbl_props", + // Index commands are not supported "drop_index", "drop_index_removes_partition_dirs", @@ -537,7 +540,6 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "count", "cp_mj_rc", "create_insert_outputformat", - "create_like_tbl_props", "create_nested_type", "create_struct_table", "create_view_translate", diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveSqlParser.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveSqlParser.scala index 8c707079a1..a97b65e27b 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveSqlParser.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveSqlParser.scala @@ -31,8 +31,10 @@ import org.apache.spark.sql.catalyst.parser._ import org.apache.spark.sql.catalyst.parser.SqlBaseParser._ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.SparkSqlAstBuilder -import org.apache.spark.sql.execution.command.CreateTable +import org.apache.spark.sql.execution.command.{CreateTable, CreateTableLike} import org.apache.spark.sql.hive.{CreateTableAsSelect => CTAS, CreateViewAsSelect => CreateView, HiveSerDe} +import org.apache.spark.sql.hive.{HiveGenericUDTF, HiveMetastoreTypes, HiveSerDe} +import org.apache.spark.sql.hive.HiveShim.HiveFunctionWrapper /** * Concrete parser for HiveQl statements. @@ -231,6 +233,15 @@ class HiveSqlAstBuilder extends SparkSqlAstBuilder { } } + /** + * Create a [[CreateTableLike]] command. + */ + override def visitCreateTableLike(ctx: CreateTableLikeContext): LogicalPlan = withOrigin(ctx) { + val targetTable = visitTableIdentifier(ctx.target) + val sourceTable = visitTableIdentifier(ctx.source) + CreateTableLike(targetTable, sourceTable, ctx.EXISTS != null) + } + /** * Create or replace a view. This creates a [[CreateViewAsSelect]] command. * diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDDLCommandSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDDLCommandSuite.scala index b87f035cd1..110c6d19d8 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDDLCommandSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDDLCommandSuite.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.catalyst.expressions.JsonTuple import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical.{Generate, ScriptTransformation} -import org.apache.spark.sql.execution.command.CreateTable +import org.apache.spark.sql.execution.command.{CreateTable, CreateTableLike} import org.apache.spark.sql.hive.execution.{HiveNativeCommand, HiveSqlParser} class HiveDDLCommandSuite extends PlanTest { @@ -557,4 +557,26 @@ class HiveDDLCommandSuite extends PlanTest { assertUnsupported("MSCK REPAIR TABLE tab1") } + test("create table like") { + val v1 = "CREATE TABLE table1 LIKE table2" + val (target, source, exists) = parser.parsePlan(v1).collect { + case CreateTableLike(t, s, allowExisting) => (t, s, allowExisting) + }.head + assert(exists == false) + assert(target.database.isEmpty) + assert(target.table == "table1") + assert(source.database.isEmpty) + assert(source.table == "table2") + + val v2 = "CREATE TABLE IF NOT EXISTS table1 LIKE table2" + val (target2, source2, exists2) = parser.parsePlan(v2).collect { + case CreateTableLike(t, s, allowExisting) => (t, s, allowExisting) + }.head + assert(exists2) + assert(target2.database.isEmpty) + assert(target2.table == "table1") + assert(source2.database.isEmpty) + assert(source2.table == "table2") + } + } -- cgit v1.2.3 From c5172f8205beabe58c0b5392c0d83f9fb9c27f18 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Thu, 14 Apr 2016 20:47:31 +0200 Subject: [SPARK-13967][PYSPARK][ML] Added binary Param to Python CountVectorizer Added binary toggle param to CountVectorizer feature transformer in PySpark. Created a unit test for using CountVectorizer with the binary toggle on. Author: Bryan Cutler Closes #12308 from BryanCutler/binary-param-python-CountVectorizer-SPARK-13967. --- python/pyspark/ml/feature.py | 34 +++++++++++++++++++++++++++++----- python/pyspark/ml/tests.py | 16 ++++++++++++++++ 2 files changed, 45 insertions(+), 5 deletions(-) diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index 86b53285b5..0b0c573eea 100644 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -256,24 +256,33 @@ class CountVectorizer(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, vocabSize = Param( Params._dummy(), "vocabSize", "max size of the vocabulary. Default 1 << 18.", typeConverter=TypeConverters.toInt) + binary = Param( + Params._dummy(), "binary", "Binary toggle to control the output vector values." + + " If True, all nonzero counts (after minTF filter applied) are set to 1. This is useful" + + " for discrete probabilistic models that model binary events rather than integer counts." + + " Default False", typeConverter=TypeConverters.toBoolean) @keyword_only - def __init__(self, minTF=1.0, minDF=1.0, vocabSize=1 << 18, inputCol=None, outputCol=None): + def __init__(self, minTF=1.0, minDF=1.0, vocabSize=1 << 18, binary=False, inputCol=None, + outputCol=None): """ - __init__(self, minTF=1.0, minDF=1.0, vocabSize=1 << 18, inputCol=None, outputCol=None) + __init__(self, minTF=1.0, minDF=1.0, vocabSize=1 << 18, binary=False, inputCol=None,\ + outputCol=None) """ super(CountVectorizer, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.CountVectorizer", self.uid) - self._setDefault(minTF=1.0, minDF=1.0, vocabSize=1 << 18) + self._setDefault(minTF=1.0, minDF=1.0, vocabSize=1 << 18, binary=False) kwargs = self.__init__._input_kwargs self.setParams(**kwargs) @keyword_only @since("1.6.0") - def setParams(self, minTF=1.0, minDF=1.0, vocabSize=1 << 18, inputCol=None, outputCol=None): + def setParams(self, minTF=1.0, minDF=1.0, vocabSize=1 << 18, binary=False, inputCol=None, + outputCol=None): """ - setParams(self, minTF=1.0, minDF=1.0, vocabSize=1 << 18, inputCol=None, outputCol=None) + setParams(self, minTF=1.0, minDF=1.0, vocabSize=1 << 18, binary=False, inputCol=None,\ + outputCol=None) Set the params for the CountVectorizer """ kwargs = self.setParams._input_kwargs @@ -324,6 +333,21 @@ class CountVectorizer(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, """ return self.getOrDefault(self.vocabSize) + @since("2.0.0") + def setBinary(self, value): + """ + Sets the value of :py:attr:`binary`. + """ + self._paramMap[self.binary] = value + return self + + @since("2.0.0") + def getBinary(self): + """ + Gets the value of binary or its default value. + """ + return self.getOrDefault(self.binary) + def _create_model(self, java_model): return CountVectorizerModel(java_model) diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index bcbeacbe80..0b0ad2377f 100644 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -406,6 +406,22 @@ class FeatureTests(PySparkTestCase): transformedDF = stopWordRemover.transform(dataset) self.assertEqual(transformedDF.head().output, ["a"]) + def test_count_vectorizer_with_binary(self): + sqlContext = SQLContext(self.sc) + dataset = sqlContext.createDataFrame([ + (0, "a a a b b c".split(' '), SparseVector(3, {0: 1.0, 1: 1.0, 2: 1.0}),), + (1, "a a".split(' '), SparseVector(3, {0: 1.0}),), + (2, "a b".split(' '), SparseVector(3, {0: 1.0, 1: 1.0}),), + (3, "c".split(' '), SparseVector(3, {2: 1.0}),)], ["id", "words", "expected"]) + cv = CountVectorizer(binary=True, inputCol="words", outputCol="features") + model = cv.fit(dataset) + + transformedList = model.transform(dataset).select("features", "expected").collect() + + for r in transformedList: + feature, expected = r + self.assertEqual(feature, expected) + class HasInducedError(Params): -- cgit v1.2.3 From bf65c87f706019d235d7093637341668a13b1be1 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Thu, 14 Apr 2016 12:44:59 -0700 Subject: [SPARK-14618][ML][DOC] Updated RegressionEvaluator.metricName param doc ## What changes were proposed in this pull request? In Spark 1.4, we negated some metrics from RegressionEvaluator since CrossValidator always maximized metrics. This was fixed in 1.5, but the docs were not updated. This PR updates the docs. ## How was this patch tested? no tests Author: Joseph K. Bradley Closes #12377 from jkbradley/regeval-doc. --- .../org/apache/spark/ml/evaluation/RegressionEvaluator.scala | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala index 4134e2dbc5..ed04b67bcc 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala @@ -39,11 +39,12 @@ final class RegressionEvaluator @Since("1.4.0") (@Since("1.4.0") override val ui def this() = this(Identifiable.randomUID("regEval")) /** - * param for metric name in evaluation (supports `"rmse"` (default), `"mse"`, `"r2"`, and `"mae"`) + * Param for metric name in evaluation. Supports: + * - `"rmse"` (default): root mean squared error + * - `"mse"`: mean squared error + * - `"r2"`: R^2^ metric + * - `"mae"`: mean absolute error * - * Because we will maximize evaluation value (ref: `CrossValidator`), - * when we evaluate a metric that is needed to minimize (e.g., `"rmse"`, `"mse"`, `"mae"`), - * we take and output the negative of this metric. * @group param */ @Since("1.4.0") -- cgit v1.2.3 From bc748b7b8f3b5aee28aff9ea078c216ca137a5b7 Mon Sep 17 00:00:00 2001 From: Yong Tang Date: Thu, 14 Apr 2016 21:53:32 +0200 Subject: [SPARK-14238][ML][MLLIB][PYSPARK] Add binary toggle Param to PySpark HashingTF in ML & MLlib ## What changes were proposed in this pull request? This fix tries to add binary toggle Param to PySpark HashingTF in ML & MLlib. If this toggle is set, then all non-zero counts will be set to 1. Note: This fix (SPARK-14238) is extended from SPARK-13963 where Scala implementation was done. ## How was this patch tested? This fix adds two tests to cover the code changes. One for HashingTF in PySpark's ML and one for HashingTF in PySpark's MLLib. Author: Yong Tang Closes #12079 from yongtang/SPARK-14238. --- python/pyspark/ml/feature.py | 24 ++++++++++++++++++++++-- python/pyspark/ml/tests.py | 19 +++++++++++++++++++ python/pyspark/mllib/feature.py | 13 ++++++++++++- python/pyspark/mllib/tests.py | 16 ++++++++++++++++ 4 files changed, 69 insertions(+), 3 deletions(-) diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index 0b0c573eea..809a513316 100644 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -536,14 +536,19 @@ class HashingTF(JavaTransformer, HasInputCol, HasOutputCol, HasNumFeatures, Java .. versionadded:: 1.3.0 """ + binary = Param(Params._dummy(), "binary", "If True, all non zero counts are set to 1. " + + "This is useful for discrete probabilistic models that model binary events " + + "rather than integer counts. Default False.", + typeConverter=TypeConverters.toBoolean) + @keyword_only - def __init__(self, numFeatures=1 << 18, inputCol=None, outputCol=None): + def __init__(self, numFeatures=1 << 18, binary=False, inputCol=None, outputCol=None): """ __init__(self, numFeatures=1 << 18, inputCol=None, outputCol=None) """ super(HashingTF, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.HashingTF", self.uid) - self._setDefault(numFeatures=1 << 18) + self._setDefault(numFeatures=1 << 18, binary=False) kwargs = self.__init__._input_kwargs self.setParams(**kwargs) @@ -557,6 +562,21 @@ class HashingTF(JavaTransformer, HasInputCol, HasOutputCol, HasNumFeatures, Java kwargs = self.setParams._input_kwargs return self._set(**kwargs) + @since("2.0.0") + def setBinary(self, value): + """ + Sets the value of :py:attr:`binary`. + """ + self._paramMap[self.binary] = value + return self + + @since("2.0.0") + def getBinary(self): + """ + Gets the value of binary or its default value. + """ + return self.getOrDefault(self.binary) + @inherit_doc class IDF(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable): diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index 0b0ad2377f..86c0254a2b 100644 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -847,6 +847,25 @@ class TrainingSummaryTest(PySparkTestCase): self.assertAlmostEqual(sameSummary.areaUnderROC, s.areaUnderROC) +class HashingTFTest(PySparkTestCase): + + def test_apply_binary_term_freqs(self): + sqlContext = SQLContext(self.sc) + + df = sqlContext.createDataFrame([(0, ["a", "a", "b", "c", "c", "c"])], ["id", "words"]) + n = 100 + hashingTF = HashingTF() + hashingTF.setInputCol("words").setOutputCol("features").setNumFeatures(n).setBinary(True) + output = hashingTF.transform(df) + features = output.select("features").first().features.toArray() + expected = Vectors.sparse(n, {(ord("a") % n): 1.0, + (ord("b") % n): 1.0, + (ord("c") % n): 1.0}).toArray() + for i in range(0, n): + self.assertAlmostEqual(features[i], expected[i], 14, "Error at " + str(i) + + ": expected " + str(expected[i]) + ", got " + str(features[i])) + + if __name__ == "__main__": from pyspark.ml.tests import * if xmlrunner: diff --git a/python/pyspark/mllib/feature.py b/python/pyspark/mllib/feature.py index 6129353525..b3dd2f63a5 100644 --- a/python/pyspark/mllib/feature.py +++ b/python/pyspark/mllib/feature.py @@ -379,6 +379,17 @@ class HashingTF(object): """ def __init__(self, numFeatures=1 << 20): self.numFeatures = numFeatures + self.binary = False + + @since("2.0.0") + def setBinary(self, value): + """ + If True, term frequency vector will be binary such that non-zero + term counts will be set to 1 + (default: False) + """ + self.binary = value + return self @since('1.2.0') def indexOf(self, term): @@ -398,7 +409,7 @@ class HashingTF(object): freq = {} for term in document: i = self.indexOf(term) - freq[i] = freq.get(i, 0) + 1.0 + freq[i] = 1.0 if self.binary else freq.get(i, 0) + 1.0 return Vectors.sparse(self.numFeatures, freq.items()) diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py index 5f515b666c..ac55fbf798 100644 --- a/python/pyspark/mllib/tests.py +++ b/python/pyspark/mllib/tests.py @@ -58,6 +58,7 @@ from pyspark.mllib.recommendation import Rating from pyspark.mllib.regression import LabeledPoint, StreamingLinearRegressionWithSGD from pyspark.mllib.random import RandomRDDs from pyspark.mllib.stat import Statistics +from pyspark.mllib.feature import HashingTF from pyspark.mllib.feature import Word2Vec from pyspark.mllib.feature import IDF from pyspark.mllib.feature import StandardScaler, ElementwiseProduct @@ -1583,6 +1584,21 @@ class ALSTests(MLlibTestCase): self.assertRaises(Py4JJavaError, self.sc._jvm.SerDe.loads, bytearray(ser.dumps(r))) +class HashingTFTest(MLlibTestCase): + + def test_binary_term_freqs(self): + hashingTF = HashingTF(100).setBinary(True) + doc = "a a b c c c".split(" ") + n = hashingTF.numFeatures + output = hashingTF.transform(doc).toArray() + expected = Vectors.sparse(n, {hashingTF.indexOf("a"): 1.0, + hashingTF.indexOf("b"): 1.0, + hashingTF.indexOf("c"): 1.0}).toArray() + for i in range(0, n): + self.assertAlmostEqual(output[i], expected[i], 14, "Error at " + str(i) + + ": expected " + str(expected[i]) + ", got " + str(output[i])) + + if __name__ == "__main__": from pyspark.mllib.tests import * if not _have_scipy: -- cgit v1.2.3 From d7e124edfe2578ecdf8e816a4dda3ce430a09172 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Thu, 14 Apr 2016 13:34:29 -0700 Subject: [SPARK-14545][SQL] Improve `LikeSimplification` by adding `a%b` rule ## What changes were proposed in this pull request? Current `LikeSimplification` handles the following four rules. - 'a%' => expr.StartsWith("a") - '%b' => expr.EndsWith("b") - '%a%' => expr.Contains("a") - 'a' => EqualTo("a") This PR adds the following rule. - 'a%b' => expr.Length() >= 2 && expr.StartsWith("a") && expr.EndsWith("b") Here, 2 is statically calculated from "a".size + "b".size. **Before** ``` scala> sql("select a from (select explode(array('abc','adc')) a) T where a like 'a%c'").explain() == Physical Plan == WholeStageCodegen : +- Filter a#5 LIKE a%c : +- INPUT +- Generate explode([abc,adc]), false, false, [a#5] +- Scan OneRowRelation[] ``` **After** ``` scala> sql("select a from (select explode(array('abc','adc')) a) T where a like 'a%c'").explain() == Physical Plan == WholeStageCodegen : +- Filter ((length(a#5) >= 2) && (StartsWith(a#5, a) && EndsWith(a#5, c))) : +- INPUT +- Generate explode([abc,adc]), false, false, [a#5] +- Scan OneRowRelation[] ``` ## How was this patch tested? Pass the Jenkins tests (including new testcase). Author: Dongjoon Hyun Closes #12312 from dongjoon-hyun/SPARK-14545. --- .../spark/sql/catalyst/optimizer/Optimizer.scala | 28 +++++++++++++--------- .../optimizer/LikeSimplificationSuite.scala | 14 +++++++++++ 2 files changed, 31 insertions(+), 11 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index aeb1842677..f5172b213a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -517,22 +517,28 @@ object LikeSimplification extends Rule[LogicalPlan] { // Cases like "something\%" are not optimized, but this does not affect correctness. private val startsWith = "([^_%]+)%".r private val endsWith = "%([^_%]+)".r + private val startsAndEndsWith = "([^_%]+)%([^_%]+)".r private val contains = "%([^_%]+)%".r private val equalTo = "([^_%]*)".r def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { - case Like(l, Literal(utf, StringType)) => - utf.toString match { - case startsWith(pattern) if !pattern.endsWith("\\") => - StartsWith(l, Literal(pattern)) - case endsWith(pattern) => - EndsWith(l, Literal(pattern)) - case contains(pattern) if !pattern.endsWith("\\") => - Contains(l, Literal(pattern)) - case equalTo(pattern) => - EqualTo(l, Literal(pattern)) + case Like(input, Literal(pattern, StringType)) => + pattern.toString match { + case startsWith(prefix) if !prefix.endsWith("\\") => + StartsWith(input, Literal(prefix)) + case endsWith(postfix) => + EndsWith(input, Literal(postfix)) + // 'a%a' pattern is basically same with 'a%' && '%a'. + // However, the additional `Length` condition is required to prevent 'a' match 'a%a'. + case startsAndEndsWith(prefix, postfix) if !prefix.endsWith("\\") => + And(GreaterThanOrEqual(Length(input), Literal(prefix.size + postfix.size)), + And(StartsWith(input, Literal(prefix)), EndsWith(input, Literal(postfix)))) + case contains(infix) if !infix.endsWith("\\") => + Contains(input, Literal(infix)) + case equalTo(str) => + EqualTo(input, Literal(str)) case _ => - Like(l, Literal.create(utf, StringType)) + Like(input, Literal.create(pattern, StringType)) } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LikeSimplificationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LikeSimplificationSuite.scala index 741bc113cf..fdde89d079 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LikeSimplificationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LikeSimplificationSuite.scala @@ -61,6 +61,20 @@ class LikeSimplificationSuite extends PlanTest { comparePlans(optimized, correctAnswer) } + test("simplify Like into startsWith and EndsWith") { + val originalQuery = + testRelation + .where(('a like "abc\\%def") || ('a like "abc%def")) + + val optimized = Optimize.execute(originalQuery.analyze) + val correctAnswer = testRelation + .where(('a like "abc\\%def") || + (Length('a) >= 6 && (StartsWith('a, "abc") && EndsWith('a, "def")))) + .analyze + + comparePlans(optimized, correctAnswer) + } + test("simplify Like into Contains") { val originalQuery = testRelation -- cgit v1.2.3
    (none) A string of extra JVM options to pass to the driver. For instance, GC settings or other logging. + Note that it is illegal to set maximum heap size (-Xmx) settings with this option. Maximum heap + size settings can be set with spark.driver.memory in the cluster mode and through + the --driver-memory command line option in the client mode.
    Note: In client mode, this config must not be set through the SparkConf directly in your application, because the driver JVM has already started at that point. Instead, please set this through the --driver-java-options command line option or in - your default properties file.
    (none) A string of extra JVM options to pass to executors. For instance, GC settings or other logging. - Note that it is illegal to set Spark properties or heap size settings with this option. Spark - properties should be set using a SparkConf object or the spark-defaults.conf file used with the - spark-submit script. Heap size settings can be set with spark.executor.memory. + Note that it is illegal to set Spark properties or maximum heap size (-Xmx) settings with this + option. Spark properties should be set using a SparkConf object or the spark-defaults.conf file + used with the spark-submit script. Maximum heap size settings can be set with spark.executor.memory.
    (none) A string of extra JVM options to pass to the YARN Application Master in client mode. - In cluster mode, use spark.driver.extraJavaOptions instead. + In cluster mode, use spark.driver.extraJavaOptions instead. Note that it is illegal + to set maximum heap size (-Xmx) settings with this option. Maximum heap size settings can be set + with spark.yarn.am.memory
    {Utils.bytesToString(d.toLong)} @@ -359,13 +359,13 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { Duration @@ -374,8 +374,8 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { @@ -385,8 +385,8 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { @@ -397,8 +397,8 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { @@ -412,8 +412,8 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { // The scheduler delay includes the network delay to send the task to the worker // machine and to send back the result (but not the time to fetch the task result, // if it needed to be fetched from the block manager on the worker). - val schedulerDelays = validTasks.map { case TaskUIData(info, metrics, _) => - getSchedulerDelay(info, metrics.get, currentTime).toDouble + val schedulerDelays = validTasks.map { taskUIData: TaskUIData => + getSchedulerDelay(taskUIData.taskInfo, taskUIData.metrics.get, currentTime).toDouble } val schedulerDelayTitle = Scheduler DelayInput Size / RecordsOutput Size / Records @@ -461,11 +461,11 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { @@ -476,8 +476,8 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { @@ -488,25 +488,25 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { Shuffle Write Size / RecordsShuffle spill (memory)Shuffle spill (disk)