aboutsummaryrefslogtreecommitdiff
path: root/mllib/src
diff options
context:
space:
mode:
authorXiangrui Meng <meng@databricks.com>2016-04-12 11:30:09 -0700
committerXiangrui Meng <meng@databricks.com>2016-04-12 11:30:09 -0700
commit1995c2e6482bf4af5a4be087bfc156311c1bec19 (patch)
treeb653cb6402190b73ffc69b97a620599f5072b4b0 /mllib/src
parent7f024c47441a2f84fcc34a6021b976f036ea24c4 (diff)
downloadspark-1995c2e6482bf4af5a4be087bfc156311c1bec19.tar.gz
spark-1995c2e6482bf4af5a4be087bfc156311c1bec19.tar.bz2
spark-1995c2e6482bf4af5a4be087bfc156311c1bec19.zip
[SPARK-14563][ML] use a random table name instead of __THIS__ in SQLTransformer
## What changes were proposed in this pull request? Use a random table name instead of `__THIS__` in SQLTransformer, and add a test for `transformSchema`. The problems of using `__THIS__` are: * It doesn't work under HiveContext (in Spark 1.6) * Race conditions ## How was this patch tested? * Manual test with HiveContext. * Added a unit test for `transformSchema` to improve coverage. cc: yhuai Author: Xiangrui Meng <meng@databricks.com> Closes #12330 from mengxr/SPARK-14563.
Diffstat (limited to 'mllib/src')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala10
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/SQLTransformerSuite.scala10
2 files changed, 16 insertions, 4 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala
index 95fe942c6b..2002d15745 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala
@@ -68,8 +68,7 @@ class SQLTransformer @Since("1.6.0") (override val uid: String) extends Transfor
val tableName = Identifiable.randomUID(uid)
dataset.registerTempTable(tableName)
val realStatement = $(statement).replace(tableIdentifier, tableName)
- val outputDF = dataset.sqlContext.sql(realStatement)
- outputDF
+ dataset.sqlContext.sql(realStatement)
}
@Since("1.6.0")
@@ -78,8 +77,11 @@ class SQLTransformer @Since("1.6.0") (override val uid: String) extends Transfor
val sqlContext = SQLContext.getOrCreate(sc)
val dummyRDD = sc.parallelize(Seq(Row.empty))
val dummyDF = sqlContext.createDataFrame(dummyRDD, schema)
- dummyDF.registerTempTable(tableIdentifier)
- val outputSchema = sqlContext.sql($(statement)).schema
+ val tableName = Identifiable.randomUID(uid)
+ val realStatement = $(statement).replace(tableIdentifier, tableName)
+ dummyDF.registerTempTable(tableName)
+ val outputSchema = sqlContext.sql(realStatement).schema
+ sqlContext.dropTempTable(tableName)
outputSchema
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/SQLTransformerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/SQLTransformerSuite.scala
index 553e0b8702..e213e17d0d 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/SQLTransformerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/SQLTransformerSuite.scala
@@ -21,6 +21,7 @@ import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.ml.util.DefaultReadWriteTest
import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.sql.types.{LongType, StructField, StructType}
class SQLTransformerSuite
extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
@@ -49,4 +50,13 @@ class SQLTransformerSuite
.setStatement("select * from __THIS__")
testDefaultReadWrite(t)
}
+
+ test("transformSchema") {
+ val df = sqlContext.range(10)
+ val outputSchema = new SQLTransformer()
+ .setStatement("SELECT id + 1 AS id1 FROM __THIS__")
+ .transformSchema(df.schema)
+ val expected = StructType(Seq(StructField("id1", LongType, nullable = false)))
+ assert(outputSchema === expected)
+ }
}