aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test/scala/org/apache
diff options
context:
space:
mode:
Diffstat (limited to 'mllib/src/test/scala/org/apache')
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/SQLTransformerSuite.scala10
1 files changed, 10 insertions, 0 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/SQLTransformerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/SQLTransformerSuite.scala
index 553e0b8702..e213e17d0d 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/SQLTransformerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/SQLTransformerSuite.scala
@@ -21,6 +21,7 @@ import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.ml.util.DefaultReadWriteTest
import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.sql.types.{LongType, StructField, StructType}
class SQLTransformerSuite
extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
@@ -49,4 +50,13 @@ class SQLTransformerSuite
.setStatement("select * from __THIS__")
testDefaultReadWrite(t)
}
+
+ test("transformSchema") {
+ val df = sqlContext.range(10)
+ val outputSchema = new SQLTransformer()
+ .setStatement("SELECT id + 1 AS id1 FROM __THIS__")
+ .transformSchema(df.schema)
+ val expected = StructType(Seq(StructField("id1", LongType, nullable = false)))
+ assert(outputSchema === expected)
+ }
}