aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorXiangrui Meng <meng@databricks.com>2014-11-03 22:29:48 -0800
committerXiangrui Meng <meng@databricks.com>2014-11-03 22:29:48 -0800
commit1a9c6cddadebdc53d083ac3e0da276ce979b5d1f (patch)
treeb485818ba52a9287ae7124e57ef55f1d974f3a1f /mllib
parent04450d11548cfb25d4fb77d4a33e3a7cd4254183 (diff)
downloadspark-1a9c6cddadebdc53d083ac3e0da276ce979b5d1f.tar.gz
spark-1a9c6cddadebdc53d083ac3e0da276ce979b5d1f.tar.bz2
spark-1a9c6cddadebdc53d083ac3e0da276ce979b5d1f.zip
[SPARK-3573][MLLIB] Make MLlib's Vector compatible with SQL's SchemaRDD
Register MLlib's Vector as a SQL user-defined type (UDT) in both Scala and Python. With this PR, we can easily map a RDD[LabeledPoint] to a SchemaRDD, and then select columns or save to a Parquet file. Examples in Scala/Python are attached. The Scala code was copied from jkbradley. ~~This PR contains the changes from #3068 . I will rebase after #3068 is merged.~~ marmbrus jkbradley Author: Xiangrui Meng <meng@databricks.com> Closes #3070 from mengxr/SPARK-3573 and squashes the following commits: 3a0b6e5 [Xiangrui Meng] organize imports 236f0a0 [Xiangrui Meng] register vector as UDT and provide dataset examples
Diffstat (limited to 'mllib')
-rw-r--r--mllib/pom.xml5
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala69
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala11
3 files changed, 83 insertions, 2 deletions
diff --git a/mllib/pom.xml b/mllib/pom.xml
index fb7239e779..87a7ddaba9 100644
--- a/mllib/pom.xml
+++ b/mllib/pom.xml
@@ -46,6 +46,11 @@
<version>${project.version}</version>
</dependency>
<dependency>
+ <groupId>org.apache.spark</groupId>
+ <artifactId>spark-sql_${scala.binary.version}</artifactId>
+ <version>${project.version}</version>
+ </dependency>
+ <dependency>
<groupId>org.eclipse.jetty</groupId>
<artifactId>jetty-server</artifactId>
</dependency>
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
index 6af225b7f4..ac217edc61 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
@@ -17,22 +17,26 @@
package org.apache.spark.mllib.linalg
-import java.lang.{Double => JavaDouble, Integer => JavaInteger, Iterable => JavaIterable}
import java.util
+import java.lang.{Double => JavaDouble, Integer => JavaInteger, Iterable => JavaIterable}
import scala.annotation.varargs
import scala.collection.JavaConverters._
import breeze.linalg.{DenseVector => BDV, SparseVector => BSV, Vector => BV}
-import org.apache.spark.mllib.util.NumericParser
import org.apache.spark.SparkException
+import org.apache.spark.mllib.util.NumericParser
+import org.apache.spark.sql.catalyst.annotation.SQLUserDefinedType
+import org.apache.spark.sql.catalyst.expressions.{GenericMutableRow, Row}
+import org.apache.spark.sql.catalyst.types._
/**
* Represents a numeric vector, whose index type is Int and value type is Double.
*
* Note: Users should not implement this interface.
*/
+@SQLUserDefinedType(udt = classOf[VectorUDT])
sealed trait Vector extends Serializable {
/**
@@ -75,6 +79,65 @@ sealed trait Vector extends Serializable {
}
/**
+ * User-defined type for [[Vector]] which allows easy interaction with SQL
+ * via [[org.apache.spark.sql.SchemaRDD]].
+ */
+private[spark] class VectorUDT extends UserDefinedType[Vector] {
+
+ override def sqlType: StructType = {
+ // type: 0 = sparse, 1 = dense
+ // We only use "values" for dense vectors, and "size", "indices", and "values" for sparse
+ // vectors. The "values" field is nullable because we might want to add binary vectors later,
+ // which uses "size" and "indices", but not "values".
+ StructType(Seq(
+ StructField("type", ByteType, nullable = false),
+ StructField("size", IntegerType, nullable = true),
+ StructField("indices", ArrayType(IntegerType, containsNull = false), nullable = true),
+ StructField("values", ArrayType(DoubleType, containsNull = false), nullable = true)))
+ }
+
+ override def serialize(obj: Any): Row = {
+ val row = new GenericMutableRow(4)
+ obj match {
+ case sv: SparseVector =>
+ row.setByte(0, 0)
+ row.setInt(1, sv.size)
+ row.update(2, sv.indices.toSeq)
+ row.update(3, sv.values.toSeq)
+ case dv: DenseVector =>
+ row.setByte(0, 1)
+ row.setNullAt(1)
+ row.setNullAt(2)
+ row.update(3, dv.values.toSeq)
+ }
+ row
+ }
+
+ override def deserialize(datum: Any): Vector = {
+ datum match {
+ case row: Row =>
+ require(row.length == 4,
+ s"VectorUDT.deserialize given row with length ${row.length} but requires length == 4")
+ val tpe = row.getByte(0)
+ tpe match {
+ case 0 =>
+ val size = row.getInt(1)
+ val indices = row.getAs[Iterable[Int]](2).toArray
+ val values = row.getAs[Iterable[Double]](3).toArray
+ new SparseVector(size, indices, values)
+ case 1 =>
+ val values = row.getAs[Iterable[Double]](3).toArray
+ new DenseVector(values)
+ }
+ }
+ }
+
+ override def pyUDT: String = "pyspark.mllib.linalg.VectorUDT"
+
+ override def userClass: Class[Vector] = classOf[Vector]
+}
+
+/**
* Factory methods for [[org.apache.spark.mllib.linalg.Vector]].
* We don't use the name `Vector` because Scala imports
* [[scala.collection.immutable.Vector]] by default.
@@ -191,6 +254,7 @@ object Vectors {
/**
* A dense vector represented by a value array.
*/
+@SQLUserDefinedType(udt = classOf[VectorUDT])
class DenseVector(val values: Array[Double]) extends Vector {
override def size: Int = values.length
@@ -215,6 +279,7 @@ class DenseVector(val values: Array[Double]) extends Vector {
* @param indices index array, assume to be strictly increasing.
* @param values value array, must have the same length as the index array.
*/
+@SQLUserDefinedType(udt = classOf[VectorUDT])
class SparseVector(
override val size: Int,
val indices: Array[Int],
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala
index cd651fe2d2..93a84fe07b 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala
@@ -155,4 +155,15 @@ class VectorsSuite extends FunSuite {
throw new RuntimeException(s"copy returned ${dvCopy.getClass} on ${dv.getClass}.")
}
}
+
+ test("VectorUDT") {
+ val dv0 = Vectors.dense(Array.empty[Double])
+ val dv1 = Vectors.dense(1.0, 2.0)
+ val sv0 = Vectors.sparse(2, Array.empty, Array.empty)
+ val sv1 = Vectors.sparse(2, Array(1), Array(2.0))
+ val udt = new VectorUDT()
+ for (v <- Seq(dv0, dv1, sv0, sv1)) {
+ assert(v === udt.deserialize(udt.serialize(v)))
+ }
+ }
}