diff options
Diffstat (limited to 'src/main/scala')
-rw-r--r-- | src/main/scala/xyz/driver/core/app.scala | 16 | ||||
-rw-r--r-- | src/main/scala/xyz/driver/core/auth.scala | 120 | ||||
-rw-r--r-- | src/main/scala/xyz/driver/core/database/Dal.scala | 10 | ||||
-rw-r--r-- | src/main/scala/xyz/driver/core/database/database.scala | 64 | ||||
-rw-r--r-- | src/main/scala/xyz/driver/core/rest.scala | 131 |
5 files changed, 171 insertions, 170 deletions
diff --git a/src/main/scala/xyz/driver/core/app.scala b/src/main/scala/xyz/driver/core/app.scala index 9bc34f6..54d08d4 100644 --- a/src/main/scala/xyz/driver/core/app.scala +++ b/src/main/scala/xyz/driver/core/app.scala @@ -16,7 +16,8 @@ import org.slf4j.LoggerFactory import spray.json.DefaultJsonProtocol import xyz.driver.core import xyz.driver.core.logging.{Logger, TypesafeScalaLogger} -import xyz.driver.core.rest.{ContextHeaders, Swagger} +import xyz.driver.core.rest.ServiceRequestContext.ContextHeaders +import xyz.driver.core.rest.Swagger import xyz.driver.core.stats.SystemStats import xyz.driver.core.time.Time import xyz.driver.core.time.provider.{SystemTimeProvider, TimeProvider} @@ -68,7 +69,7 @@ object app { val _ = Future { http.bindAndHandle(route2HandlerFlow(handleExceptions(ExceptionHandler(exceptionHandler)) { ctx => - val trackingId = rest.extractTrackingId(ctx) + val trackingId = rest.ServiceRequestContext.extractTrackingId(ctx) log.audit(s"Received request ${ctx.request} with tracking id $trackingId") val contextWithTrackingId = @@ -81,24 +82,29 @@ object app { } } + /** + * Override me for custom exception handling + * + * @return Exception handling route for exception type + */ protected def exceptionHandler = PartialFunction[Throwable, Route] { case is: IllegalStateException => ctx => - val trackingId = rest.extractTrackingId(ctx) + val trackingId = rest.ServiceRequestContext.extractTrackingId(ctx) log.debug(s"Request is not allowed to ${ctx.request.uri} ($trackingId)", is) complete(HttpResponse(BadRequest, entity = is.getMessage))(ctx) case cm: ConcurrentModificationException => ctx => - val trackingId = rest.extractTrackingId(ctx) + val trackingId = rest.ServiceRequestContext.extractTrackingId(ctx) log.audit(s"Concurrent modification of the resource ${ctx.request.uri} ($trackingId)", cm) complete( HttpResponse(Conflict, entity = "Resource was changed concurrently, try requesting a newer version"))(ctx) case t: Throwable => ctx => - val trackingId = rest.extractTrackingId(ctx) + val trackingId = rest.ServiceRequestContext.extractTrackingId(ctx) log.error(s"Request to ${ctx.request.uri} could not be handled normally ($trackingId)", t) complete(HttpResponse(InternalServerError, entity = t.getMessage))(ctx) } diff --git a/src/main/scala/xyz/driver/core/auth.scala b/src/main/scala/xyz/driver/core/auth.scala deleted file mode 100644 index 0b30bc0..0000000 --- a/src/main/scala/xyz/driver/core/auth.scala +++ /dev/null @@ -1,120 +0,0 @@ -package xyz.driver.core - -import akka.http.scaladsl.model.headers.HttpChallenges -import akka.http.scaladsl.server.AuthenticationFailedRejection.CredentialsRejected -import xyz.driver.core.rest.ServiceRequestContext - -import scala.concurrent.Future -import scala.util.{Failure, Success} -import scalaz.OptionT - -object auth { - - sealed trait Permission - case object CanSeeUser extends Permission - case object CanSeeAssay extends Permission - case object CanSeeReport extends Permission - case object CanCreateReport extends Permission - case object CanEditReport extends Permission - case object CanReviewReport extends Permission - case object CanEditReviewingReport extends Permission - case object CanSignOutReport extends Permission - case object CanAmendReport extends Permission - case object CanShareReportWithPatient extends Permission - case object CanAssignRoles extends Permission - - trait Role { - val id: Id[Role] - val name: Name[Role] - val permissions: Set[Permission] - - def hasPermission(permission: Permission): Boolean = permissions.contains(permission) - } - - case object ObserverRole extends Role { - val id = Id("1") - val name = Name("observer") - val permissions = Set[Permission](CanSeeUser, CanSeeAssay, CanSeeReport) - } - - case object PatientRole extends Role { - val id = Id("2") - val name = Name("patient") - val permissions = Set.empty[Permission] - } - - case object CuratorRole extends Role { - val id = Id("3") - val name = Name("curator") - val permissions = ObserverRole.permissions ++ Set[Permission](CanEditReport, CanReviewReport) - } - - case object PathologistRole extends Role { - val id = Id("4") - val name = Name("pathologist") - val permissions = ObserverRole.permissions ++ - Set[Permission](CanEditReport, CanSignOutReport, CanAmendReport, CanEditReviewingReport) - } - - case object AdministratorRole extends Role { - val id = Id("5") - val name = Name("administrator") - val permissions = CuratorRole.permissions ++ - Set[Permission](CanCreateReport, CanShareReportWithPatient, CanAssignRoles) - } - - case object PhysicianRole extends Role { - val id = Id("6") - val name = Name("physician") - val permissions = Set[Permission]() - } - - case object RelativeRole extends Role { - val id = Id("7") - val name = Name("relative") - val permissions = Set[Permission]() - } - - trait User { - def id: Id[User] - def roles: Set[Role] - def permissions: Set[Permission] = roles.flatMap(_.permissions) - } - - final case class AuthToken(value: String) - - final case class PasswordHash(value: String) - - object AuthService { - val AuthenticationTokenHeader = rest.ContextHeaders.AuthenticationTokenHeader - val SetAuthenticationTokenHeader = "set-authorization" - } - - trait AuthService[U <: User] { - - import akka.http.scaladsl.server._ - import Directives._ - - protected def authStatus(context: ServiceRequestContext): OptionT[Future, U] - - def authorize(permissions: Permission*): Directive1[U] = { - rest.serviceContext flatMap { ctx => - onComplete(authStatus(ctx).run).flatMap { - case Success(Some(user)) => - if (permissions.forall(user.permissions.contains)) provide(user) - else { - val challenge = - HttpChallenges.basic(s"User does not have the required permissions: ${permissions.mkString(", ")}") - reject(AuthenticationFailedRejection(CredentialsRejected, challenge)) - } - - case Success(None) => - reject(ValidationRejection(s"Wasn't able to find authenticated user for the token provided")) - - case Failure(t) => - reject(ValidationRejection(s"Wasn't able to verify token for authenticated user", Some(t))) - } - } - } - } -} diff --git a/src/main/scala/xyz/driver/core/database/Dal.scala b/src/main/scala/xyz/driver/core/database/Dal.scala index e920392..0d38282 100644 --- a/src/main/scala/xyz/driver/core/database/Dal.scala +++ b/src/main/scala/xyz/driver/core/database/Dal.scala @@ -6,12 +6,12 @@ import scalaz.{ListT, Monad} import scalaz.std.scalaFuture._ trait Dal { - protected type T[D] - protected implicit val monadT: Monad[T] + type T[D] + implicit val monadT: Monad[T] - protected def execute[D](operations: T[D]): Future[D] - protected def noAction[V](v: V): T[V] - protected def customAction[R](action: => Future[R]): T[R] + def execute[D](operations: T[D]): Future[D] + def noAction[V](v: V): T[V] + def customAction[R](action: => Future[R]): T[R] } class FutureDal(executionContext: ExecutionContext) extends Dal { diff --git a/src/main/scala/xyz/driver/core/database/database.scala b/src/main/scala/xyz/driver/core/database/database.scala index a8aec63..308c391 100644 --- a/src/main/scala/xyz/driver/core/database/database.scala +++ b/src/main/scala/xyz/driver/core/database/database.scala @@ -28,35 +28,55 @@ package database { trait ColumnTypes { val profile: JdbcProfile + } + + trait NameColumnTypes extends ColumnTypes { import profile.api._ + implicit def `xyz.driver.core.Name.columnType`[T]: BaseColumnType[Name[T]] + } - implicit def `xyz.driver.core.Id.columnType`[T]: BaseColumnType[Id[T]] + object NameColumnTypes { + trait StringName extends NameColumnTypes { + import profile.api._ - implicit def `xyz.driver.core.Name.columnType`[T]: BaseColumnType[Name[T]] = - MappedColumnType.base[Name[T], String](_.value, Name[T](_)) + override implicit def `xyz.driver.core.Name.columnType`[T]: BaseColumnType[Name[T]] = + MappedColumnType.base[Name[T], String](_.value, Name[T]) + } + } - implicit def `xyz.driver.core.time.Time.columnType`: BaseColumnType[Time] = - MappedColumnType.base[Time, Long](_.millis, Time(_)) + trait DateColumnTypes extends ColumnTypes { + import profile.api._ + implicit def `xyz.driver.core.time.Date.columnType`: BaseColumnType[Date] + } - implicit def `xyz.driver.core.time.Date.columnType`: BaseColumnType[Date] = - MappedColumnType.base[Date, java.sql.Date](dateToSqlDate(_), sqlDateToDate(_)) + object DateColumnTypes { + trait SqlDate extends DateColumnTypes { + import profile.api._ + override implicit def `xyz.driver.core.time.Date.columnType`: BaseColumnType[Date] = + MappedColumnType.base[Date, java.sql.Date](dateToSqlDate, sqlDateToDate) + } } - object ColumnTypes { - trait UUID extends ColumnTypes { + trait IdColumnTypes extends ColumnTypes { + import profile.api._ + implicit def `xyz.driver.core.Id.columnType`[T]: BaseColumnType[Id[T]] + } + + object IdColumnTypes { + trait UUID extends IdColumnTypes { import profile.api._ override implicit def `xyz.driver.core.Id.columnType`[T] = MappedColumnType .base[Id[T], java.util.UUID](id => java.util.UUID.fromString(id.value), uuid => Id[T](uuid.toString)) } - trait SerialId extends ColumnTypes { + trait SerialId extends IdColumnTypes { import profile.api._ override implicit def `xyz.driver.core.Id.columnType`[T] = MappedColumnType.base[Id[T], Long](_.value.toLong, serialId => Id[T](serialId.toString)) } - trait NaturalId extends ColumnTypes { + trait NaturalId extends IdColumnTypes { import profile.api._ override implicit def `xyz.driver.core.Id.columnType`[T] = @@ -64,6 +84,28 @@ package database { } } + trait TimestampColumnTypes extends ColumnTypes { + import profile.api._ + implicit def `xyz.driver.core.time.Time.columnType`: BaseColumnType[Time] + } + + object TimestampColumnTypes { + trait SqlTimestamp extends TimestampColumnTypes { + import profile.api._ + + override implicit def `xyz.driver.core.time.Time.columnType`: BaseColumnType[Time] = + MappedColumnType.base[Time, java.sql.Timestamp](time => new java.sql.Timestamp(time.millis), + timestamp => Time(timestamp.getTime)) + } + + trait PrimitiveTimestamp extends TimestampColumnTypes { + import profile.api._ + + override implicit def `xyz.driver.core.time.Time.columnType`: BaseColumnType[Time] = + MappedColumnType.base[Time, Long](_.millis, Time(_)) + } + } + trait DatabaseObject extends ColumnTypes { def createTables(): Future[Unit] def disconnect(): Unit diff --git a/src/main/scala/xyz/driver/core/rest.scala b/src/main/scala/xyz/driver/core/rest.scala index ed90d7d..437df3c 100644 --- a/src/main/scala/xyz/driver/core/rest.scala +++ b/src/main/scala/xyz/driver/core/rest.scala @@ -3,7 +3,8 @@ package xyz.driver.core import akka.actor.ActorSystem import akka.http.scaladsl.Http import akka.http.scaladsl.model._ -import akka.http.scaladsl.model.headers.RawHeader +import akka.http.scaladsl.model.headers.{HttpChallenges, RawHeader} +import akka.http.scaladsl.server.AuthenticationFailedRejection.CredentialsRejected import akka.http.scaladsl.unmarshalling.Unmarshal import akka.stream.ActorMaterializer import com.github.swagger.akka.model._ @@ -17,47 +18,119 @@ import xyz.driver.core.time.provider.TimeProvider import scala.concurrent.{ExecutionContext, Future} import scala.util.{Failure, Success} +import scalaz.OptionT import scalaz.Scalaz.{Id => _, _} object rest { - object ContextHeaders { - val AuthenticationTokenHeader = "WWW-Authenticate" - val TrackingIdHeader = "X-Trace" + final case class ServiceRequestContext( + trackingId: String = generators.nextUuid().toString, + contextHeaders: Map[String, String] = Map.empty[String, String]) { - object LinkerD { - // https://linkerd.io/doc/0.7.4/linkerd/protocol-http/ - def isLinkerD(headerName: String) = headerName.startsWith("l5d-") - } + def authToken: Option[Auth.AuthToken] = + contextHeaders.get(Auth.AuthProvider.AuthenticationTokenHeader).map(Auth.AuthToken.apply) } - final case class ServiceRequestContext( - trackingId: String = generators.nextUuid().toString, - contextHeaders: Map[String, String] = Map.empty[String, String]) + object ServiceRequestContext { + + object ContextHeaders { + val AuthenticationTokenHeader = "WWW-Authenticate" + val TrackingIdHeader = "X-Trace" + + object LinkerD { + // https://linkerd.io/doc/0.7.4/linkerd/protocol-http/ + def isLinkerD(headerName: String) = headerName.startsWith("l5d-") + } + } - import akka.http.scaladsl.server._ - import Directives._ + import akka.http.scaladsl.server._ + import Directives._ - def serviceContext: Directive1[ServiceRequestContext] = extract(ctx => extractServiceContext(ctx)) + def serviceContext: Directive1[ServiceRequestContext] = extract(ctx => extractServiceContext(ctx)) - def extractServiceContext(ctx: RequestContext): ServiceRequestContext = - ServiceRequestContext(extractTrackingId(ctx), extractContextHeaders(ctx)) + def extractServiceContext(ctx: RequestContext): ServiceRequestContext = + ServiceRequestContext(extractTrackingId(ctx), extractContextHeaders(ctx)) - def extractTrackingId(ctx: RequestContext): String = { - ctx.request.headers - .find(_.name == ContextHeaders.TrackingIdHeader) - .fold(java.util.UUID.randomUUID.toString)(_.value()) - } + def extractTrackingId(ctx: RequestContext): String = { + ctx.request.headers + .find(_.name == ContextHeaders.TrackingIdHeader) + .fold(java.util.UUID.randomUUID.toString)(_.value()) + } - def extractContextHeaders(ctx: RequestContext): Map[String, String] = { - ctx.request.headers.filter { h => - h.name === ContextHeaders.AuthenticationTokenHeader || h.name === ContextHeaders.TrackingIdHeader - // || ContextHeaders.LinkerD.isLinkerD(h.lowercaseName) - } map { header => - header.name -> header.value - } toMap + def extractContextHeaders(ctx: RequestContext): Map[String, String] = { + ctx.request.headers.filter { h => + h.name === ContextHeaders.AuthenticationTokenHeader || h.name === ContextHeaders.TrackingIdHeader + // || ContextHeaders.LinkerD.isLinkerD(h.lowercaseName) + } map { header => + header.name -> header.value + } toMap + } } + object Auth { + + trait Permission + + trait Role { + val id: Id[Role] + val name: Name[Role] + val permissions: Set[Permission] + + def hasPermission(permission: Permission): Boolean = permissions.contains(permission) + } + + trait User { + def id: Id[User] + def roles: Set[Role] + def permissions: Set[Permission] = roles.flatMap(_.permissions) + } + + final case class BasicUser(id: Id[User], roles: Set[Role]) extends User + + final case class AuthToken(value: String) + + final case class PasswordHash(value: String) + + object AuthProvider { + val AuthenticationTokenHeader = ServiceRequestContext.ContextHeaders.AuthenticationTokenHeader + val SetAuthenticationTokenHeader = "set-authorization" + } + + trait AuthProvider[U <: User] { + + import akka.http.scaladsl.server._ + import Directives._ + + /** + * Specific implementation on how to extract user from request context, + * can either need to do a network call to auth server or extract everything from self-contained token + * + * @param context set of request values which can be relevant to authenticate user + * @return authenticated user + */ + protected def authenticatedUser(context: ServiceRequestContext): OptionT[Future, U] + + def authorize(permissions: Permission*): Directive1[U] = { + ServiceRequestContext.serviceContext flatMap { ctx => + onComplete(authenticatedUser(ctx).run).flatMap { + case Success(Some(user)) => + if (permissions.forall(user.permissions.contains)) provide(user) + else { + val challenge = + HttpChallenges.basic(s"User does not have the required permissions: ${permissions.mkString(", ")}") + reject(AuthenticationFailedRejection(CredentialsRejected, challenge)) + } + + case Success(None) => + reject(ValidationRejection(s"Wasn't able to find authenticated user for the token provided")) + + case Failure(t) => + reject(ValidationRejection(s"Wasn't able to verify token for authenticated user", Some(t))) + } + } + } + } + } trait Service @@ -82,7 +155,7 @@ object rest { val requestTime = time.currentTime() val request = requestStub - .withHeaders(RawHeader(ContextHeaders.TrackingIdHeader, context.trackingId)) + .withHeaders(RawHeader(ServiceRequestContext.ContextHeaders.TrackingIdHeader, context.trackingId)) .withHeaders(context.contextHeaders.toSeq.map { h => RawHeader(h._1, h._2): HttpHeader }: _*) log.audit(s"Sending to ${request.uri} request $request with tracking id ${context.trackingId}") |