aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorShuo Xiang <shuoxiangpub@gmail.com>2015-01-07 23:22:37 -0800
committerXiangrui Meng <meng@databricks.com>2015-01-07 23:22:37 -0800
commitc66a976300734b52d943d4ff811fc269c1bff2de (patch)
tree9470fb54113c09722af1247f5467fc078b97ba67 /mllib
parent2b729d22500c682435ef7adde566551b45a3c6e3 (diff)
downloadspark-c66a976300734b52d943d4ff811fc269c1bff2de.tar.gz
spark-c66a976300734b52d943d4ff811fc269c1bff2de.tar.bz2
spark-c66a976300734b52d943d4ff811fc269c1bff2de.zip
[SPARK-5116][MLlib] Add extractor for SparseVector and DenseVector
Add extractor for SparseVector and DenseVector in MLlib to save some code while performing pattern matching on Vectors. For example, previously we may use: vec match { case dv: DenseVector => val values = dv.values ... case sv: SparseVector => val indices = sv.indices val values = sv.values val size = sv.size ... } with extractor it is: vec match { case DenseVector(values) => ... case SparseVector(size, indices, values) => ... } Author: Shuo Xiang <shuoxiangpub@gmail.com> Closes #3919 from coderxiang/extractor and squashes the following commits: 359e8d5 [Shuo Xiang] merge master ca5fc3e [Shuo Xiang] merge master 0b1e190 [Shuo Xiang] use extractor for vectors in RowMatrix.scala e961805 [Shuo Xiang] use extractor for vectors in StandardScaler.scala c2bbdaf [Shuo Xiang] use extractor for vectors in IDFscala 8433922 [Shuo Xiang] use extractor for vectors in NaiveBayes.scala and Normalizer.scala d83c7ca [Shuo Xiang] use extractor for vectors in Vectors.scala 5523dad [Shuo Xiang] Add extractor for SparseVector and DenseVector
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala8
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/feature/IDF.scala26
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/feature/Normalizer.scala10
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala15
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala25
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala24
6 files changed, 57 insertions, 51 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala
index 8c8e4a161a..a967df857b 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala
@@ -93,10 +93,10 @@ class NaiveBayes private (private var lambda: Double) extends Serializable with
def run(data: RDD[LabeledPoint]) = {
val requireNonnegativeValues: Vector => Unit = (v: Vector) => {
val values = v match {
- case sv: SparseVector =>
- sv.values
- case dv: DenseVector =>
- dv.values
+ case SparseVector(size, indices, values) =>
+ values
+ case DenseVector(values) =>
+ values
}
if (!values.forall(_ >= 0.0)) {
throw new SparkException(s"Naive Bayes requires nonnegative feature values but found $v.")
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/IDF.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/IDF.scala
index 19120e1e8a..3260f27513 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/feature/IDF.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/IDF.scala
@@ -86,20 +86,20 @@ private object IDF {
df = BDV.zeros(doc.size)
}
doc match {
- case sv: SparseVector =>
- val nnz = sv.indices.size
+ case SparseVector(size, indices, values) =>
+ val nnz = indices.size
var k = 0
while (k < nnz) {
- if (sv.values(k) > 0) {
- df(sv.indices(k)) += 1L
+ if (values(k) > 0) {
+ df(indices(k)) += 1L
}
k += 1
}
- case dv: DenseVector =>
- val n = dv.size
+ case DenseVector(values) =>
+ val n = values.size
var j = 0
while (j < n) {
- if (dv.values(j) > 0.0) {
+ if (values(j) > 0.0) {
df(j) += 1L
}
j += 1
@@ -207,20 +207,20 @@ private object IDFModel {
def transform(idf: Vector, v: Vector): Vector = {
val n = v.size
v match {
- case sv: SparseVector =>
- val nnz = sv.indices.size
+ case SparseVector(size, indices, values) =>
+ val nnz = indices.size
val newValues = new Array[Double](nnz)
var k = 0
while (k < nnz) {
- newValues(k) = sv.values(k) * idf(sv.indices(k))
+ newValues(k) = values(k) * idf(indices(k))
k += 1
}
- Vectors.sparse(n, sv.indices, newValues)
- case dv: DenseVector =>
+ Vectors.sparse(n, indices, newValues)
+ case DenseVector(values) =>
val newValues = new Array[Double](n)
var j = 0
while (j < n) {
- newValues(j) = dv.values(j) * idf(j)
+ newValues(j) = values(j) * idf(j)
j += 1
}
Vectors.dense(newValues)
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Normalizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Normalizer.scala
index 1ced26a9b7..32848e039e 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Normalizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Normalizer.scala
@@ -52,8 +52,8 @@ class Normalizer(p: Double) extends VectorTransformer {
// However, for sparse vector, the `index` array will not be changed,
// so we can re-use it to save memory.
vector match {
- case dv: DenseVector =>
- val values = dv.values.clone()
+ case DenseVector(vs) =>
+ val values = vs.clone()
val size = values.size
var i = 0
while (i < size) {
@@ -61,15 +61,15 @@ class Normalizer(p: Double) extends VectorTransformer {
i += 1
}
Vectors.dense(values)
- case sv: SparseVector =>
- val values = sv.values.clone()
+ case SparseVector(size, ids, vs) =>
+ val values = vs.clone()
val nnz = values.size
var i = 0
while (i < nnz) {
values(i) /= norm
i += 1
}
- Vectors.sparse(sv.size, sv.indices, values)
+ Vectors.sparse(size, ids, values)
case v => throw new IllegalArgumentException("Do not support vector type " + v.getClass)
}
} else {
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala
index 8c4c5db525..3c2091732f 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala
@@ -105,8 +105,8 @@ class StandardScalerModel private[mllib] (
// This can be avoid by having a local reference of `shift`.
val localShift = shift
vector match {
- case dv: DenseVector =>
- val values = dv.values.clone()
+ case DenseVector(vs) =>
+ val values = vs.clone()
val size = values.size
if (withStd) {
// Having a local reference of `factor` to avoid overhead as the comment before.
@@ -130,8 +130,8 @@ class StandardScalerModel private[mllib] (
// Having a local reference of `factor` to avoid overhead as the comment before.
val localFactor = factor
vector match {
- case dv: DenseVector =>
- val values = dv.values.clone()
+ case DenseVector(vs) =>
+ val values = vs.clone()
val size = values.size
var i = 0
while(i < size) {
@@ -139,18 +139,17 @@ class StandardScalerModel private[mllib] (
i += 1
}
Vectors.dense(values)
- case sv: SparseVector =>
+ case SparseVector(size, indices, vs) =>
// For sparse vector, the `index` array inside sparse vector object will not be changed,
// so we can re-use it to save memory.
- val indices = sv.indices
- val values = sv.values.clone()
+ val values = vs.clone()
val nnz = values.size
var i = 0
while (i < nnz) {
values(i) *= localFactor(indices(i))
i += 1
}
- Vectors.sparse(sv.size, indices, values)
+ Vectors.sparse(size, indices, values)
case v => throw new IllegalArgumentException("Do not support vector type " + v.getClass)
}
} else {
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
index d40f13342a..bf1faa25ef 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
@@ -108,16 +108,16 @@ private[spark] class VectorUDT extends UserDefinedType[Vector] {
override def serialize(obj: Any): Row = {
val row = new GenericMutableRow(4)
obj match {
- case sv: SparseVector =>
+ case SparseVector(size, indices, values) =>
row.setByte(0, 0)
- row.setInt(1, sv.size)
- row.update(2, sv.indices.toSeq)
- row.update(3, sv.values.toSeq)
- case dv: DenseVector =>
+ row.setInt(1, size)
+ row.update(2, indices.toSeq)
+ row.update(3, values.toSeq)
+ case DenseVector(values) =>
row.setByte(0, 1)
row.setNullAt(1)
row.setNullAt(2)
- row.update(3, dv.values.toSeq)
+ row.update(3, values.toSeq)
}
row
}
@@ -271,8 +271,8 @@ object Vectors {
def norm(vector: Vector, p: Double): Double = {
require(p >= 1.0)
val values = vector match {
- case dv: DenseVector => dv.values
- case sv: SparseVector => sv.values
+ case DenseVector(vs) => vs
+ case SparseVector(n, ids, vs) => vs
case v => throw new IllegalArgumentException("Do not support vector type " + v.getClass)
}
val size = values.size
@@ -427,6 +427,10 @@ class DenseVector(val values: Array[Double]) extends Vector {
}
}
+object DenseVector {
+ def unapply(dv: DenseVector): Option[Array[Double]] = Some(dv.values)
+}
+
/**
* A sparse vector represented by an index array and an value array.
*
@@ -474,3 +478,8 @@ class SparseVector(
}
}
}
+
+object SparseVector {
+ def unapply(sv: SparseVector): Option[(Int, Array[Int], Array[Double])] =
+ Some((sv.size, sv.indices, sv.values))
+}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala
index a3fca53929..fbd35e372f 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala
@@ -528,21 +528,21 @@ class RowMatrix(
iter.flatMap { row =>
val buf = new ListBuffer[((Int, Int), Double)]()
row match {
- case sv: SparseVector =>
- val nnz = sv.indices.size
+ case SparseVector(size, indices, values) =>
+ val nnz = indices.size
var k = 0
while (k < nnz) {
- scaled(k) = sv.values(k) / q(sv.indices(k))
+ scaled(k) = values(k) / q(indices(k))
k += 1
}
k = 0
while (k < nnz) {
- val i = sv.indices(k)
+ val i = indices(k)
val iVal = scaled(k)
if (iVal != 0 && rand.nextDouble() < p(i)) {
var l = k + 1
while (l < nnz) {
- val j = sv.indices(l)
+ val j = indices(l)
val jVal = scaled(l)
if (jVal != 0 && rand.nextDouble() < p(j)) {
buf += (((i, j), iVal * jVal))
@@ -552,11 +552,11 @@ class RowMatrix(
}
k += 1
}
- case dv: DenseVector =>
- val n = dv.values.size
+ case DenseVector(values) =>
+ val n = values.size
var i = 0
while (i < n) {
- scaled(i) = dv.values(i) / q(i)
+ scaled(i) = values(i) / q(i)
i += 1
}
i = 0
@@ -620,11 +620,9 @@ object RowMatrix {
// TODO: Find a better home (breeze?) for this method.
val n = v.size
v match {
- case dv: DenseVector =>
- blas.dspr("U", n, alpha, dv.values, 1, U)
- case sv: SparseVector =>
- val indices = sv.indices
- val values = sv.values
+ case DenseVector(values) =>
+ blas.dspr("U", n, alpha, values, 1, U)
+ case SparseVector(size, indices, values) =>
val nnz = indices.length
var colStartIdx = 0
var prevCol = 0