diff options
author | Imran Rashid <imran@quantifind.com> | 2012-07-12 12:52:12 -0700 |
---|---|---|
committer | Imran Rashid <imran@quantifind.com> | 2012-07-12 13:08:03 -0700 |
commit | 86024ca74da87907b360963d5a603ef0fcc0a286 (patch) | |
tree | f4e82772d45835bb30b0b2d16d4fca12057511d8 /core | |
parent | 42ce879486f935043ccc21258edb34a4c20d1a8d (diff) | |
download | spark-86024ca74da87907b360963d5a603ef0fcc0a286.tar.gz spark-86024ca74da87907b360963d5a603ef0fcc0a286.tar.bz2 spark-86024ca74da87907b360963d5a603ef0fcc0a286.zip |
add some functionality to Vector, delete copy in AccumulatorSuite
Diffstat (limited to 'core')
-rw-r--r-- | core/src/main/scala/spark/util/Vector.scala | 32 | ||||
-rw-r--r-- | core/src/test/scala/spark/AccumulatorSuite.scala | 114 |
2 files changed, 33 insertions, 113 deletions
diff --git a/core/src/main/scala/spark/util/Vector.scala b/core/src/main/scala/spark/util/Vector.scala index e5604687e9..4e95ac2ac6 100644 --- a/core/src/main/scala/spark/util/Vector.scala +++ b/core/src/main/scala/spark/util/Vector.scala @@ -29,7 +29,37 @@ class Vector(val elements: Array[Double]) extends Serializable { return ans } - def * (scale: Double): Vector = Vector(length, i => this(i) * scale) + /** + * return (this + plus) dot other, but without creating any intermediate storage + * @param plus + * @param other + * @return + */ + def plusDot(plus: Vector, other: Vector): Double = { + if (length != other.length) + throw new IllegalArgumentException("Vectors of different length") + if (length != plus.length) + throw new IllegalArgumentException("Vectors of different length") + var ans = 0.0 + var i = 0 + while (i < length) { + ans += (this(i) + plus(i)) * other(i) + i += 1 + } + return ans + } + + def +=(other: Vector) { + if (length != other.length) + throw new IllegalArgumentException("Vectors of different length") + var ans = 0.0 + var i = 0 + while (i < length) { + elements(i) += other(i) + i += 1 + } + } + def * (scale: Double): Vector = Vector(length, i => this(i) * scale) def / (d: Double): Vector = this * (1 / d) diff --git a/core/src/test/scala/spark/AccumulatorSuite.scala b/core/src/test/scala/spark/AccumulatorSuite.scala index 2297ecf50d..24c4591034 100644 --- a/core/src/test/scala/spark/AccumulatorSuite.scala +++ b/core/src/test/scala/spark/AccumulatorSuite.scala @@ -64,8 +64,7 @@ class AccumulatorSuite extends FunSuite with ShouldMatchers { test ("value readable in tasks") { - import Vector.VectorAccumParam._ - import Vector._ + import spark.util.Vector //stochastic gradient descent with weights stored in accumulator -- should be able to read value as we go //really easy data @@ -121,113 +120,4 @@ class AccumulatorSuite extends FunSuite with ShouldMatchers { } } -} - - - -//ugly copy and paste from examples ... -class Vector(val elements: Array[Double]) extends Serializable { - def length = elements.length - - def apply(index: Int) = elements(index) - - def + (other: Vector): Vector = { - if (length != other.length) - throw new IllegalArgumentException("Vectors of different length") - return Vector(length, i => this(i) + other(i)) - } - - def - (other: Vector): Vector = { - if (length != other.length) - throw new IllegalArgumentException("Vectors of different length") - return Vector(length, i => this(i) - other(i)) - } - - def dot(other: Vector): Double = { - if (length != other.length) - throw new IllegalArgumentException("Vectors of different length") - var ans = 0.0 - var i = 0 - while (i < length) { - ans += this(i) * other(i) - i += 1 - } - return ans - } - - def plusDot(plus: Vector, other: Vector): Double = { - if (length != other.length) - throw new IllegalArgumentException("Vectors of different length") - if (length != plus.length) - throw new IllegalArgumentException("Vectors of different length") - var ans = 0.0 - var i = 0 - while (i < length) { - ans += (this(i) + plus(i)) * other(i) - i += 1 - } - return ans - } - - def += (other: Vector) { - if (length != other.length) - throw new IllegalArgumentException("Vectors of different length") - var ans = 0.0 - var i = 0 - while (i < length) { - elements(i) += other(i) - i += 1 - } - } - - - def * (scale: Double): Vector = Vector(length, i => this(i) * scale) - - def / (d: Double): Vector = this * (1 / d) - - def unary_- = this * -1 - - def sum = elements.reduceLeft(_ + _) - - def squaredDist(other: Vector): Double = { - var ans = 0.0 - var i = 0 - while (i < length) { - ans += (this(i) - other(i)) * (this(i) - other(i)) - i += 1 - } - return ans - } - - def dist(other: Vector): Double = math.sqrt(squaredDist(other)) - - override def toString = elements.mkString("(", ", ", ")") -} - -object Vector { - def apply(elements: Array[Double]) = new Vector(elements) - - def apply(elements: Double*) = new Vector(elements.toArray) - - def apply(length: Int, initializer: Int => Double): Vector = { - val elements = new Array[Double](length) - for (i <- 0 until length) - elements(i) = initializer(i) - return new Vector(elements) - } - - def zeros(length: Int) = new Vector(new Array[Double](length)) - - def ones(length: Int) = Vector(length, _ => 1) - - class Multiplier(num: Double) { - def * (vec: Vector) = vec * num - } - - implicit def doubleToMultiplier(num: Double) = new Multiplier(num) - - implicit object VectorAccumParam extends spark.AccumulatorParam[Vector] { - def addInPlace(t1: Vector, t2: Vector) = t1 + t2 - def zero(initialValue: Vector) = Vector.zeros(initialValue.length) - } -} +}
\ No newline at end of file |