aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala88
1 files changed, 67 insertions, 21 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
index af0cfe22ca..34833e90d4 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
@@ -52,7 +52,7 @@ sealed trait Vector extends Serializable {
override def equals(other: Any): Boolean = {
other match {
- case v2: Vector => {
+ case v2: Vector =>
if (this.size != v2.size) return false
(this, v2) match {
case (s1: SparseVector, s2: SparseVector) =>
@@ -63,20 +63,28 @@ sealed trait Vector extends Serializable {
Vectors.equals(0 until d1.size, d1.values, s1.indices, s1.values)
case (_, _) => util.Arrays.equals(this.toArray, v2.toArray)
}
- }
case _ => false
}
}
+ /**
+ * Returns a hash code value for the vector. The hash code is based on its size and its nonzeros
+ * in the first 16 entries, using a hash algorithm similar to [[java.util.Arrays.hashCode]].
+ */
override def hashCode(): Int = {
- var result: Int = size + 31
- this.foreachActive { case (index, value) =>
- // ignore explict 0 for comparison between sparse and dense
- if (value != 0) {
- result = 31 * result + index
- // refer to {@link java.util.Arrays.equals} for hash algorithm
- val bits = java.lang.Double.doubleToLongBits(value)
- result = 31 * result + (bits ^ (bits >>> 32)).toInt
+ // This is a reference implementation. It calls return in foreachActive, which is slow.
+ // Subclasses should override it with optimized implementation.
+ var result: Int = 31 + size
+ this.foreachActive { (index, value) =>
+ if (index < 16) {
+ // ignore explicit 0 for comparison between sparse and dense
+ if (value != 0) {
+ result = 31 * result + index
+ val bits = java.lang.Double.doubleToLongBits(value)
+ result = 31 * result + (bits ^ (bits >>> 32)).toInt
+ }
+ } else {
+ return result
}
}
result
@@ -317,7 +325,7 @@ object Vectors {
case SparseVector(n, ids, vs) => vs
case v => throw new IllegalArgumentException("Do not support vector type " + v.getClass)
}
- val size = values.size
+ val size = values.length
if (p == 1) {
var sum = 0.0
@@ -371,8 +379,8 @@ object Vectors {
val v1Indices = v1.indices
val v2Values = v2.values
val v2Indices = v2.indices
- val nnzv1 = v1Indices.size
- val nnzv2 = v2Indices.size
+ val nnzv1 = v1Indices.length
+ val nnzv2 = v2Indices.length
var kv1 = 0
var kv2 = 0
@@ -401,7 +409,7 @@ object Vectors {
case (DenseVector(vv1), DenseVector(vv2)) =>
var kv = 0
- val sz = vv1.size
+ val sz = vv1.length
while (kv < sz) {
val score = vv1(kv) - vv2(kv)
squaredDistance += score * score
@@ -422,7 +430,7 @@ object Vectors {
var kv2 = 0
val indices = v1.indices
var squaredDistance = 0.0
- val nnzv1 = indices.size
+ val nnzv1 = indices.length
val nnzv2 = v2.size
var iv1 = if (nnzv1 > 0) indices(kv1) else -1
@@ -451,8 +459,8 @@ object Vectors {
v1Values: Array[Double],
v2Indices: IndexedSeq[Int],
v2Values: Array[Double]): Boolean = {
- val v1Size = v1Values.size
- val v2Size = v2Values.size
+ val v1Size = v1Values.length
+ val v2Size = v2Values.length
var k1 = 0
var k2 = 0
var allEqual = true
@@ -493,7 +501,7 @@ class DenseVector(val values: Array[Double]) extends Vector {
private[spark] override def foreachActive(f: (Int, Double) => Unit) = {
var i = 0
- val localValuesSize = values.size
+ val localValuesSize = values.length
val localValues = values
while (i < localValuesSize) {
@@ -501,6 +509,22 @@ class DenseVector(val values: Array[Double]) extends Vector {
i += 1
}
}
+
+ override def hashCode(): Int = {
+ var result: Int = 31 + size
+ var i = 0
+ val end = math.min(values.length, 16)
+ while (i < end) {
+ val v = values(i)
+ if (v != 0.0) {
+ result = 31 * result + i
+ val bits = java.lang.Double.doubleToLongBits(values(i))
+ result = 31 * result + (bits ^ (bits >>> 32)).toInt
+ }
+ i += 1
+ }
+ result
+ }
}
object DenseVector {
@@ -522,8 +546,8 @@ class SparseVector(
val values: Array[Double]) extends Vector {
require(indices.length == values.length, "Sparse vectors require that the dimension of the" +
- s" indices match the dimension of the values. You provided ${indices.size} indices and " +
- s" ${values.size} values.")
+ s" indices match the dimension of the values. You provided ${indices.length} indices and " +
+ s" ${values.length} values.")
override def toString: String =
s"($size,${indices.mkString("[", ",", "]")},${values.mkString("[", ",", "]")})"
@@ -547,7 +571,7 @@ class SparseVector(
private[spark] override def foreachActive(f: (Int, Double) => Unit) = {
var i = 0
- val localValuesSize = values.size
+ val localValuesSize = values.length
val localIndices = indices
val localValues = values
@@ -556,6 +580,28 @@ class SparseVector(
i += 1
}
}
+
+ override def hashCode(): Int = {
+ var result: Int = 31 + size
+ val end = values.length
+ var continue = true
+ var k = 0
+ while ((k < end) & continue) {
+ val i = indices(k)
+ if (i < 16) {
+ val v = values(k)
+ if (v != 0.0) {
+ result = 31 * result + i
+ val bits = java.lang.Double.doubleToLongBits(v)
+ result = 31 * result + (bits ^ (bits >>> 32)).toInt
+ }
+ } else {
+ continue = false
+ }
+ k += 1
+ }
+ result
+ }
}
object SparseVector {