aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test
diff options
context:
space:
mode:
authorYuhao Yang <yuhao.yang@intel.com>2016-06-27 12:27:39 -0700
committerXiangrui Meng <meng@databricks.com>2016-06-27 12:27:39 -0700
commitc17b1abff8f8c6d24cb0cf4ff4f8c14a780c64b0 (patch)
tree86dbee40360a5f2c6baddb94edb00ba94ef972f9 /mllib/src/test
parentc48c8ebc0aad433aab7af9e2ddf544d253ab9fd7 (diff)
downloadspark-c17b1abff8f8c6d24cb0cf4ff4f8c14a780c64b0.tar.gz
spark-c17b1abff8f8c6d24cb0cf4ff4f8c14a780c64b0.tar.bz2
spark-c17b1abff8f8c6d24cb0cf4ff4f8c14a780c64b0.zip
[SPARK-16187][ML] Implement util method for ML Matrix conversion in scala/java
## What changes were proposed in this pull request? jira: https://issues.apache.org/jira/browse/SPARK-16187 This is to provide conversion utils between old/new vector columns in a DataFrame. So users can use it to migrate their datasets and pipelines manually. ## How was this patch tested? java and scala ut Author: Yuhao Yang <yuhao.yang@intel.com> Closes #13888 from hhbyyh/matComp.
Diffstat (limited to 'mllib/src/test')
-rw-r--r--mllib/src/test/java/org/apache/spark/mllib/util/JavaMLUtilsSuite.java29
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala56
2 files changed, 82 insertions, 3 deletions
diff --git a/mllib/src/test/java/org/apache/spark/mllib/util/JavaMLUtilsSuite.java b/mllib/src/test/java/org/apache/spark/mllib/util/JavaMLUtilsSuite.java
index 2fa0bd2546..e271a0a77c 100644
--- a/mllib/src/test/java/org/apache/spark/mllib/util/JavaMLUtilsSuite.java
+++ b/mllib/src/test/java/org/apache/spark/mllib/util/JavaMLUtilsSuite.java
@@ -17,18 +17,22 @@
package org.apache.spark.mllib.util;
+import java.util.Arrays;
import java.util.Collections;
import org.junit.Assert;
import org.junit.Test;
import org.apache.spark.SharedSparkSession;
-import org.apache.spark.mllib.linalg.Vector;
-import org.apache.spark.mllib.linalg.Vectors;
+import org.apache.spark.mllib.linalg.*;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
+import org.apache.spark.sql.types.DataTypes;
+import org.apache.spark.sql.types.Metadata;
+import org.apache.spark.sql.types.StructField;
+import org.apache.spark.sql.types.StructType;
public class JavaMLUtilsSuite extends SharedSparkSession {
@@ -46,4 +50,25 @@ public class JavaMLUtilsSuite extends SharedSparkSession {
Row old1 = MLUtils.convertVectorColumnsFromML(newDataset1).first();
Assert.assertEquals(RowFactory.create(1.0, x), old1);
}
+
+ @Test
+ public void testConvertMatrixColumnsToAndFromML() {
+ Matrix x = Matrices.dense(2, 1, new double[]{1.0, 2.0});
+ StructType schema = new StructType(new StructField[]{
+ new StructField("label", DataTypes.DoubleType, false, Metadata.empty()),
+ new StructField("features", new MatrixUDT(), false, Metadata.empty())
+ });
+ Dataset<Row> dataset = spark.createDataFrame(
+ Arrays.asList(
+ RowFactory.create(1.0, x)),
+ schema);
+
+ Dataset<Row> newDataset1 = MLUtils.convertMatrixColumnsToML(dataset);
+ Row new1 = newDataset1.first();
+ Assert.assertEquals(RowFactory.create(1.0, x.asML()), new1);
+ Row new2 = MLUtils.convertMatrixColumnsToML(dataset, "features").first();
+ Assert.assertEquals(new1, new2);
+ Row old1 = MLUtils.convertMatrixColumnsFromML(newDataset1).first();
+ Assert.assertEquals(RowFactory.create(1.0, x), old1);
+ }
}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala
index 3801bd127a..6aa93c9076 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala
@@ -26,7 +26,7 @@ import breeze.linalg.{squaredDistance => breezeSquaredDistance}
import com.google.common.io.Files
import org.apache.spark.{SparkException, SparkFunSuite}
-import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vectors}
+import org.apache.spark.mllib.linalg.{DenseVector, Matrices, SparseVector, Vectors}
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.util.MLUtils._
import org.apache.spark.mllib.util.TestingUtils._
@@ -301,4 +301,58 @@ class MLUtilsSuite extends SparkFunSuite with MLlibTestSparkContext {
convertVectorColumnsFromML(df, "p._2")
}
}
+
+ test("convertMatrixColumnsToML") {
+ val x = Matrices.sparse(3, 2, Array(0, 2, 3), Array(0, 2, 1), Array(0.0, -1.2, 0.0))
+ val metadata = new MetadataBuilder().putLong("numFeatures", 2L).build()
+ val y = Matrices.dense(2, 1, Array(0.2, 1.3))
+ val z = Matrices.ones(1, 1)
+ val p = (5.0, z)
+ val w = Matrices.dense(1, 1, Array(4.5)).asML
+ val df = spark.createDataFrame(Seq(
+ (0, x, y, p, w)
+ )).toDF("id", "x", "y", "p", "w")
+ .withColumn("x", col("x"), metadata)
+ val newDF1 = convertMatrixColumnsToML(df)
+ assert(newDF1.schema("x").metadata === metadata, "Metadata should be preserved.")
+ val new1 = newDF1.first()
+ assert(new1 === Row(0, x.asML, y.asML, Row(5.0, z), w))
+ val new2 = convertMatrixColumnsToML(df, "x", "y").first()
+ assert(new2 === new1)
+ val new3 = convertMatrixColumnsToML(df, "y", "w").first()
+ assert(new3 === Row(0, x, y.asML, Row(5.0, z), w))
+ intercept[IllegalArgumentException] {
+ convertMatrixColumnsToML(df, "p")
+ }
+ intercept[IllegalArgumentException] {
+ convertMatrixColumnsToML(df, "p._2")
+ }
+ }
+
+ test("convertMatrixColumnsFromML") {
+ val x = Matrices.sparse(3, 2, Array(0, 2, 3), Array(0, 2, 1), Array(0.0, -1.2, 0.0)).asML
+ val metadata = new MetadataBuilder().putLong("numFeatures", 2L).build()
+ val y = Matrices.dense(2, 1, Array(0.2, 1.3)).asML
+ val z = Matrices.ones(1, 1).asML
+ val p = (5.0, z)
+ val w = Matrices.dense(1, 1, Array(4.5))
+ val df = spark.createDataFrame(Seq(
+ (0, x, y, p, w)
+ )).toDF("id", "x", "y", "p", "w")
+ .withColumn("x", col("x"), metadata)
+ val newDF1 = convertMatrixColumnsFromML(df)
+ assert(newDF1.schema("x").metadata === metadata, "Metadata should be preserved.")
+ val new1 = newDF1.first()
+ assert(new1 === Row(0, Matrices.fromML(x), Matrices.fromML(y), Row(5.0, z), w))
+ val new2 = convertMatrixColumnsFromML(df, "x", "y").first()
+ assert(new2 === new1)
+ val new3 = convertMatrixColumnsFromML(df, "y", "w").first()
+ assert(new3 === Row(0, x, Matrices.fromML(y), Row(5.0, z), w))
+ intercept[IllegalArgumentException] {
+ convertMatrixColumnsFromML(df, "p")
+ }
+ intercept[IllegalArgumentException] {
+ convertMatrixColumnsFromML(df, "p._2")
+ }
+ }
}