From d9efe63ecdc60a9955f1924de0e8a00bcb6a559d Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Sun, 21 Feb 2016 22:53:15 +0800 Subject: [SPARK-12799] Simplify various string output for expressions This PR introduces several major changes: 1. Replacing `Expression.prettyString` with `Expression.sql` The `prettyString` method is mostly an internal, developer faced facility for debugging purposes, and shouldn't be exposed to users. 1. Using SQL-like representation as column names for selected fields that are not named expression (back-ticks and double quotes should be removed) Before, we were using `prettyString` as column names when possible, and sometimes the result column names can be weird. Here are several examples: Expression | `prettyString` | `sql` | Note ------------------ | -------------- | ---------- | --------------- `a && b` | `a && b` | `a AND b` | `a.getField("f")` | `a[f]` | `a.f` | `a` is a struct 1. Adding trait `NonSQLExpression` extending from `Expression` for expressions that don't have a SQL representation (e.g. Scala UDF/UDAF and Java/Scala object expressions used for encoders) `NonSQLExpression.sql` may return an arbitrary user facing string representation of the expression. Author: Cheng Lian Closes #10757 from liancheng/spark-12799.simplify-expression-string-methods. --- .../org/apache/spark/sql/hive/SQLBuilder.scala | 22 ++++++-- .../scala/org/apache/spark/sql/hive/hiveUDFs.scala | 43 ++++++++++----- .../spark/sql/hive/ExpressionSQLBuilderSuite.scala | 15 +++-- .../spark/sql/hive/LogicalPlanToSQLSuite.scala | 64 ++++++++++++++++++++++ .../spark/sql/hive/execution/HivePlanTest.scala | 4 +- .../apache/spark/sql/hive/orc/OrcSourceSuite.scala | 22 ++++---- .../org/apache/spark/sql/hive/parquetSuites.scala | 5 +- 7 files changed, 137 insertions(+), 38 deletions(-) (limited to 'sql/hive') diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala index bf5edb4759..1dda39d44e 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala @@ -24,10 +24,11 @@ import scala.util.control.NonFatal import org.apache.spark.Logging import org.apache.spark.sql.{DataFrame, SQLContext} import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.catalyst.expressions.{Attribute, NamedExpression, SortOrder} +import org.apache.spark.sql.catalyst.expressions.{Attribute, NamedExpression, NonSQLExpression, SortOrder} import org.apache.spark.sql.catalyst.optimizer.CollapseProject import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.{Rule, RuleExecutor} +import org.apache.spark.sql.catalyst.util.quoteIdentifier import org.apache.spark.sql.execution.datasources.LogicalRelation /** @@ -37,11 +38,21 @@ import org.apache.spark.sql.execution.datasources.LogicalRelation * supported by this builder (yet). */ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Logging { + require(logicalPlan.resolved, "SQLBuilder only supports resloved logical query plans") + def this(df: DataFrame) = this(df.queryExecution.analyzed, df.sqlContext) def toSQL: String = { val canonicalizedPlan = Canonicalizer.execute(logicalPlan) try { + canonicalizedPlan.transformAllExpressions { + case e: NonSQLExpression => + throw new UnsupportedOperationException( + s"Expression $e doesn't have a SQL representation" + ) + case e => e + } + val generatedSQL = toSQL(canonicalizedPlan) logDebug( s"""Built SQL query string successfully from given logical plan: @@ -95,7 +106,7 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi p.child match { // Persisted data source relation case LogicalRelation(_, _, Some(TableIdentifier(table, Some(database)))) => - s"`$database`.`$table`" + s"${quoteIdentifier(database)}.${quoteIdentifier(table)}" // Parentheses is not used for persisted data source relations // e.g., select x.c1 from (t1) as x inner join (t1) as y on x.c1 = y.c1 case Subquery(_, _: LogicalRelation | _: MetastoreRelation) => @@ -114,8 +125,8 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi case p: MetastoreRelation => build( - s"`${p.databaseName}`.`${p.tableName}`", - p.alias.map(a => s" AS `$a`").getOrElse("") + s"${quoteIdentifier(p.databaseName)}.${quoteIdentifier(p.tableName)}", + p.alias.map(a => s" AS ${quoteIdentifier(a)}").getOrElse("") ) case Sort(orders, _, RepartitionByExpression(partitionExprs, child, _)) @@ -148,7 +159,8 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi * The segments are trimmed so only a single space appears in the separation. * For example, `build("a", " b ", " c")` becomes "a b c". */ - private def build(segments: String*): String = segments.map(_.trim).mkString(" ") + private def build(segments: String*): String = + segments.map(_.trim).filter(_.nonEmpty).mkString(" ") private def projectToSQL(plan: Project, isDistinct: Boolean): String = { build( diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala index d5ed838ca4..bcafa045e0 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala @@ -36,7 +36,6 @@ import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback -import org.apache.spark.sql.catalyst.util.sequenceOption import org.apache.spark.sql.hive.HiveShim._ import org.apache.spark.sql.hive.client.HiveClientImpl import org.apache.spark.sql.types._ @@ -70,18 +69,28 @@ private[hive] class HiveFunctionRegistry( // catch the exception and throw AnalysisException instead. try { if (classOf[GenericUDFMacro].isAssignableFrom(functionInfo.getFunctionClass)) { - HiveGenericUDF( + val udf = HiveGenericUDF( name, new HiveFunctionWrapper(functionClassName, functionInfo.getGenericUDF), children) + udf.dataType // Force it to check input data types. + udf } else if (classOf[UDF].isAssignableFrom(functionInfo.getFunctionClass)) { - HiveSimpleUDF(name, new HiveFunctionWrapper(functionClassName), children) + val udf = HiveSimpleUDF(name, new HiveFunctionWrapper(functionClassName), children) + udf.dataType // Force it to check input data types. + udf } else if (classOf[GenericUDF].isAssignableFrom(functionInfo.getFunctionClass)) { - HiveGenericUDF(name, new HiveFunctionWrapper(functionClassName), children) + val udf = HiveGenericUDF(name, new HiveFunctionWrapper(functionClassName), children) + udf.dataType // Force it to check input data types. + udf } else if ( classOf[AbstractGenericUDAFResolver].isAssignableFrom(functionInfo.getFunctionClass)) { - HiveUDAFFunction(name, new HiveFunctionWrapper(functionClassName), children) + val udaf = HiveUDAFFunction(name, new HiveFunctionWrapper(functionClassName), children) + udaf.dataType // Force it to check input data types. + udaf } else if (classOf[UDAF].isAssignableFrom(functionInfo.getFunctionClass)) { - HiveUDAFFunction( + val udaf = HiveUDAFFunction( name, new HiveFunctionWrapper(functionClassName), children, isUDAFBridgeRequired = true) + udaf.dataType // Force it to check input data types. + udaf } else if (classOf[GenericUDTF].isAssignableFrom(functionInfo.getFunctionClass)) { val udtf = HiveGenericUDTF(name, new HiveFunctionWrapper(functionClassName), children) udtf.elementTypes // Force it to check input data types. @@ -163,7 +172,7 @@ private[hive] case class HiveSimpleUDF( @transient private lazy val conversionHelper = new ConversionHelper(method, arguments) - override val dataType = javaClassToDataType(method.getReturnType) + override lazy val dataType = javaClassToDataType(method.getReturnType) @transient lazy val returnInspector = ObjectInspectorFactory.getReflectionObjectInspector( @@ -189,6 +198,8 @@ private[hive] case class HiveSimpleUDF( s"$nodeName#${funcWrapper.functionClassName}(${children.mkString(",")})" } + override def prettyName: String = name + override def sql: String = s"$name(${children.map(_.sql).mkString(", ")})" } @@ -233,11 +244,11 @@ private[hive] case class HiveGenericUDF( } @transient - private lazy val deferedObjects = argumentInspectors.zip(children).map { case (inspect, child) => + private lazy val deferredObjects = argumentInspectors.zip(children).map { case (inspect, child) => new DeferredObjectAdapter(inspect, child.dataType) }.toArray[DeferredObject] - override val dataType: DataType = inspectorToDataType(returnInspector) + override lazy val dataType: DataType = inspectorToDataType(returnInspector) override def eval(input: InternalRow): Any = { returnInspector // Make sure initialized. @@ -245,20 +256,20 @@ private[hive] case class HiveGenericUDF( var i = 0 while (i < children.length) { val idx = i - deferedObjects(i).asInstanceOf[DeferredObjectAdapter].set( + deferredObjects(i).asInstanceOf[DeferredObjectAdapter].set( () => { children(idx).eval(input) }) i += 1 } - unwrap(function.evaluate(deferedObjects), returnInspector) + unwrap(function.evaluate(deferredObjects), returnInspector) } + override def prettyName: String = name + override def toString: String = { s"$nodeName#${funcWrapper.functionClassName}(${children.mkString(",")})" } - - override def sql: String = s"$name(${children.map(_.sql).mkString(", ")})" } /** @@ -340,7 +351,7 @@ private[hive] case class HiveGenericUDTF( s"$nodeName#${funcWrapper.functionClassName}(${children.mkString(",")})" } - override def sql: String = s"$name(${children.map(_.sql).mkString(", ")})" + override def prettyName: String = name } /** @@ -432,7 +443,9 @@ private[hive] case class HiveUDAFFunction( override def supportsPartial: Boolean = false - override val dataType: DataType = inspectorToDataType(returnInspector) + override lazy val dataType: DataType = inspectorToDataType(returnInspector) + + override def prettyName: String = name override def sql(isDistinct: Boolean): String = { val distinct = if (isDistinct) "DISTINCT " else " " diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ExpressionSQLBuilderSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ExpressionSQLBuilderSuite.scala index 3fb6543b1a..e4b4d1861a 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ExpressionSQLBuilderSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ExpressionSQLBuilderSuite.scala @@ -26,17 +26,24 @@ class ExpressionSQLBuilderSuite extends SQLBuilderTest { test("literal") { checkSQL(Literal("foo"), "\"foo\"") checkSQL(Literal("\"foo\""), "\"\\\"foo\\\"\"") - checkSQL(Literal(1: Byte), "CAST(1 AS TINYINT)") - checkSQL(Literal(2: Short), "CAST(2 AS SMALLINT)") + checkSQL(Literal(1: Byte), "1Y") + checkSQL(Literal(2: Short), "2S") checkSQL(Literal(4: Int), "4") - checkSQL(Literal(8: Long), "CAST(8 AS BIGINT)") + checkSQL(Literal(8: Long), "8L") checkSQL(Literal(1.5F), "CAST(1.5 AS FLOAT)") - checkSQL(Literal(2.5D), "2.5") + checkSQL(Literal(2.5D), "2.5D") checkSQL( Literal(Timestamp.valueOf("2016-01-01 00:00:00")), "TIMESTAMP('2016-01-01 00:00:00.0')") // TODO tests for decimals } + test("attributes") { + checkSQL('a.int, "`a`") + checkSQL(Symbol("foo bar").int, "`foo bar`") + // Keyword + checkSQL('int.int, "`int`") + } + test("binary comparisons") { checkSQL('a.int === 'b.int, "(`a` = `b`)") checkSQL('a.int <=> 'b.int, "(`a` <=> `b`)") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala index dc8ac7e47f..5255b150aa 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala @@ -29,6 +29,7 @@ class LogicalPlanToSQLSuite extends SQLBuilderTest with SQLTestUtils { sql("DROP TABLE IF EXISTS t0") sql("DROP TABLE IF EXISTS t1") sql("DROP TABLE IF EXISTS t2") + sqlContext.range(10).write.saveAsTable("t0") sqlContext @@ -168,4 +169,67 @@ class LogicalPlanToSQLSuite extends SQLBuilderTest with SQLTestUtils { } } } + + test("plans with non-SQL expressions") { + sqlContext.udf.register("foo", (_: Int) * 2) + intercept[UnsupportedOperationException](new SQLBuilder(sql("SELECT foo(id) FROM t0")).toSQL) + } + + test("named expression in column names shouldn't be quoted") { + def checkColumnNames(query: String, expectedColNames: String*): Unit = { + checkHiveQl(query) + assert(sql(query).columns === expectedColNames) + } + + // Attributes + checkColumnNames( + """SELECT * FROM ( + | SELECT 1 AS a, 2 AS b, 3 AS `we``ird` + |) s + """.stripMargin, + "a", "b", "we`ird" + ) + + checkColumnNames( + """SELECT x.a, y.a, x.b, y.b + |FROM (SELECT 1 AS a, 2 AS b) x + |INNER JOIN (SELECT 1 AS a, 2 AS b) y + |ON x.a = y.a + """.stripMargin, + "a", "a", "b", "b" + ) + + // String literal + checkColumnNames( + "SELECT 'foo', '\"bar\\''", + "foo", "\"bar\'" + ) + + // Numeric literals (should have CAST or suffixes in column names) + checkColumnNames( + "SELECT 1Y, 2S, 3, 4L, 5.1, 6.1D", + "1", "2", "3", "4", "5.1", "6.1" + ) + + // Aliases + checkColumnNames( + "SELECT 1 AS a", + "a" + ) + + // Complex type extractors + checkColumnNames( + """SELECT + | a.f1, b[0].f1, b.f1, c["foo"], d[0] + |FROM ( + | SELECT + | NAMED_STRUCT("f1", 1, "f2", "foo") AS a, + | ARRAY(NAMED_STRUCT("f1", 1, "f2", "foo")) AS b, + | MAP("foo", 1) AS c, + | ARRAY(1) AS d + |) s + """.stripMargin, + "f1", "b[0].f1", "f1", "c[foo]", "d[0]" + ) + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HivePlanTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HivePlanTest.scala index cd055f9eca..d8d3448add 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HivePlanTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HivePlanTest.scala @@ -29,8 +29,8 @@ class HivePlanTest extends QueryTest with TestHiveSingleton { test("udf constant folding") { Seq.empty[Tuple1[Int]].toDF("a").registerTempTable("t") - val optimized = sql("SELECT cos(null) FROM t").queryExecution.optimizedPlan - val correctAnswer = sql("SELECT cast(null as double) FROM t").queryExecution.optimizedPlan + val optimized = sql("SELECT cos(null) AS c FROM t").queryExecution.optimizedPlan + val correctAnswer = sql("SELECT cast(null as double) AS c FROM t").queryExecution.optimizedPlan comparePlans(optimized, correctAnswer) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala index 27ea3e8041..fe446774ef 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala @@ -131,17 +131,17 @@ abstract class OrcSuite extends QueryTest with TestHiveSingleton with BeforeAndA val df = sql( """ |SELECT - | CAST(null as TINYINT), - | CAST(null as SMALLINT), - | CAST(null as INT), - | CAST(null as BIGINT), - | CAST(null as FLOAT), - | CAST(null as DOUBLE), - | CAST(null as DECIMAL(7,2)), - | CAST(null as TIMESTAMP), - | CAST(null as DATE), - | CAST(null as STRING), - | CAST(null as VARCHAR(10)) + | CAST(null as TINYINT) as c0, + | CAST(null as SMALLINT) as c1, + | CAST(null as INT) as c2, + | CAST(null as BIGINT) as c3, + | CAST(null as FLOAT) as c4, + | CAST(null as DOUBLE) as c5, + | CAST(null as DECIMAL(7,2)) as c6, + | CAST(null as TIMESTAMP) as c7, + | CAST(null as DATE) as c8, + | CAST(null as STRING) as c9, + | CAST(null as VARCHAR(10)) as c10 |FROM orc_temp_table limit 1 """.stripMargin) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala index c997453803..a6ca7d0386 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala @@ -626,7 +626,10 @@ class ParquetSourceSuite extends ParquetPartitioningTest { sql( s"""CREATE TABLE array_of_struct |STORED AS PARQUET LOCATION '$path' - |AS SELECT '1st', '2nd', ARRAY(NAMED_STRUCT('a', 'val_a', 'b', 'val_b')) + |AS SELECT + | '1st' AS a, + | '2nd' AS b, + | ARRAY(NAMED_STRUCT('a', 'val_a', 'b', 'val_b')) AS c """.stripMargin) checkAnswer( -- cgit v1.2.3