aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorImran Rashid <imran@quantifind.com>2012-07-12 12:52:12 -0700
committerMatei Zaharia <matei@eecs.berkeley.edu>2012-07-28 20:15:51 -0700
commit2d666b9d76ab5bcb9511b2b3994ad50b8774072b (patch)
tree331929ec381cab17e66203e99e319f3f5352a319
parentedc6972f8e14e78a243040f8c4e252884b63c55d (diff)
downloadspark-2d666b9d76ab5bcb9511b2b3994ad50b8774072b.tar.gz
spark-2d666b9d76ab5bcb9511b2b3994ad50b8774072b.tar.bz2
spark-2d666b9d76ab5bcb9511b2b3994ad50b8774072b.zip
add some functionality to Vector, delete copy in AccumulatorSuite
-rw-r--r--core/src/main/scala/spark/util/Vector.scala32
-rw-r--r--core/src/test/scala/spark/AccumulatorSuite.scala114
2 files changed, 33 insertions, 113 deletions
diff --git a/core/src/main/scala/spark/util/Vector.scala b/core/src/main/scala/spark/util/Vector.scala
index e5604687e9..4e95ac2ac6 100644
--- a/core/src/main/scala/spark/util/Vector.scala
+++ b/core/src/main/scala/spark/util/Vector.scala
@@ -29,7 +29,37 @@ class Vector(val elements: Array[Double]) extends Serializable {
return ans
}
- def * (scale: Double): Vector = Vector(length, i => this(i) * scale)
+ /**
+ * return (this + plus) dot other, but without creating any intermediate storage
+ * @param plus
+ * @param other
+ * @return
+ */
+ def plusDot(plus: Vector, other: Vector): Double = {
+ if (length != other.length)
+ throw new IllegalArgumentException("Vectors of different length")
+ if (length != plus.length)
+ throw new IllegalArgumentException("Vectors of different length")
+ var ans = 0.0
+ var i = 0
+ while (i < length) {
+ ans += (this(i) + plus(i)) * other(i)
+ i += 1
+ }
+ return ans
+ }
+
+ def +=(other: Vector) {
+ if (length != other.length)
+ throw new IllegalArgumentException("Vectors of different length")
+ var ans = 0.0
+ var i = 0
+ while (i < length) {
+ elements(i) += other(i)
+ i += 1
+ }
+ }
+
def * (scale: Double): Vector = Vector(length, i => this(i) * scale)
def / (d: Double): Vector = this * (1 / d)
diff --git a/core/src/test/scala/spark/AccumulatorSuite.scala b/core/src/test/scala/spark/AccumulatorSuite.scala
index 2297ecf50d..24c4591034 100644
--- a/core/src/test/scala/spark/AccumulatorSuite.scala
+++ b/core/src/test/scala/spark/AccumulatorSuite.scala
@@ -64,8 +64,7 @@ class AccumulatorSuite extends FunSuite with ShouldMatchers {
test ("value readable in tasks") {
- import Vector.VectorAccumParam._
- import Vector._
+ import spark.util.Vector
//stochastic gradient descent with weights stored in accumulator -- should be able to read value as we go
//really easy data
@@ -121,113 +120,4 @@ class AccumulatorSuite extends FunSuite with ShouldMatchers {
}
}
-}
-
-
-
-//ugly copy and paste from examples ...
-class Vector(val elements: Array[Double]) extends Serializable {
- def length = elements.length
-
- def apply(index: Int) = elements(index)
-
- def + (other: Vector): Vector = {
- if (length != other.length)
- throw new IllegalArgumentException("Vectors of different length")
- return Vector(length, i => this(i) + other(i))
- }
-
- def - (other: Vector): Vector = {
- if (length != other.length)
- throw new IllegalArgumentException("Vectors of different length")
- return Vector(length, i => this(i) - other(i))
- }
-
- def dot(other: Vector): Double = {
- if (length != other.length)
- throw new IllegalArgumentException("Vectors of different length")
- var ans = 0.0
- var i = 0
- while (i < length) {
- ans += this(i) * other(i)
- i += 1
- }
- return ans
- }
-
- def plusDot(plus: Vector, other: Vector): Double = {
- if (length != other.length)
- throw new IllegalArgumentException("Vectors of different length")
- if (length != plus.length)
- throw new IllegalArgumentException("Vectors of different length")
- var ans = 0.0
- var i = 0
- while (i < length) {
- ans += (this(i) + plus(i)) * other(i)
- i += 1
- }
- return ans
- }
-
- def += (other: Vector) {
- if (length != other.length)
- throw new IllegalArgumentException("Vectors of different length")
- var ans = 0.0
- var i = 0
- while (i < length) {
- elements(i) += other(i)
- i += 1
- }
- }
-
-
- def * (scale: Double): Vector = Vector(length, i => this(i) * scale)
-
- def / (d: Double): Vector = this * (1 / d)
-
- def unary_- = this * -1
-
- def sum = elements.reduceLeft(_ + _)
-
- def squaredDist(other: Vector): Double = {
- var ans = 0.0
- var i = 0
- while (i < length) {
- ans += (this(i) - other(i)) * (this(i) - other(i))
- i += 1
- }
- return ans
- }
-
- def dist(other: Vector): Double = math.sqrt(squaredDist(other))
-
- override def toString = elements.mkString("(", ", ", ")")
-}
-
-object Vector {
- def apply(elements: Array[Double]) = new Vector(elements)
-
- def apply(elements: Double*) = new Vector(elements.toArray)
-
- def apply(length: Int, initializer: Int => Double): Vector = {
- val elements = new Array[Double](length)
- for (i <- 0 until length)
- elements(i) = initializer(i)
- return new Vector(elements)
- }
-
- def zeros(length: Int) = new Vector(new Array[Double](length))
-
- def ones(length: Int) = Vector(length, _ => 1)
-
- class Multiplier(num: Double) {
- def * (vec: Vector) = vec * num
- }
-
- implicit def doubleToMultiplier(num: Double) = new Multiplier(num)
-
- implicit object VectorAccumParam extends spark.AccumulatorParam[Vector] {
- def addInPlace(t1: Vector, t2: Vector) = t1 + t2
- def zero(initialValue: Vector) = Vector.zeros(initialValue.length)
- }
-}
+} \ No newline at end of file