diff options
Diffstat (limited to 'mllib/src')
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala | 10 | ||||
-rw-r--r-- | mllib/src/test/scala/org/apache/spark/ml/feature/SQLTransformerSuite.scala | 10 |
2 files changed, 16 insertions, 4 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala index 95fe942c6b..2002d15745 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala @@ -68,8 +68,7 @@ class SQLTransformer @Since("1.6.0") (override val uid: String) extends Transfor val tableName = Identifiable.randomUID(uid) dataset.registerTempTable(tableName) val realStatement = $(statement).replace(tableIdentifier, tableName) - val outputDF = dataset.sqlContext.sql(realStatement) - outputDF + dataset.sqlContext.sql(realStatement) } @Since("1.6.0") @@ -78,8 +77,11 @@ class SQLTransformer @Since("1.6.0") (override val uid: String) extends Transfor val sqlContext = SQLContext.getOrCreate(sc) val dummyRDD = sc.parallelize(Seq(Row.empty)) val dummyDF = sqlContext.createDataFrame(dummyRDD, schema) - dummyDF.registerTempTable(tableIdentifier) - val outputSchema = sqlContext.sql($(statement)).schema + val tableName = Identifiable.randomUID(uid) + val realStatement = $(statement).replace(tableIdentifier, tableName) + dummyDF.registerTempTable(tableName) + val outputSchema = sqlContext.sql(realStatement).schema + sqlContext.dropTempTable(tableName) outputSchema } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/SQLTransformerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/SQLTransformerSuite.scala index 553e0b8702..e213e17d0d 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/SQLTransformerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/SQLTransformerSuite.scala @@ -21,6 +21,7 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.util.DefaultReadWriteTest import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.sql.types.{LongType, StructField, StructType} class SQLTransformerSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { @@ -49,4 +50,13 @@ class SQLTransformerSuite .setStatement("select * from __THIS__") testDefaultReadWrite(t) } + + test("transformSchema") { + val df = sqlContext.range(10) + val outputSchema = new SQLTransformer() + .setStatement("SELECT id + 1 AS id1 FROM __THIS__") + .transformSchema(df.schema) + val expected = StructType(Seq(StructField("id1", LongType, nullable = false))) + assert(outputSchema === expected) + } } |