セカイノカタチ

世界のカタチを探求するブログ。関数型言語に興味があり、HaskellやScalaを勉強中。最近はカメラの話題も多め

「関数プログラミング 珠玉のアルゴリズムデザイン」をScalaで実装してみる 第5章

前回: 「関数プログラミング 珠玉のアルゴリズムデザイン」をScalaで実装してみる 第4章 - セカイノカタチ

関数プログラミング 珠玉のアルゴリズムデザイン

関数プログラミング 珠玉のアルゴリズムデザイン

この本のコードをScalaに置き換えながらじわじわ進むシリーズの第5弾です。

組和の整列

組和という単語でググっても何も出てきません。

・・・。

張り切って行ってみましょう!

今回の問題はこんなかんじです。

集合Aを「線形順序付き集合」とします。「線形順序付き集合」とは、小さい順に並べることができる集合。つまり、整数とか実数とかですね。

関数⊕(丸にプラス。面倒なので以下 + とします)があったとします。これは単調な二項演算子で、集合Aからx,y,x',y'の4つの数字を取り出した時に「x ≦ x'」で「y ≦ y'」ならば「x + y ≦ x' + y'」となることを条件に適当に定義します。

この関数を使いAのリストを2つ取って以下の様な計算を行う関数について考えます。

sortsum :: (Ord a, Num a) => [a] -> [a] -> [a]
sortsum xs ys = sort [x + y | x <- xs, y <- ys]

Scalaで書くとこんな感じです。

def sortsum(xs: List[Int], ys: List[Int]): List[Int] = 
  (for (x <- xs; y <- ys) yield x + y).sorted

xs × ys の全パターンに対して+を適用してその結果をソートすると読めます。

この時の比較回数は、最大 O(n2 log n) 回になると思います。

この条件のもとでは、これ以上の低減は出来ないとあります。

(A,+)がアーベル群であった時、比較回数を O(n2) まで減らせるそうです*1

ちなみにアーベル群とは、結合則を満たし単位元と逆元を持つ群に対して、交換則を加えるとアーベル群になります。

実装の準備

問題からして何やら難しいですが、気を取り直して実装の為の下準備をしていきます。

今回のポイントは、素直に + を計算するのではなく、逆数を求める negate と差演算を行う (以下 - )を使ってアルゴリズムを組み立てることです。

そのための前提として以下の事実が必要になります。

(5.1) x + y = x - negate y
(5.2) x - y < x' - y' ≡ x - x' < y - y'

(5.1)を検証するのはたやすいとありますが、 - negate が打ち消しあうので x + yになりますね。簡単ですね。

そして「(5.2)を検証するにはアーベル群の性質をすべて使わなくてはならない。これについては練習問題としておく」

・・・おい。(^^;

ここでいきなり練習問題が出されています。(^-^;;

焦りますが、慌てず検証してみます。

x - y < x' - y' ≡ x - x' < y - y'
-- 右側を何とかするために[- y + x']を加える
x - y < x' - y' ≡ x - x' - y + x' < y - y' - y + x'
-- 整理
x - y < x' - y' ≡ x - y < - y' + x'
-- 交換
x - y < x' - y' ≡ x - y < x' - y' 

無事、右辺と左辺が同じになりました。

これで準備完了です。

実装

実装に入ります。

差計算の結果に元のリストの何番目から取ったあたいかを示すラベルを付けて返す関数 subs を定義します。

type Label a = (a,(Int,Int))

subs :: (Num a) => [a] -> [a] -> [Label a]
subs xs ys = [(x - y, (i,j)) | (x,i) <- zip xs [1..],(y,j) <- zip ys [1..]]

Scalaで書くと。。

def subs(xs: List[Int], ys: List[Int]): List[Label] = {
  for {
    (i, x) <- Stream.from(1).zip(xs).toList
    (j, y) <- Stream.from(1).zip(ys).toList
  } yield (x - y, (i,j))
}

ラベルは、(x,y)の座標を表します。

この関数を使ってsortsumを再定義します。

sortsums xs ys = map fst (sortsubs xs (map negate ys))

sortsubs xs ys = sort (subs xs ys)

わざわざ x + y = x - negate y を使い、ラベル付き差演算をした結果からラベルを外しています。(^^;

def sortsums(xs: List[Int], ys: List[Int]) = sortsubs(xs, ys.map(_ * -1)).map(_._1)

def sortsubs(xs: List[Int], ys: List[Int]):List[Label] = subs(xs,ys).sorted

Scalaで書いてもあまり変わりませんね。

次にラベルに更にタグを張ってマージによって擦り合わせる謎の関数を定義します。

table :: (Ord a,Num a) => [a] -> [a] -> [(Int,Int,Int)]
table xs ys = map snd (map (tag 1) xxs `merge` map (tag 2) yys)
              where xxs = sortsubs xs xs
                    yys = sortsubs ys ys
tag i (x,(j,k)) = (x,(i,j,k))

merge :: (Ord a) => [a] -> [a] -> [a]
merge [] ys = ys
merge xs [] = xs
merge (x:xs) (y:ys) | x < y     = x : merge xs (y:ys)
                    | otherwise = y : merge (x:xs) ys

merge は、本では ^^ と重ねあわせた演算子となっていますが、無いので定義しています。

特に sortsubs xs xs sortsubs ys ys が謎いですね。(^^;

Scalaにしてみます。

type TaggedLabel = (Int,(Int,Int,Int))
def table(xs: List[Int], ys: List[Int]):List[(Int,Int,Int)] = {
  val xxs:List[TagedLabel] = sortsubs(xs,xs).map(tag(1,_))
  val yys:List[TagedLabel] = sortsubs(ys,ys).map(tag(2,_))
  merge(xxs, yys).map(_._2)
} 
def tag(i:Int, x: Label): TaggedLabel = x match {case (x,(j,k)) => (x,(i,j,k))}
def merge[A: Ordering](list1: List[A], list2: List[A]): List[A] = {
  (list1,list2) match {
    case (Nil, ys) => ys
    case (xs, Nil) => xs
    case (x::xs, y::ys) => 
      if (implicitly[Ordering[A]].lt(x,y)) x :: merge(xs, y::ys)
      else y :: merge(x::xs, ys)
  }
} 

sortsubs xs xs , sortsubs ys ys とは何か?

この章最大の謎の部分です。

sortsubs xs ysではなく、xs xsです。xsを2回渡しているのです。

「何かの間違いかな?」と思いましたが、違うようです。

わからなくて挫折しかかっていたのですが、下記のブログからヒントを頂き、理解することが出来ました。

【随時追記予定】読書メモ:関数プログラミング 珠玉のアルゴリズムデザイン - claustrophobia

この方凄いですね。^^;

自分なりに解説します。

まず、思い出して欲しいのは、この式です。

(5.2) x - y < x' - y' ≡ x - x' < y - y'

わざわざ検証したのに忘れてました。(^^;

これによると、任意のx,yとx',y'をマイナスしたものと、x,x'とy,y'をマイナスしたものの大小関係が合同と言っています。

これを利用しているわけです。

下の表を御覧ください。

f:id:qtamaki:20150312134240p:plain

これは、 xs=[1,2,3,4], ys=[5,11,6,8] とした時の組み合わせ表です。

上段の2つの表から、任意のx,yとx',y'を抜き出しています。

x :2 - y :6 (i:2,j:3)
x':4 - y':5 (k:4,l:1)

x,y,x',y'の関係はこんな感じに直交します。

この時、図をxとyで縦に読むように対応させると、前提「(5,2)」に従って下記が言えるということです。

x:2 - x':4 < y:6 - y':5
    -2     <     1

i,jの表とk,lの表は、それぞれ sortsubs に[1,2,3,4]と[5,11,6,8]を渡した時に発生するインデックス毎の計算結果の表です。

大小関係が保持されていることがわかります。

本の table の処理では、(x-x')の表に1(Tag1)、(y-y')の表に2(Tag2)のタグを付けた上で、計算結果が小さい順にソートしマージしています。

そのためリストの前の方に来る値の方が小さくなるわけですが(当然ですね)、Tag1由来の要素がTag2由来の要素より先に来るということは、(x-x')<(y-y')なので、(x-y)<(x'-y')が言えるわけです。

このことを利用してソートするというのが、この章のミソなわけです。

そのための下準備として、(x-x'),(y-y')のインデックスを作成するのが、 sortsubs xs xs , sortsubs ys ys の正体ということになります。

大変難しいですね。(^^;

分割統治

さて、前半が終わり(!!)後半戦、いつもの様に分割統治が始まるわけですが、今回はメインではないっぽいので端折ります。(ぉ

sortsubsが、そのままだと使えないのでsortsubs'を定義します。

原理としては、以下の恒等式を利用してリストを分割しているらしいです。

(xs ++ ys) - (xs' ++ ys')
    = (xs - xs') ++ (xs - ys') ++ (ys ++ xs') ++ (ys ++ ys')

コードを示します。

import Data.List
import Data.Array

type Label a = (a,(Int,Int))


sortsums :: (Ord a, Num a) => [a] -> [a] -> [a]
sortsums xs ys = map fst (sortsubs xs (map negate ys))

sortsubs :: (Ord a, Num a) => [a] -> [a] -> [Label a]
sortsubs xs ys = sortBy (cmp (mkArray xs ys)) (subs xs ys)

subs :: (Ord a, Num a) => [a] -> [a] -> [Label a]
subs xs ys = [(x - y, (i,j))|(x,i) <- zip xs [1..], (y,j) <- zip ys [1..]]
cmp a (x,(i,j)) (y,(k,l)) = compare (a ! (1,i,k)) (a ! (2,j,l))
mkArray xs ys = array b (zip (table xs ys) [1..])
                where b = ((1,1,1),(2,p,p))
                      p = max (length xs) (length ys)
table xs ys   = map snd (map (tag 1) xxs `merge` map (tag 2) yys)
                where xxs = sortsubs' xs
                      yys = sortsubs' ys
tag i (x,(j,k)) = (x,(i,j,k))

sortsubs' :: (Ord a, Num a) => [a] -> [Label a]
sortsubs' [] = []
sortsubs' [w] = [(w - w,(1,1))]
sortsubs' ws = foldr1 merge [xxs, map (incr m) xys,
                             map (incl m) yxs, map (incb m) yys]
  where xxs = sortsubs' xs
        xys = sortBy (cmp (mkArray xs ys)) (subs xs ys)
        yxs = map switch (reverse xys)
        yys = sortsubs' ys
        (xs,ys) = splitAt m ws
        m = length ws `div` 2

incl m (x,(i,j)) = (x,(m+i,j))
incr m (x,(i,j)) = (x,(i,m+j))
incb m (x,(i,j)) = (x,(m+i,m+j))
switch (x,(i,j)) = (negate x,(j,i))

merge :: (Ord a) => [a] -> [a] -> [a]
merge [] ys = ys
merge xs [] = xs
merge (x:xs) (y:ys) | x < y     = x : merge xs (y:ys)
                    | otherwise = y : merge (x:xs) ys

さらにScalaのコードです。

type Label = (Int,(Int,Int))
type TaggedLabel = (Int,(Int,Int,Int))
type LabeledArray = Map[(Int,Int,Int), Int]

def sortsums(xs: List[Int], ys: List[Int]) = sortsubs(xs, ys.map(_ * -1)).map(_._1)

def sortsubs(xs: List[Int], ys: List[Int]):List[Label] = subs(xs,ys).sortWith(cmp(mkArray(xs,ys)))

def subs(xs: List[Int], ys: List[Int]): List[Label] = {
  for {
    (i, x) <- Stream.from(1).zip(xs).toList
    (j, y) <- Stream.from(1).zip(ys).toList
  } yield (x - y, (i,j))
}

def cmp(a:LabeledArray)(l1:Label, l2:Label): Boolean = (l1,l2) match {
  case ((x,(i,j)),(y,(k,l))) => a((1,i,k)) - a((2,j,l)) < 0
}

def mkArray(xs:List[Int], ys:List[Int]): LabeledArray = {
  table(xs,ys).zip(Stream.from(1)).toMap
}

def table(xs: List[Int], ys: List[Int]):List[(Int,Int,Int)] = {
  val xxs:List[TaggedLabel] = sortsubs2(xs).map(tag(1,_))
  val yys:List[TaggedLabel] = sortsubs2(ys).map(tag(2,_))
  merge(xxs, yys).map(_._2)
} 

def merge[A: Ordering](list1: List[A], list2: List[A]): List[A] = {
  (list1,list2) match {
    case (Nil, ys) => ys
    case (xs, Nil) => xs
    case (x::xs, y::ys) => 
      if (implicitly[Ordering[A]].lt(x,y)) x :: merge(xs, y::ys)
      else y :: merge(x::xs, ys)
  }
} 

def tag(i:Int, x: Label): TaggedLabel = x match {case (x,(j,k)) => (x,(i,j,k))}

def switch:PartialFunction[Label,Label] = {case (x:Int,(i,j)) => (x * -1, (j,i))}

def sortsubs2(ws: List[Int]):List[Label] = {
  lazy val m = ws.length / 2
  lazy val (xs,ys) = ws.splitAt(m)
  lazy val yys = sortsubs2(ys)
  lazy val xys = subs(xs,ys).sortWith(cmp(mkArray(xs,ys)))
  lazy val yxs = xys.reverse.map(switch)
  lazy val xxs = sortsubs2(xs)
  ws match {
    case Nil => Nil
    case w::Nil => List((w-w,(1,1)))
    case ws => val list = List(xys.map(incr(m, _)), yxs.map(incl(m, _)), yys.map(incb(m, _)))
               list.foldRight(xxs)(merge)
  }
}

def incl(m:Int, a: Label): Label = {
  val (x, (i,j)) = a
  (x,(m+i,j))
}

def incr(m:Int, a: Label): Label = {
  val (x, (i,j)) = a
  (x,(i,m+j))
}

def incb(m:Int, a: Label): Label = {
  val (x, (i,j)) = a
  (x,(m+i,m+j))
}

def switch(a: Label): Label = {
  val (x, (i,j)) = a
  (x * -1,(i,j))
}

まとめ

実は、今回Scalaのコードもかなり難しかったです。

引数のパタンマッチが使えると良いのになー。と思いました。(^^;

いつものように完全なコードがGithubにおいてあります。

qtamaki/pearls · GitHub

*1:しかし、追加で Cn2 log n 回の演算が必要とのことでなんのこっちゃです。(^^;