限定継続(ContT)でプログラムの実行制御する

限定継続(ContT)のshift-resetを使ったフロー制御を試してみました。
2022.09.14

はじめに

PPLサマースクール で聴講した講義「限定継続を使った計算効果プログラミング」が面白かったのでcatsのContT (限定継続モナド)でおさらいをしてみました。

shiftとresetによる限定継続の紹介は「shift/resetプログラミング」(Tutorial notesからリンクされています)が詳しいです。

リストの積

以下のようなリストの要素の積を求める関数を考えます。

times :: Numeric a => [a] -> a
times []               = 1
times x:xs             = * x (times xs)

シンプルな実装ですが、引数のリストに0が含まれる場合は戻り値が0になることは明らかなので他の要素の計算をスキップしたいです。

そこで以下のように最適化します。

times :: Numeric a => [a] -> a
times []               = 1
times x:xs | x == 0    = 0
           | otherwise = * x (times xs)

しかし、0より前に見つかった要素の積は計算済みになっています。限定継続(ContT)を使ってこれを最適化します。

計算経過を示すために途中で標準出力しています。2つめの試行では0を含むため、前後の計算がスキップされています。

import cats.Show
import cats.data.ContT
import cats.effect.{IO, IOApp}
import cats.syntax.all.*

object ShiftReset extends IOApp.Simple:

  type Cont[N] = ContT[IO, N, N]
	//計算の経過を示すためにprint付きのヘルパーを定義しておく
  def pure[A:Show](a:A):Cont[A] = ContT.liftF(putStrLn(a) *> IO.pure(a))
  def times[N: Numeric](a:N, b:N):Cont[N] = ContT.liftF(putStrLn(s"$a times $b") *> IO.pure(Numeric[N].times(a,b)))
  def putStrLn[A:Show](a:A) = IO(println(a.show))

  def times[N: Numeric: Show](lst: List[N]):Cont[N] = lst match {
      case Nil => pure(Numeric[N].one)
      case x::xs =>
        if(x == 0) ContT.shiftT(_ => pure(Numeric[N].zero))
        else times(xs).flatMap(times(x,_))
    }

  def run = for {
    _ <- ContT.resetT(times(List(1,2,3,4,5))).eval >>= putStrLn
    _ <- putStrLn("****************************************")
    _ <- ContT.resetT(times(List(1,2,0,4,5))).eval >>= putStrLn
  } yield ()

// 1
// 5 times 1
// 4 times 5
// 3 times 20
// 2 times 60
// 1 times 120
// 120
// ****************************************
// 0
// 0

まとめ

サマースクールでは限定継続を使ったテクニックが他にも紹介されていたので今後も試してみようと思います。