diff options
author | Xiangrui Meng <meng@databricks.com> | 2016-04-12 11:30:09 -0700 |
---|---|---|
committer | Xiangrui Meng <meng@databricks.com> | 2016-04-12 11:30:09 -0700 |
commit | 1995c2e6482bf4af5a4be087bfc156311c1bec19 (patch) | |
tree | b653cb6402190b73ffc69b97a620599f5072b4b0 /mllib/src/main/scala | |
parent | 7f024c47441a2f84fcc34a6021b976f036ea24c4 (diff) | |
download | spark-1995c2e6482bf4af5a4be087bfc156311c1bec19.tar.gz spark-1995c2e6482bf4af5a4be087bfc156311c1bec19.tar.bz2 spark-1995c2e6482bf4af5a4be087bfc156311c1bec19.zip |
[SPARK-14563][ML] use a random table name instead of __THIS__ in SQLTransformer
## What changes were proposed in this pull request?
Use a random table name instead of `__THIS__` in SQLTransformer, and add a test for `transformSchema`. The problems of using `__THIS__` are:
* It doesn't work under HiveContext (in Spark 1.6)
* Race conditions
## How was this patch tested?
* Manual test with HiveContext.
* Added a unit test for `transformSchema` to improve coverage.
cc: yhuai
Author: Xiangrui Meng <meng@databricks.com>
Closes #12330 from mengxr/SPARK-14563.
Diffstat (limited to 'mllib/src/main/scala')
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala | 10 |
1 files changed, 6 insertions, 4 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala index 95fe942c6b..2002d15745 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala @@ -68,8 +68,7 @@ class SQLTransformer @Since("1.6.0") (override val uid: String) extends Transfor val tableName = Identifiable.randomUID(uid) dataset.registerTempTable(tableName) val realStatement = $(statement).replace(tableIdentifier, tableName) - val outputDF = dataset.sqlContext.sql(realStatement) - outputDF + dataset.sqlContext.sql(realStatement) } @Since("1.6.0") @@ -78,8 +77,11 @@ class SQLTransformer @Since("1.6.0") (override val uid: String) extends Transfor val sqlContext = SQLContext.getOrCreate(sc) val dummyRDD = sc.parallelize(Seq(Row.empty)) val dummyDF = sqlContext.createDataFrame(dummyRDD, schema) - dummyDF.registerTempTable(tableIdentifier) - val outputSchema = sqlContext.sql($(statement)).schema + val tableName = Identifiable.randomUID(uid) + val realStatement = $(statement).replace(tableIdentifier, tableName) + dummyDF.registerTempTable(tableName) + val outputSchema = sqlContext.sql(realStatement).schema + sqlContext.dropTempTable(tableName) outputSchema } |