セカイノカタチ

世界のカタチを探求するブログ。関数型言語に興味があり、HaskellやScalaを勉強中。最近はカメラの話題も多め

マーブルワーズ

「関数プログラミング 珠玉のアルゴリズムデザイン」をScalaで実装してみる 第2章

先日、 Scala Advent Calendar 2014 - Qiita で「関数プログラミング 珠玉のアルゴリズムデザイン」の第1章をScalaで実装してみましたが。味をしめて第2章もやってみたいと思います。

関数プログラミング 珠玉のアルゴリズムデザイン

関数プログラミング 珠玉のアルゴリズムデザイン

第2章の題材は「上位者問題」です。

この問題、思ったんですけど、CodeIQで結城 浩さんが出題した「クロッシング問題」と本質的に同じですね。

挑戦者求む!【アルゴリズム】交差点をすばやく数えよう! by The Essence of Programming 結城 浩│CodeIQ

この問題、自分は解けませんでした*1。(^^;

上位者問題

上位者問題とは、ある順序のある要素の集合(ここでは文字列です)を左から見ていった時に、右側に来る文字列の中に、自分自身より大きな要素が幾つあるか?という事を調べます。

そして、その数の一番大きい物を答える問題となります。

G E N E R A T I N G
5 6 2 5 1 4 0 1 0 0

例として、このようになるとされています。

一番左のGよりアルファベット順で後ろに来る文字は「N,R,T,I,N」の5つですね。

これを調べるための一番簡単なコードが下記のように例示されています。

-- msc (Max surpasser count)
msc :: Ord a => [a] -> Int
msc xs = maximum [scount z zs | z:zs <- tails xs]
scount x xs = length (filter (x<) xs)

tailsとは下記のような関数です。

tails [] = []
tails (x:xs) = (x:xs) : tails xs

Haskell標準のものとは違い、末尾が空リストになりません。(tails [1,2,3] -> [[1,2,3],[2,3],[3],[]]となるのが標準的)

つまり下記のように要素を外して行ったリストを作り、filterで自分自身以下の要素を排除し、長さを測っています。

G E N E R A T I N G
  E N E R A T I N G
    N E R A T I N G
      E R A T I N G

直感的なアルゴリズムをベタに実装しているだけですが、逆に言うと、思った通りに書けば、ちゃんと動くというあたりがHaskellの良い所です。

Scalaで実装してみる

実装をScalaに直します。

def msc[A : Ordering](xs: List[A]): Int = 
  xs.tails.collect({case (z::zs) => scount(z,zs)}).max

def scount[A : Ordering](x: A, xs: List[A]): Int = 
  xs.filter(y => implicitly[Ordering[A]].lt(x, y)).size

実装に際して、Scala独自のテクニックが使われています。

まず、「[A : Ordering]」と「implicitly」の部分ですが、これはHaskellの「Ord a =>」の部分に対応し、大小比較を抽象化した型クラスを扱うための機能です。

ギョーカイ用語で"implicit parameter"というやつで、「Ordering機能のあるAを引数に取る」という意味になります。

Haskellならば、「Ord a =>」と1ヶ所に書けばあとは適当に推論してくれるのですが、Scalaの場合、implicitlyしたい関数とそこへ至る関数に「[A : Ordering]」を書かなければならないところがちょっとアレです。(^^;

もう一つ「xs.tails.collect ...」の部分ですが、Haskellで独自のtailsを定義していたように、Scalaでも標準のtailsは、末尾に空リストがついてくるので、この部分が「case (z::zs)」のパタンマッチに合いません。

これを回避するために、mapではなく、collectを使っています。

少し、説明が長くなってしまいますが、「case (z::zs) => ...」の部分は"PartialFunction"という機能を使っています。

以下にPartialFunctionとググると必ず出てくるid:yuroyoro氏のページを貼ります。

ScalaのPartialFunctionが便利ですよ - ( ꒪⌓꒪) ゆるよろ日記

PartialFunctionとは、関数とcase文が融合したようなシロモノで、関数を呼び出さずとも「isDefinedAt」で引数に適合するかどうか調べることができます。

そのため、下記のように書くことで、List()がcaseにマッチしない場合に呼び出さない!という選択を行うことができます。

list.filter{pf.isDefinedAt}.map{pf}

ただ、これだと若干野暮ったいため、このセットをやってくれる標準関数が用意されています。

それが「collect」です!

今回Scala的に面白いのは以上で、後は淡々とアルゴリズムの説明になります。(^^;

分割統治

出ました!

今回も、分割統治で解決します。

分割統治が出てくるってことは、大体「O(n2)」を「O(n log n)」にしたい時です。今回もその時がやってきたようです。(^^;;

Haskell(っぽい)関数を見てみましょう*2

msc (xs ++ ys) = join (msc xs) (msc ys)

考え方としては、msc に渡されたリストを例えば半分に分割して、それぞれに対して小さいmscを実行することができたなら、それを再帰的に適用していけば、半分の半分の半分の半分の・・・。ということで、「n log 2」にすることができます。

このことを利用するのですが、一旦mscから離れます。

より一般的に、各要素(文字)と、その上位者数のペアを求める関数"table"の開発を目指すことにします。そのうち一番大きな上位者数が最大上位者数になるという目論見です。

tableのイメージ: [(t,0), (a,3), (m,0), (a,2), (k,0), (i,0)]

table関数の単純な定義です。(先ほどのscountを利用)

table xs = [(z, scount z zs) | z:zs <- tails xs]

この関数に分割統治の技法を適用したと仮定するとこんな感じになるはずです。

table (xs ++ ys) = join (table xs) (table ys)

そして、tableのxsとysを分割するためには、tails関数のイメージが下記の関数ようになる必要があります。

table (xs ++ ys) = map (++ys) (tails xs) ++ tails ys

これは、ぱっと見わかりにくいと思いますが、具体的には下記のようなリストを生成するイメージです。

SEKAINO KATACHI
xs      ys
SEKAINO ++ KATACHI
 EKAINO ++ KATACHI
  KAINO ++ KATACHI
   AINO ++ KATACHI
    INO ++ KATACHI
     NO ++ KATACHI
      O ++ KATACHI
++
           KATACHI
            ATACHI
             TACHI
              ACHI
               ...

SEKAINOKATACHIをxsとysに分割した場合、xsのtailsにはすべてysをくっつける必要があり、xsのtailsの後に、ysのtailsをくっつけると、tails (xs ++ ys)と同じ結果を得られるという事です。

そしてこれを元に「運算」します。

元の本では、ぐちゃぐちゃ書いてありますが、式の意味を変えずに形を変えていっています。目的は、joinを定義できるような形の発見です。

その結果、下記のような関数に行き着くということです。

join txs tys = [(z, c + tcount z tys) | (z,c) <- txs] ++ tys
tcount z tys = scount z (map fst tys)

運算中、「ys -> map fst (table ys)」という周りくどい変換を行っていますが、joinとtcountの引数がtxs,tys(xs,ysにtableを適用したもの(c,z)のペアの形式)であるため、tysからysに変換する必要があるのですが、運算中は、そんな事知る由もないのに未来を予知しているのがちょっとアレです。(^^;

しかも、ここまで頑張ったのに、joinがO(n2)なので、意味がありません(汗)。

ここからさらに、txs,tysを昇順に維持しながらすり合わせるように演算していくことで、O(n log 2)に持って行くことができます。

このへんまで読むと、「これってマージソートじゃん!!クロッシング問題じゃん!!!」となるわけです(知ってれば)。

ここから、式を変換していくわけですが、難しいので最終的な形を示します。(^^;

table [x] = [(x,0)]
talbe xs  = join (m-n) (table ys) (table zs)
            where m = length xs
                  n = m `div` 2
                  (ys,zs) = splitAt n xs

join 0 txs [] = txs
join n [] tys = tys
join n txs@((x,c) : txs') tys@((y,d) : tys')
    | x < y = (x,c+n) : join n txs' tys
    | x >= y = (y,d) : join (n-1) txs tys'

table -> tableの再帰とjoin -> joinの再帰が行われています。

tableの再帰は、リストをsplitAtによって、2分割しながら降下していき、xsの要素が1個になったところでjoinに引き渡され、joinによって、より小さい要素が前に来るようにマージされながら組み合わされていきます。

x<yの時に、マージしながら上位者の数だけカウントアップされます。つまり、yが大きいということは、ys全体がxより大きいことを示すため、nをlength ysとして、カウントするという寸法です。

Scalaでの実装

Scalaあんまり関係なくなってきた気もしますが、Scalaでの実装例を示します。

def table2[A : Ordering](xs: List[A]): List[(A, Int)] = xs match { 
  case List() => List.empty
  case List(x) => List((x, 0))
  case xs => 
    val m = xs.length
    val n = m / 2
    val (ys, zs) = xs.splitAt(n)
    join2(m - n, table2(ys), table2(zs))  
}

def join2[A : Ordering](n: Int, txs: List[(A, Int)], tys: List[(A, Int)]): List[(A, Int)] = (n, txs, tys) match {
  case (0, txs, Nil) => txs
  case (_, Nil, tys) => tys
  case (n, txs@((x,c) :: txs2), tys@((y,d) :: tys2)) =>
    if (implicitly[Ordering[A]].lt(x, y)) (x,c + n) :: join2(n, txs2, tys)
    else (y, d) :: join2(n-1, txs, tys2)
  case _ => ???
}

Orderingのテクニックはそのままです。

パタンマッチのために、match式を使っています。あと、ガードがif式になっているのと、パタンを網羅していないとワーニングが出るので、適当なパタンを補ってます。(^^;

アズパターンもそのまま使えます。

全体的にHaskellの方がノイズ成分が少ないような気がしますが、ほぼそのまま記述可能です。

*1:正確にはO(n2)のアルゴリズムでやって時間が足りなかった

*2:実際は、(xs ++ ys)というパタンマッチは不可能です。「例えばこんな感じ」という意味です