#CP AVL木をパクった話 Python

友人のAVL木のソースコードをパクってAVL木の高速化を目指しました(ダメでした)

きっかけ

pythonには標準ライブラリで平衡二分探索木を作れるものがありません。
競技プログラミングでは平衡二分探索木を使うことも多く、そのため友人が自力で作っていたのですがどうもデータ数に比べて処理が遅いです。そのAVL木のソースコードは以下のページにあります。
juppy.hatenablog.com

計算量はO(NlogN)のはずなので10^5くらいなら余裕のはずですが、pythonでは10^5のデータをinsertするだけで0.6秒ほどかかります。そのため、searchなどをするAtCoderの2秒制限は軽くオーバーします。なんとか高速化できないものかと彼のAVL木の表現方法を、ノード1個をクラスとして定める方式から、木全体を1個の入れ子にしたリストで表現してみました。リストの方がアクセス高速じゃね?とか思ったわけです。結果からいうと、あんまり変わりませんでした

ソースコード

class Avltree:
    """ Avl木を表現するクラス """
    def __init__(self):
        self.avl = [None,None,None,0]

    def node(self,key):
        return [None,key,None,0]

    # search(x) xが存在するならTrueをしないならFalseを返す
    # 使用法 Instance.search(x)
    def search(self,key,avl=None):
        if avl == None:
            avl = self.avl
        if key == avl[1]:
            return True
        elif key < avl[1]:
            if avl[0] == None:
                return False
            else:
                return self.search(key,avl[0])
        elif avl[1] < key:
            if avl[2] == None:
                return False
            else:
                return self.search(key,avl[2])
    
    # key未満で最大の要素を検索する。なければNoneを返す。
    # 使用法 Instance.search_lower(x)
    def search_lower(self,key,key_lower=None,avl=0):
        if avl == 0:
            avl = self.avl
        if avl == None:
            return key_lower
        elif key < avl[1]:
            if avl[0] == None:
                return key_lower
            else:
                return self.search_lower(key,key_lower,avl[0])
        elif avl[1] < key:
            key_lower = avl[1]
            if avl[2] == None:
                return key_lower
            else:
                return self.search_lower(key,key_lower,avl[2])
        #avl[1] == keyの場合
        if avl[0] == None:
            return key_lower
        else:
            if key_lower == None:
                return self.end_higher(avl[0][1],avl[0])
            else:
                return max(key_lower,self.end_higher(avl[0][1],avl[0]))
    
    # keyより大きい最小の要素を検索する。なければNoneを返す。
    # 使用法 Instance.search_higher(x)
    def search_higher(self,key,key_higher=None,avl=0):
        if avl == 0:
            avl = self.avl
        if avl == None:
            return key_higher
        if  key < avl[1]:
            key_higher = avl[1]
            if avl[0] == None:
                return key_higher
            else:
                return self.search_higher(key,key_higher,avl[0])
        if avl[1] < key:
            if avl[2] == None:
                return key_higher
            else:
                return self.search_higher(key,key_higher,avl[2])
        #self.key == keyの場合
        if avl[2] == None:
            return key_higher
        else:
            if key_higher == None:
                return self.end_lower(avl[2][1],avl[2])
            else:
                return min(key_higher,self.end_lower(avl[2][1],avl[2]))
    
    def end_lower(self,end_lower_key,avl=None):
        if avl == None:
            avl = self.avl
        if avl[0] == None:
            return end_lower_key
        else:
            end_lower_key = avl[0][1]
            return self.end_lower(end_lower_key,avl[0])
    def end_higher(self,end_higher_key,avl=None):
        if avl == None:
            avl = self.avl
        if avl[2] == None:
            return end_higher_key
        else:
            end_higher_key = avl[2][1]
            return self.end_higher(end_higher_key,avl[2])

    def DoubleRightRotation(self,avl):
        tl = avl[0]
        avl[0] = tl[2][2]
        tl[2][2] = avl # selfはそのノード
        tlr = tl[2]
        tl[2] = tlr[0]
        tlr[0] = tl
        if tlr[3] == 1:
            tlr[2][3] = 2
            tlr[0][3] = 0
        elif tlr[3] == 2:
            tlr[2][3] = 0
            tlr[0][3] = 1
        elif tlr[3] == 0:
            tlr[2][3] = 0
            tlr[0][3] = 0
        tlr[3] = 0
        return tlr

    def DoubleLeftRotation(self,avl):
        tr = avl[2]
        avl[2] = tr[0][0]
        tr[0][0] = avl
        trl = tr[0]
        tr[0] = trl[2]
        trl[2] = tr
        if trl[3] == 2:
            trl[0][3] = 1
            trl[2][3] = 0
        elif trl.balance == 1:
            trl[0][3] = 0
            trl[2][3] = 2
        elif trl.balance == 0:
            trl[0][3] = 0
            trl[2][3] = 0
        trl[3] = 0
        return trl

    def SingleLeftRotation(self,avl):
        tr = avl[2]
        tr[3] = 0
        avl[3] = 0
        avl[2] = tr[0]
        tr[0] = avl
        return tr

    def SingleRightRotation(self,avl):
        tl = avl[0]
        tl[3] = 0
        avl[3] = 0
        avl[0] = tl[2]
        tl[2] = avl
        return tl

    def replace(self,p,v,avl): # 親ノードpの下にある自分(avl)をvに置き換える。
        if p[0] is avl:
            p[0] = v
        else :
            p[2] = v
    
    # 木に要素を追加する。
    # 使用法 Instance.insert(x)
    def insert(self,key): # rootでのみ呼ばれる挿入
        if self.avl[1] == None: # rootを含むrotationはしないことにする。
            self.avl[1] = key
            return self.avl
        if key < self.avl[1]:
            if self.avl[0] == None:
                self.avl[0] = self.node(key)
            else:
                self.insertx(self.avl,key,self.avl[0])
        elif self.avl[1] < key:
            if self.avl[2] == None:
                self.avl[2] = self.node(key)
            else:
                self.insertx(self.avl,key,self.avl[2])
        else: # key == self.avl[1]:
            pass # do not overwrite
    def insertx(self,p,key,avl): # replaceを呼ぶために一つ上の親を持っているinsert
        if key < avl[1]:
            if avl[0] == None:
                avl[0] = self.node(key)
            else:
                if not self.insertx(avl,key,avl[0]): # 左の木が生長しなければ、
                    return False # 成長しない
            balance = avl[3]
            if balance == 2:
                avl[3] = 0
                return False
            elif balance == 0:
                avl[3] = 1
                return True # 成長した
            elif balance == 1:
                if avl[0][3] == 2:
                    self.replace(p,self.DoubleRightRotation(avl),avl)
                elif self.left.balance == 1:
                    self.replace(p,self.SingleRightRotation(avl),avl)
                return False # rotationを行うと成長しない
        if avl[1] < key:
            if avl[2] == None:
                avl[2] = self.node(key)
            else:
                if not self.insertx(avl,key,avl[2]):
                    return False
            balance = avl[3]
            if balance == 1:
                avl[3] = 0
                return False
            elif balance == 0:
                avl[3] = 2
                return True
            elif balance == 2:
                if avl[2][3] == 1:
                    self.replace(p,self.DoubleLeftRotation(avl),avl)
                elif avl[2][3] == 2:
                    self.replace(p,self.SingleLeftRotation(avl),avl)
                return False
        return False # avl[1] == keyの時は何もしないので成長もしない

    def debug(self):
        print(self.avl)

具体的な使用例

長いので定義の部分は省略します。

### ここまで定義 ###
def main():
    at = Avltree()
    for e in [2,3,4,7,8,9,11,13,16,18,19,20]:
        at.insert(e)
    for e in [1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21]:
        print(e,at.search(e))
    for e in [1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21]:
        print(e,at.search_lower(e,None),at.search_higher(e,None))
    
if __name__ == "__main__":
    main()

実行結果(コメントは付け足したものです)

1 False # 1~21までが存在するかどうか判定
2 True
3 True
4 True
5 False
6 False
7 True
8 True
9 True
10 False
11 True
12 False
13 True
14 False
15 False
16 True
17 False
18 True
19 True
20 True
21 False
1 None 2 # 1~21までの数字に対し、小さくて最大の要素
2 None 3 # 大きくて最小の要素を出力 なければNone
3 2 4
4 3 7
5 4 7
6 4 7
7 4 8
8 7 9
9 8 11
10 9 11
11 9 13
12 11 13
13 11 16
14 13 16
15 13 16
16 13 18
17 16 18
18 16 19
19 18 20
20 19 None
21 20 None

実行速度結果比較

自分の作ったAVL木
以下のmain関数を10回回しました。

def main():
    start = time.time()
    treeA = Avltree()
    for i in range(0,10**5):
        treeA.insert(i)
    for i in range(1,10**5):
        treeA.search_higher(i)
        treeA.search_lower(i)
    print(time.time()-start)

f:id:harutech:20190226221800p:plain

じゅっぴー氏の作ったAVL木
同様に以下のmain関数を10回回しました。

def main():
	start = time.time()
	treeA = Avltree()
	for i in range(0,10**5):
		treeA.insert(i)
	for i in range(1,10**5):
		treeA.search_higher(i,None)
		treeA.search_lower(i,None)
	print(time.time()-start)

f:id:harutech:20190226222414p:plain

結論

マジで変化はありませんでした。Pythonに平衡二分探索木は無理です。C++を勉強したい。