aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorFeynman Liang <fliang@databricks.com>2015-06-30 12:31:33 -0700
committerJoseph K. Bradley <joseph@databricks.com>2015-06-30 12:31:33 -0700
commit74cc16dbc35e35fd5cd5542239dcb6e5e7f92d18 (patch)
treec700c4ee420c77fe640b8c9462536a7a7b9e9e3e /mllib
parentb8e5bb6fc1553256e950fdad9cb5acc6b296816e (diff)
downloadspark-74cc16dbc35e35fd5cd5542239dcb6e5e7f92d18.tar.gz
spark-74cc16dbc35e35fd5cd5542239dcb6e5e7f92d18.tar.bz2
spark-74cc16dbc35e35fd5cd5542239dcb6e5e7f92d18.zip
[SPARK-8471] [ML] Discrete Cosine Transform Feature Transformer
Implementation and tests for Discrete Cosine Transformer. Author: Feynman Liang <fliang@databricks.com> Closes #6894 from feynmanliang/dct-features and squashes the following commits: 433dbc7 [Feynman Liang] Test refactoring 91e9636 [Feynman Liang] Style guide and test helper refactor b5ac19c [Feynman Liang] Use Vector types, add Java test 530983a [Feynman Liang] Tests for other numeric datatypes 195d7aa [Feynman Liang] Implement support for arbitrary numeric types 95d4939 [Feynman Liang] Working DCT for 1D Doubles
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/DiscreteCosineTransformer.scala72
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/feature/JavaDiscreteCosineTransformerSuite.java78
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/DiscreteCosineTransformerSuite.scala73
3 files changed, 223 insertions, 0 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/DiscreteCosineTransformer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/DiscreteCosineTransformer.scala
new file mode 100644
index 0000000000..a2f4d59f81
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/DiscreteCosineTransformer.scala
@@ -0,0 +1,72 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.feature
+
+import edu.emory.mathcs.jtransforms.dct._
+
+import org.apache.spark.annotation.Experimental
+import org.apache.spark.ml.UnaryTransformer
+import org.apache.spark.ml.param.BooleanParam
+import org.apache.spark.ml.util.Identifiable
+import org.apache.spark.mllib.linalg.{Vector, VectorUDT, Vectors}
+import org.apache.spark.sql.types.DataType
+
+/**
+ * :: Experimental ::
+ * A feature transformer that takes the 1D discrete cosine transform of a real vector. No zero
+ * padding is performed on the input vector.
+ * It returns a real vector of the same length representing the DCT. The return vector is scaled
+ * such that the transform matrix is unitary (aka scaled DCT-II).
+ *
+ * More information on [[https://en.wikipedia.org/wiki/Discrete_cosine_transform#DCT-II Wikipedia]].
+ */
+@Experimental
+class DiscreteCosineTransformer(override val uid: String)
+ extends UnaryTransformer[Vector, Vector, DiscreteCosineTransformer] {
+
+ def this() = this(Identifiable.randomUID("dct"))
+
+ /**
+ * Indicates whether to perform the inverse DCT (true) or forward DCT (false).
+ * Default: false
+ * @group param
+ */
+ def inverse: BooleanParam = new BooleanParam(
+ this, "inverse", "Set transformer to perform inverse DCT")
+
+ /** @group setParam */
+ def setInverse(value: Boolean): this.type = set(inverse, value)
+
+ /** @group getParam */
+ def getInverse: Boolean = $(inverse)
+
+ setDefault(inverse -> false)
+
+ override protected def createTransformFunc: Vector => Vector = { vec =>
+ val result = vec.toArray
+ val jTransformer = new DoubleDCT_1D(result.length)
+ if ($(inverse)) jTransformer.inverse(result, true) else jTransformer.forward(result, true)
+ Vectors.dense(result)
+ }
+
+ override protected def validateInputType(inputType: DataType): Unit = {
+ require(inputType.isInstanceOf[VectorUDT], s"Input type must be VectorUDT but got $inputType.")
+ }
+
+ override protected def outputDataType: DataType = new VectorUDT
+}
diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaDiscreteCosineTransformerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaDiscreteCosineTransformerSuite.java
new file mode 100644
index 0000000000..28bc5f65e0
--- /dev/null
+++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaDiscreteCosineTransformerSuite.java
@@ -0,0 +1,78 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.feature;
+
+import com.google.common.collect.Lists;
+import edu.emory.mathcs.jtransforms.dct.DoubleDCT_1D;
+import org.junit.After;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Test;
+
+import org.apache.spark.api.java.JavaRDD;
+import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.mllib.linalg.Vector;
+import org.apache.spark.mllib.linalg.VectorUDT;
+import org.apache.spark.mllib.linalg.Vectors;
+import org.apache.spark.sql.DataFrame;
+import org.apache.spark.sql.Row;
+import org.apache.spark.sql.RowFactory;
+import org.apache.spark.sql.SQLContext;
+import org.apache.spark.sql.types.Metadata;
+import org.apache.spark.sql.types.StructField;
+import org.apache.spark.sql.types.StructType;
+
+public class JavaDiscreteCosineTransformerSuite {
+ private transient JavaSparkContext jsc;
+ private transient SQLContext jsql;
+
+ @Before
+ public void setUp() {
+ jsc = new JavaSparkContext("local", "JavaDiscreteCosineTransformerSuite");
+ jsql = new SQLContext(jsc);
+ }
+
+ @After
+ public void tearDown() {
+ jsc.stop();
+ jsc = null;
+ }
+
+ @Test
+ public void javaCompatibilityTest() {
+ double[] input = new double[] {1D, 2D, 3D, 4D};
+ JavaRDD<Row> data = jsc.parallelize(Lists.newArrayList(
+ RowFactory.create(Vectors.dense(input))
+ ));
+ DataFrame dataset = jsql.createDataFrame(data, new StructType(new StructField[]{
+ new StructField("vec", (new VectorUDT()), false, Metadata.empty())
+ }));
+
+ double[] expectedResult = input.clone();
+ (new DoubleDCT_1D(input.length)).forward(expectedResult, true);
+
+ DiscreteCosineTransformer DCT = new DiscreteCosineTransformer()
+ .setInputCol("vec")
+ .setOutputCol("resultVec");
+
+ Row[] result = DCT.transform(dataset).select("resultVec").collect();
+ Vector resultVec = result[0].getAs("resultVec");
+
+ Assert.assertArrayEquals(expectedResult, resultVec.toArray(), 1e-6);
+ }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/DiscreteCosineTransformerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/DiscreteCosineTransformerSuite.scala
new file mode 100644
index 0000000000..ed0fc11f78
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/DiscreteCosineTransformerSuite.scala
@@ -0,0 +1,73 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.feature
+
+import scala.beans.BeanInfo
+
+import edu.emory.mathcs.jtransforms.dct.DoubleDCT_1D
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.mllib.linalg.{Vector, Vectors}
+import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.sql.{DataFrame, Row}
+
+@BeanInfo
+case class DCTTestData(vec: Vector, wantedVec: Vector)
+
+class DiscreteCosineTransformerSuite extends SparkFunSuite with MLlibTestSparkContext {
+
+ test("forward transform of discrete cosine matches jTransforms result") {
+ val data = Vectors.dense((0 until 128).map(_ => 2D * math.random - 1D).toArray)
+ val inverse = false
+
+ testDCT(data, inverse)
+ }
+
+ test("inverse transform of discrete cosine matches jTransforms result") {
+ val data = Vectors.dense((0 until 128).map(_ => 2D * math.random - 1D).toArray)
+ val inverse = true
+
+ testDCT(data, inverse)
+ }
+
+ private def testDCT(data: Vector, inverse: Boolean): Unit = {
+ val expectedResultBuffer = data.toArray.clone()
+ if (inverse) {
+ (new DoubleDCT_1D(data.size)).inverse(expectedResultBuffer, true)
+ } else {
+ (new DoubleDCT_1D(data.size)).forward(expectedResultBuffer, true)
+ }
+ val expectedResult = Vectors.dense(expectedResultBuffer)
+
+ val dataset = sqlContext.createDataFrame(Seq(
+ DCTTestData(data, expectedResult)
+ ))
+
+ val transformer = new DiscreteCosineTransformer()
+ .setInputCol("vec")
+ .setOutputCol("resultVec")
+ .setInverse(inverse)
+
+ transformer.transform(dataset)
+ .select("resultVec", "wantedVec")
+ .collect()
+ .foreach { case Row(resultVec: Vector, wantedVec: Vector) =>
+ assert(Vectors.sqdist(resultVec, wantedVec) < 1e-6)
+ }
+ }
+}