aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/main
diff options
context:
space:
mode:
Diffstat (limited to 'mllib/src/main')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala10
1 files changed, 6 insertions, 4 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala
index 95fe942c6b..2002d15745 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala
@@ -68,8 +68,7 @@ class SQLTransformer @Since("1.6.0") (override val uid: String) extends Transfor
val tableName = Identifiable.randomUID(uid)
dataset.registerTempTable(tableName)
val realStatement = $(statement).replace(tableIdentifier, tableName)
- val outputDF = dataset.sqlContext.sql(realStatement)
- outputDF
+ dataset.sqlContext.sql(realStatement)
}
@Since("1.6.0")
@@ -78,8 +77,11 @@ class SQLTransformer @Since("1.6.0") (override val uid: String) extends Transfor
val sqlContext = SQLContext.getOrCreate(sc)
val dummyRDD = sc.parallelize(Seq(Row.empty))
val dummyDF = sqlContext.createDataFrame(dummyRDD, schema)
- dummyDF.registerTempTable(tableIdentifier)
- val outputSchema = sqlContext.sql($(statement)).schema
+ val tableName = Identifiable.randomUID(uid)
+ val realStatement = $(statement).replace(tableIdentifier, tableName)
+ dummyDF.registerTempTable(tableName)
+ val outputSchema = sqlContext.sql(realStatement).schema
+ sqlContext.dropTempTable(tableName)
outputSchema
}