aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorSlicerSuite.java
diff options
context:
space:
mode:
Diffstat (limited to 'mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorSlicerSuite.java')
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorSlicerSuite.java9
1 files changed, 5 insertions, 4 deletions
diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorSlicerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorSlicerSuite.java
index 786c11c412..b87605ebfd 100644
--- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorSlicerSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorSlicerSuite.java
@@ -31,7 +31,7 @@ import org.apache.spark.ml.attribute.AttributeGroup;
import org.apache.spark.ml.attribute.NumericAttribute;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.Vectors;
-import org.apache.spark.sql.DataFrame;
+import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SQLContext;
@@ -68,16 +68,17 @@ public class JavaVectorSlicerSuite {
RowFactory.create(Vectors.dense(-2.0, 2.3, 0.0))
);
- DataFrame dataset = jsql.createDataFrame(data, (new StructType()).add(group.toStructField()));
+ Dataset<Row> dataset =
+ jsql.createDataFrame(data, (new StructType()).add(group.toStructField()));
VectorSlicer vectorSlicer = new VectorSlicer()
.setInputCol("userFeatures").setOutputCol("features");
vectorSlicer.setIndices(new int[]{1}).setNames(new String[]{"f3"});
- DataFrame output = vectorSlicer.transform(dataset);
+ Dataset<Row> output = vectorSlicer.transform(dataset);
- for (Row r : output.select("userFeatures", "features").take(2)) {
+ for (Row r : output.select("userFeatures", "features").takeRows(2)) {
Vector features = r.getAs(1);
Assert.assertEquals(features.size(), 2);
}