diff options
Diffstat (limited to 'mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorSlicerSuite.java')
-rw-r--r-- | mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorSlicerSuite.java | 9 |
1 files changed, 5 insertions, 4 deletions
diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorSlicerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorSlicerSuite.java index 786c11c412..b87605ebfd 100644 --- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorSlicerSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorSlicerSuite.java @@ -31,7 +31,7 @@ import org.apache.spark.ml.attribute.AttributeGroup; import org.apache.spark.ml.attribute.NumericAttribute; import org.apache.spark.mllib.linalg.Vector; import org.apache.spark.mllib.linalg.Vectors; -import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.RowFactory; import org.apache.spark.sql.SQLContext; @@ -68,16 +68,17 @@ public class JavaVectorSlicerSuite { RowFactory.create(Vectors.dense(-2.0, 2.3, 0.0)) ); - DataFrame dataset = jsql.createDataFrame(data, (new StructType()).add(group.toStructField())); + Dataset<Row> dataset = + jsql.createDataFrame(data, (new StructType()).add(group.toStructField())); VectorSlicer vectorSlicer = new VectorSlicer() .setInputCol("userFeatures").setOutputCol("features"); vectorSlicer.setIndices(new int[]{1}).setNames(new String[]{"f3"}); - DataFrame output = vectorSlicer.transform(dataset); + Dataset<Row> output = vectorSlicer.transform(dataset); - for (Row r : output.select("userFeatures", "features").take(2)) { + for (Row r : output.select("userFeatures", "features").takeRows(2)) { Vector features = r.getAs(1); Assert.assertEquals(features.size(), 2); } |