セカイノカタチ

世界のカタチを探求するブログ。関数型言語に興味があり、HaskellやScalaを勉強中。最近はカメラの話題も多め

「関数プログラミング 珠玉のアルゴリズムデザイン」をScalaで実装してみる 第6章

前回: 「関数プログラミング 珠玉のアルゴリズムデザイン」をScalaで実装してみる 第5章 - セカイノカタチ

関数プログラミング 珠玉のアルゴリズムデザイン

関数プログラミング 珠玉のアルゴリズムデザイン

この本のコードをScalaに置き換えながらじわじわ進むシリーズの第6弾です。

まだ心折れていなかったようです。驚きですね。(^^;

小町算

小町算とは、一桁の数のリストに+と×を差し込んで全体の式の値が100になるものをすべて列挙するという問題です。

100 = 12 + 34 + 4 * 6 + 7 + 8 + 9
100 = 1 + 2 * 3 + 4 + 5 + 67 + 8 + 9

カッコが使えないのと、隣り合う数字をつなげて2桁以上の数字を作れるあたりが注意点でしょうか。

前半戦

理屈はよくわからないですが、運算を繰り返して式を得ます。

この途中に「foldrの融合則」というものが出てきます*1が、このブログで以前に解説していますので、理解の助けになるかもしれません。

qtamaki.hatenablog.com

それだけだと、辛いかもしれないので、分かる範囲で解説します。

「foldrの融合則」とは以下3つの規則を満たすと、gをhに融合してfoldrの計算効率が上がるというものです。

  • f が正格関数であること
  • f a = b であること
  • 任意のx,yに対し、f (g x y) = h x (f y) であること。

今回の例で言うと、fは filter (ok . value) とのことなので、f,g,hが下記のような形で展開可能となります。

f :: [Candidate] -> [Candidate]
f = filter (ok . value)
g :: datum -> [Candidate] -> [Candidate]
g = extend
h :: datum -> [Candidate] -> [Candidate]
h = extend'
f (g datum [candidate]) = h datum (f [candidate])

ここでは、extend'の特性を示しているだけなので、実際にextend'の実装をどうするかについては言及がありません。

これは、具体的なextend'の実装を見る前に、extend'が持つべき性質を洗い出して実装を逆算していくのに使うためです。

この本では、このように実装に入る前にインターフェースを仮定することによって式を展開していく手法が随所に見られます。

この章でも、仮定の連続によって式がどんどん展開されるので内容を理解しながら読んでいくのが相当骨です。

時には、途中の状態を微妙に端折って次の段階に進んでいたりすることもあり、難易度が高い式変形があります。

例えば、下記の変形ですが、法則(6.6)と「map (f · g) = map f · map g」を組合せて展開しています。

= fork (map fst, map snd) . map (fork (f,g))
= fork (map (fst . fork (f,g)), map (snd . fork (f,g)))

ここは、中間の状態として下記の式を補うとわかりやすいと思います。

= fork (map fst . map (fork (f,g)), map snd . map (fork (f,g)))

同じく、下記の展開については、それぞれの行間に中間式を補います。

= zip . fork (extend x, modify x . map value) -- 6.4
= zip . cross (extend x, modify x) . fork (id, map value) -- 6.7
= zip . cross (extend x, modify x) . unzip . map (fork (id, value)) -- 6.8

extend xに仮にidを合成してf.gの形にします。

= zip . fork (extend x . id , modify x . map value)

こちらも仮にidを補っています。

= zip . cross (extend x, modify x) . fork (map id, map value)

これらの難関を突破しながら、なんとか前半戦の最後に以下の2つの式にたどり着きます。

solutions = map fst . filter (good.snd) . foldr expand []
expand = filter (ok.snd) . zip . cross (extend x, modify x) . unzip

本には、「2つの運算を合わせて、以下に到達する。」と書いてあるのですが、solutionsの式を導き出すような運算は全くありません。(^^;

expandの方も、運算結果にはunzipの後に「map (fork (id, value))」が合成されているのですが、なくなってます(((^^;;;;;;

この辺まったく理解できていないので、解説してくださる方がいらっしゃいましたら、助けてください。(ぉ

追記: などと悲鳴をあげていたら、サポートサイトに解説が追加されていました。

06. 小町算

凄い。やっと理解出来ました。(^^;

後半戦

そして後半戦ですが、あんなに苦労した前半戦の運算をとりあえず棚上げして、さくっと問題を解き始めます。(^^;

type Expression = [Term]
type Term = [Factor]
type Factor = [Digit]
type Digit = Int

valExpr :: Expression -> Int
valExpr = sum . map valTerm
valTerm :: Term -> Int
valTerm = product . map valFact
valFact :: Factor -> Int
valFact = foldl1 (\n d -> 10 * n + d)

good :: Int -> Bool
good v = (v == 100)

expressions :: [Digit] -> [Expression]
-- expressions = concatMap partitions . partitions
expressions = foldr extend []

partitions :: [a] -> [[[a]]]
partitions [] = [[]]
partitions (x:xs) = [[x] : p | p <- ps] ++ [(x:ys):yss | ys:yss <- ps]
                    where ps = partitions xs

extend :: Digit -> [Expression] -> [Expression]
extend x [] = [[[[x]]]]
extend x es = concatMap (glue x) es
glue :: Digit -> Expression -> [Expression]
glue x ((xs:xss):xsss) = [((x:xs):xss):xsss,
                           ([x]:xs:xss):xsss,
                           [[x]]:(xs:xss):xsss]

epressionsに関しては、2パターンあって、extendを使用するほうが、前半戦の説明あった部分だと思います。

特に難しいことはないと思います。

Scalaのコードを示します。

object Komachi {

  type Digit = Int
  type Factor = List[Digit]
  type Term = List[Factor]
  type Expression = List[Term]

  def valExpr(ex: Expression): Int = ex.map(valTerm).foldLeft(0)(_+_)

  def valTerm(te: Term): Int = te.map(valFactor).foldLeft(1)(_*_)

  def valFactor(fa: Factor): Int = fa.tail.foldLeft(fa.head)((n:Int, d:Int) => 10 * n + d ) 

  def good(v:Int):Boolean = v == 100

  //def expressions(ds: List[Digit]): List[Expression] = partitions(ds).flatMap(partitions)
  def expressions(ds: List[Digit]): List[Expression] = ds.foldRight[List[Expression]](Nil)(extend)

  def extend(x: Digit, exs: List[Expression]): List[Expression] = {
    exs match {
      case Nil => List(List(List(List(x))))
      case es => es.flatMap(glue(x, _))
    }
  }

  def glue(x:Digit, ex: Expression): List[Expression] = ex match {
    case ((xs::xss)::xsss) => List(((x::xs)::xss)::xsss,
                              (List(x)::xs::xss)::xsss,
                              List(List(x))::(xs::xss)::xsss)
  }

  def partitions[A](a: List[A]): List[List[List[A]]] = a match {
    case Nil => List(Nil)
    case x::xs => {
      lazy val ps = partitions(xs)
      val a = for(p <- ps) yield List(x)::p
      val b = for(p <- ps if p != Nil) yield {
        val ys = x :: p.head
        val yys = p.tail
        ys :: yys 
      }
      a ++ b
    }
  }
}

かなり、ゴチャゴチャしましたが、素直に書き換え出来ました。

型推論も概ねうまく行っています。(partitionsの所で一部List[Lsit[Lsit[Any]]]になってしまったので若干書き換えています)

そして、先ほどの運算で得られた式に書き換えます。

modify x (k,f,t,e) = [(10*k,k*x+f,t,e), (10,x,f*t,e),(10,x,1,f*t+e)]

good c (k,f,t,e) = (f*t+e == c)
ok c (k,f,t,e) = (f*t+e <= c)

solutions c = map fst . filter (good c . snd) . foldr (expand c) []
expand x c = filter (ok c . snd) . zip . cross (extend x, modify x) . unzip

差分のみを示しました。

(k,f,t,e)って何!?(゜o゜;

「単に計算効率を良くするものだ」と言い切ってますが、理解できません。(^^;;;

そして、expandの2パターンあって、「少し単純にできる」と下記のように書き換えられます。

expand c x [] = [([[[x]]], (10,x,1,0))]
expand c x evs = concatMap (filter (ok c . snd) . glue x) evs
glue x ((xs:xss):xsss,(k,f,t,e)) =
  [(((x:xs):xss):xsss,(10*k,k*x+f,t,e)),
   (([x]:xs:xss):xsss,(10,x,f*t,e)),
   ([[x]]:(xs:xss):xsss,(10,x,1,f*t+e))]

単純ってレベルじゃねーぞ!?

・・・と、いうことで細部は理解できませんでしたが、とにかく動いているようです。(ぉ

Scalaにします。

object Komachi2 {

  type Digit = Int
  type Factor = List[Digit]
  type Term = List[Factor]
  type Expression = List[Term]
  type Digits = (Digit,Digit,Digit,Digit)

  def valExpr(ex: Expression): Int = ex.map(valTerm).foldLeft(0)(_+_)

  def valTerm(te: Term): Int = te.map(valFactor).foldLeft(1)(_*_)

  def valFactor(fa: Factor): Int = fa.tail.foldLeft(fa.head)((n:Int, d:Int) => 10 * n + d ) 

  def modify(x:Digit, ds:Digits): List[Digits] = {
    val (k,f,t,e) = ds
    List((10*k,k*x+f,t,e),(10,x,f*t,e),(10,x,1,f*t+e))
  }

  def good(c: Int, ds:Digits):Boolean = {
    val (k,f,t,e) = ds
    (f*t+e) == c
  }

  def ok(c: Int, ds:Digits):Boolean = {
    val (k,f,t,e) = ds
    (f*t+e) <= c
  }

  def solutions(c: Int, ds: List[Digit]): List[Expression] = 
    ds.foldRight[List[(Expression,Digits)]](Nil)((d,xs) => expand(c, d, xs)).filter((x) => good(c, x._2)).map(_._1)

  def expand(c:Int, x:Digit, xs: List[(Expression, Digits)]): List[(Expression, Digits)] = xs match {
    case Nil => List((List(List(List(x))), (10,x,1,0)))
    case evs => evs.flatMap((y) => glue(x, y).filter((z) => ok(c, z._2)))
  }

  def glue(x:Digit, ex: (Expression, Digits)): List[(Expression, Digits)] = ex match {
    case (((xs::xss)::xsss), (k,f,t,e)) => List(
                              (((x::xs)::xss)::xsss, (10*k,k*x+f,t,e)),
                              ((List(x)::xs::xss)::xsss, (10,x,f*t,e)),
                              (List(List(x))::(xs::xss)::xsss,(10,x,1,f*t+e)))
  }
}

こちらもゴチャゴチャした割に素直に書き換えられました。

オマケ

オマケに本では、サクッと流されている適切に出力を調整する部分を書いてみたので置いておきますね。(Haskellのみ)

mkstr s = foldl1 (\x y -> x ++ s ++ y) showFact xs = mkstr "" $ map show xs showTerm xs = mkstr "x" $ map showFact xs showExpr xs = mkstr "+" $ map showTerm xs showAll = mapM_ print $ map (\x -> "100=" ++ showExpr x) . solutions 100 $ [1..9]

ほんとうはデータ型を定義してShowのインスタンスにすればカッコいいんだけど、やってないです。(^^;

まとめ

今回Scalaのコードは、割りと簡単に書けました。

運算の過程はかろうじて追えた気がしますが、結論が理解できませんでした。残念。><

いつものように完全なコードがGithubにおいてあります。

qtamaki/pearls · GitHub

*1:しかも、立て続けに二回出てきます。^_^;