aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala14
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala14
2 files changed, 27 insertions, 1 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
index 71034c2c43..2cf241de61 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
@@ -118,7 +118,19 @@ trait ScalaReflection {
case t if t <:< typeOf[Product] =>
val formalTypeArgs = t.typeSymbol.asClass.typeParams
val TypeRef(_, _, actualTypeArgs) = t
- val params = t.member(nme.CONSTRUCTOR).asMethod.paramss
+ val constructorSymbol = t.member(nme.CONSTRUCTOR)
+ val params = if (constructorSymbol.isMethod) {
+ constructorSymbol.asMethod.paramss
+ } else {
+ // Find the primary constructor, and use its parameter ordering.
+ val primaryConstructorSymbol: Option[Symbol] = constructorSymbol.asTerm.alternatives.find(
+ s => s.isMethod && s.asMethod.isPrimaryConstructor)
+ if (primaryConstructorSymbol.isEmpty) {
+ sys.error("Internal SQL error: Product object did not have a primary constructor.")
+ } else {
+ primaryConstructorSymbol.get.asMethod.paramss
+ }
+ }
Schema(StructType(
params.head.map { p =>
val Schema(dataType, nullable) =
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala
index ddc3d44869..7be24bea7d 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala
@@ -68,6 +68,10 @@ case class ComplexData(
case class GenericData[A](
genericField: A)
+case class MultipleConstructorsData(a: Int, b: String, c: Double) {
+ def this(b: String, a: Int) = this(a, b, c = 1.0)
+}
+
class ScalaReflectionSuite extends FunSuite {
import ScalaReflection._
@@ -253,4 +257,14 @@ class ScalaReflectionSuite extends FunSuite {
Row(1, 1, 1, 1, 1, 1, true))
assert(convertToCatalyst(data, dataType) === convertedData)
}
+
+ test("infer schema from case class with multiple constructors") {
+ val dataType = schemaFor[MultipleConstructorsData].dataType
+ dataType match {
+ case s: StructType =>
+ // Schema should have order: a: Int, b: String, c: Double
+ assert(s.fieldNames === Seq("a", "b", "c"))
+ assert(s.fields.map(_.dataType) === Seq(IntegerType, StringType, DoubleType))
+ }
+ }
}