diff options
-rw-r--r-- | src/compiler/scala/tools/nsc/classpath/AggregateFlatClassPath.scala | 25 |
1 files changed, 12 insertions, 13 deletions
diff --git a/src/compiler/scala/tools/nsc/classpath/AggregateFlatClassPath.scala b/src/compiler/scala/tools/nsc/classpath/AggregateFlatClassPath.scala index 3768ed19ac..91026d0e13 100644 --- a/src/compiler/scala/tools/nsc/classpath/AggregateFlatClassPath.scala +++ b/src/compiler/scala/tools/nsc/classpath/AggregateFlatClassPath.scala @@ -31,21 +31,23 @@ case class AggregateFlatClassPath(aggregates: Seq[FlatClassPath]) extends FlatCl } override def findClass(className: String): Option[ClassRepresentation] = { - val (pkg, simpleClassName) = PackageNameUtils.separatePkgAndClassNames(className) - @tailrec - def findEntry[T <: ClassRepClassPathEntry](aggregates: Seq[FlatClassPath], getEntries: FlatClassPath => Seq[T]): Option[T] = + def findEntry(aggregates: Seq[FlatClassPath], isSource: Boolean): Option[ClassRepresentation] = if (aggregates.nonEmpty) { - val entry = getEntries(aggregates.head).find(_.name == simpleClassName) + val entry = aggregates.head.findClass(className) match { + case s @ Some(_: SourceFileEntry) if isSource => s + case s @ Some(_: ClassFileEntry) if !isSource => s + case _ => None + } if (entry.isDefined) entry - else findEntry(aggregates.tail, getEntries) + else findEntry(aggregates.tail, isSource) } else None - val classEntry = findEntry(aggregates, classesGetter(pkg)) - val sourceEntry = findEntry(aggregates, sourcesGetter(pkg)) + val classEntry = findEntry(aggregates, isSource = false) + val sourceEntry = findEntry(aggregates, isSource = true) (classEntry, sourceEntry) match { - case (Some(c), Some(s)) => Some(ClassAndSourceFilesEntry(c.file, s.file)) + case (Some(c: ClassFileEntry), Some(s: SourceFileEntry)) => Some(ClassAndSourceFilesEntry(c.file, s.file)) case (c @ Some(_), _) => c case (_, s) => s } @@ -63,10 +65,10 @@ case class AggregateFlatClassPath(aggregates: Seq[FlatClassPath]) extends FlatCl } override private[nsc] def classes(inPackage: String): Seq[ClassFileEntry] = - getDistinctEntries(classesGetter(inPackage)) + getDistinctEntries(_.classes(inPackage)) override private[nsc] def sources(inPackage: String): Seq[SourceFileEntry] = - getDistinctEntries(sourcesGetter(inPackage)) + getDistinctEntries(_.sources(inPackage)) override private[nsc] def list(inPackage: String): FlatClassPathEntries = { val (packages, classesAndSources) = aggregates.map(_.list(inPackage)).unzip @@ -121,9 +123,6 @@ case class AggregateFlatClassPath(aggregates: Seq[FlatClassPath]) extends FlatCl } entriesBuffer.toIndexedSeq } - - private def classesGetter(pkg: String) = (cp: FlatClassPath) => cp.classes(pkg) - private def sourcesGetter(pkg: String) = (cp: FlatClassPath) => cp.sources(pkg) } object AggregateFlatClassPath { |