継続モナド変換子(cats.data.ContT)を使った大域脱出

catsの継続モナド変換子ContTを使って条件分岐による大域脱出を試してみました
2019.09.18

はじめに

今回はcats.effect.IOを使った記述した処理を途中で中断する、というコードが書きたかったのでCatsに定義されている継続モナド変換子であるContT を試してみました。

やりたいこと

前提として何らかの会員情報と会員ごとのポイントを管理するシステムがあるとします。

このシステムで以下の処理が既に実装済みとなっています。

  1. 会員のポイント情報を2倍する
  2. 1.で変更した会員情報をDBに保存する
  3. 1.で変更した会員情報をイベントシステムに送信する

コードでいうと以下のような感じです(DB保存などの詳細はスキップしています)

import scala.language.higherKinds

object Main extends IOApp {
  //会員
  final case class Member(tier: Tier, points: Long)

  //会員のtier
  sealed trait Tier
  final case object Gold extends Tier
  final case object Green extends Tier


  //定義済みの処理
  def doublePoint(member: Member): IO[Member] =
    IO.pure(member.copy(points = member.points * 2))

  def saveMemberToDB(member: Member): IO[Member] = IO {
    println(s"saving member ${member} to DB")
    member
  }

  def publishMember(member: Member): IO[Member] = IO {
    println(s"publishing ${member} to EventQueue")
    member
  }

  //ポイントを2倍してDBなどに保存する
  def doublePointsAndSave(member: Member): IO[Member] =
    for {
      doubled <- doublePoint(member)
      _ <- saveMemberToDB(doubled)
      _ <- publishMember(doubled)
    } yield doubled
}

仕様変更

この処理を以下のように変更します。

  1. 会員のポイント情報を2倍する、ただしtierがGoldの会員のみ
  2. 1.で変更した会員情報をDBに保存する
  3. 1.で変更した会員情報をイベントシステムに送信する

この変更をするにあたってtierがGreenの会員の場合はDB保存以降の処理を行わないように最適化したいです。

で、どうするのか、というと、ここでContTを使ってみます。

まず下記のようにtierに応じて条件分岐する処理を記述します。 tierがGoldの時のみポイントを2倍して、後に続く処理(next)を実行します。

// ContTを使って条件分岐を追加した処理
  def doublePointsIfTierIsGold(member: Member): ContT[IO, Member, Member] =
    ContT { next =>
      member match {
        case Member(Gold, _) =>
          doublePoint(member).flatMap(next)
        case Member(Green, _) => IO.pure(member)
      }
    }

で、残りの処理ですが、1つづつContTを積むのは冗長なので下記のようにまとめてしまいます。

// tier=Goldならポイントを2倍して、DBなどに保存する
  def doublePointsAndSaveWithCondition(member: Member): IO[Member] =
    (for {
      doubled <- doublePointsIfTierIsGold(member)
      //tierがGreenの場合は下記の処理をスキップしたい
      _ <- ContT[IO, Member, Member] { _ =>
        for {
          _ <- saveMemberToDB(doubled)
          _ <- publishMember(doubled)
        } yield doubled
      }
    } yield doubled).run(IO.pure)

以上をまとめると以下のようになります。

import cats.data.ContT
import cats.effect.{ExitCode, IO, IOApp}

import scala.language.higherKinds

object Main extends IOApp {

  sealed trait Tier
  final case object Gold extends Tier
  final case object Green extends Tier

  final case class Member(tier: Tier, points: Long)

  //定義済みの処理

  def doublePoint(member: Member): IO[Member] =
    IO.pure(member.copy(points = member.points * 2))

  def saveMemberToDB(member: Member): IO[Member] = IO {
    println(s"saving member ${member} to DB")
    member
  }

  def publishMember(member: Member): IO[Member] = IO {
    println(s"publishing ${member} to EventQueue")
    member
  }

  // ContTを使って条件分岐を追加した処理
  def doublePointsIfTierIsGold(member: Member): ContT[IO, Member, Member] =
    ContT { next =>
      member match {
        case Member(Gold, _) =>
          doublePoint(member).flatMap(next)
        case Member(Green, _) => IO.pure(member)
      }
    }

  // tier=Goldならポイントを2倍して、DBなどに保存する
  def doublePointsAndSaveWithCondition(member: Member): IO[Member] =
    (for {
      doubled <- doublePointsIfTierIsGold(member)
      //tierがGreenの場合は下記の処理をスキップしたい
      _ <- ContT[IO, Member, Member] { _ =>
        for {
          _ <- saveMemberToDB(doubled)
          _ <- publishMember(doubled)
        } yield doubled
      }
    } yield doubled).run(IO.pure)

  //ポイントを2倍してDBなどに保存する
  def doublePointsAndSave(member: Member): IO[Member] =
    for {
      doubled <- doublePoint(member)
      _ <- saveMemberToDB(doubled)
      _ <- publishMember(doubled)
    } yield doubled

  override def run(args: List[String]): IO[ExitCode] = {
    val member = Member(Gold, 100)
    for {
      _ <- IO(println("条件なし(tier=Gold)"))
      _ <- doublePointsAndSave(member)
      _ <- IO(println("継続による条件分岐(tier=Green)"))
      _ <- doublePointsAndSaveWithCondition(member.copy(Green))
      _ <- IO(println("継続による条件分岐(tier=Gold)"))
      _ <- doublePointsAndSaveWithCondition(member)
    } yield ExitCode.Success

  }

}

実行例

以下が実行した出力です。条件分岐あり、tier=GreenのパターンではDB保存などの後続処理が行われていないのが確認できました。

条件なし(tier=Gold)
saving member Member(Gold,200) to DB
publishing Member(Gold,200) to EventQueue
継続による条件分岐(tier=Green)
継続による条件分岐(tier=Gold)
saving member Member(Gold,200) to DB
publishing Member(Gold,200) to EventQueue

まとめと考察

IOで記述された処理をContTを使うこと条件に応じて途中から脱出するように変更できました。

一応やりたかったことはできたのですが、いくつか気になる点が残りました。

  • 処理を中断する箇所が複数ある場合、その都度前後の処理にContTを積む必要があるのがスマートではないように思える
  • ContT[M[_], A, B]M[_]cats.Deferが必要なのでとりあえず各処理をIOで記述してみたが高カインド型にしたい