본문 바로가기
알고리즘 설명

[PS를 위한 자료구조 3강] 세그먼트 트리의 구현(재귀) 2

by 승욱은 2022. 3. 24.
 

PS를 위한 자료구조 1-3강

# 세그먼트 트리의 구현(재귀) 최적화 # 에 대해 알아보겠습니다.

지난 시간까지 세그먼트 트리에 대한 대략적인 이해를 모두 마쳤을 것이라 생각합니다.

이제 세그먼트 트리를 최적화해봅시다.

앞에서 알아본 노드의 설계도를 가져오겠습니다.

Class 노드:
    int 구간의 합
    int 왼쪽, 오른쪽 # 구간의 왼쪽, 오른쪽 경계
    노드 LEFT, RIGHT # 노드는 자료형입니다.

노드의 설계도에서 무려 앞으로 필요없는 값이 있습니다.

바로 구간의 왼쪽, 오른쪽 경계입니다..!!

이게 왜 필요없는지 분명 의아하실 겁니다.

그 이유에 대해 설명하겠습니다.

1. 원소는 업데이트만 될 뿐 추가나 삭제되지 않는다.

-> 원소가 추가되는 경우는 세그먼트 트리 자료구조를 이용할 수 없습니다. 이 경우는 다른 자료구조를 이용해야합니다.

-> 따라서 루트가 담당하는 구간은 항상 일정합니다.

2. 루트에서 시작해서 재귀로 내려간다.

-> 우리가 접근하고자 하는 노드는 항상 루트에서 시작해서 재귀적으로 내려갑니다.

-> 루트의 구간을 절반씩 나누며 매번 재귀하여 들어갈 때마다 구간을 구하면 되는 겁니다.

-> 즉, 구간에 대한 정보를 저장할 필요 없이, 루트의 범위만 알고 있으면 매번 구할 수 있다는 것입니다.

 

 

그렇다면 여기서 질문이 있을 수 있습니다.

선생님. 미리 값을 가지고 있는게 낫지, 매번 구하며 들어가면 손해 아닌가요?

반은 맞고 반은 틀립니다.

1. 미리 값을 저장하는 것이, 아주 매우 미세하게 속도가 빠른 것으로 보이나 큰 차이가 없습니다.

2. 그런데 구간의 값을 가지는 메모리는? 사용하는 노드의 갯수만큼이나 메모리 사용이 커질 것입니다. 그리고 보통 노드의 양이 수십만 개이므로 메모리 사용이 훨씬 클 것을 알 수 있습니다.

'아주 약간의 시간을 위해, 너무 큰 메모리를 사용할 이유는 없다.' 라고 볼 수 있습니다.

이제 설계도에서 구간을 없애주겠습니다.

 

Class 노드:
    int 구간의 합
    노드 LEFT, RIGHT # 노드는 자료형입니다.
 

 

매번 함수의 매개변수로써 구간의 범위에 대한 정보를 받을 수 있도록 함수를 약간씩 수정하겠습니다.

def 세그먼트 트리 전처리(구간의 왼쪽 : int = 1, 구간의 오른쪽 : int = N) -> 노드:
    # 루트의 구간은 항상 정해져 있으므로 기본값으로 1과 N을 넣어줍시다.
    # 해당하는 구간의 노드를 만들고, 해당 노드를 리턴합니다.
    NEW노드 = 노드()
    # NEW노드.왼쪽 = 구간의 왼쪽 #### 삭제
    # NEW노드.오른쪽 = 구간의 오른쪽 #### 삭제
    # 새로운 노드의 구간을 업데이트
    if 구간의 왼쪽 == 구간의 오른쪽:
        # 구간에 포함된 값이 1개라면, 해당값을 등록하고 더 쪼갤 수 없으므로 재귀 종료
        NEW노드.구간의 합 = 해당값
        return NEW노드
    # 중간을 기준으로 구간 쪼개기를 진행.
    중간 = (구간의 왼쪽 + 구간의 오른쪽)//2
    NEW노드.LEFT = 세그먼트 트리 전처리(구간의 왼쪽, 중간)
    # 왼쪽 자식을 재귀로 생성
    NEW노드.RIGHT = 세그먼트 트리 전처리(중간+1, 구간의 오른쪽)
    # 오른쪽 자식을 재귀로 생성
    NEW노드.구간의 합 = NEW노드.LEFT.구간의 합 + NEW노드.RIGHT.구간의합
    # 구간의 합은 두 자식의 구간의 합의 합
    return NEW노드

루트의 구간의 값은 항상 같으므로, 루트의 구간의 값을 기본값으로 넣어주었습니다.

업데이트 함수도 바꿔줍니다.

def 업데이트(노드 : 노드, 번호 : int, 추가값 : int, 구간의 왼쪽 : int = 1, 구간의 오른쪽 : int = N):
    # 해당 노드를 업데이트 하며, 리턴값은 없음
    if 구간의 왼쪽 == 구간의 오른쪽:
        # 마지막 계층 도달. 추가해주고 종료
        노드.구간의 합 += 추가값
        return
    중간 = (구간의 왼쪽 + 구간의 오른쪽)//2
    if 번호 <= 중간:
        업데이트(노드.LEFT, 번호, 추가값, 구간의 왼쪽, 중간)
    else:
        업데이트(노드.RIGHT, 번호, 추가값, 중간+1, 구간의 오른쪽)
    # 자식들의 업데이트가 모두 끝났으므로 다시 자신의 값을 구함
    노드.구간의 합 = 노드.LEFT.구간의 합 + 노드.RIGHT.구간의 합

쿼리 함수도 바꿔줍니다.

def 쿼리(노드 : 노드, L : int, R : int, 구간의 왼쪽 : int = 1, 구간의 오른쪽 : int = N) -> int :
    if 구간의 오른쪽 < L OR R < 구간의 왼쪽:
        # 완전히 포함되지 않는 노드는 답에 영향을 주지 않기 위해 0을 리턴
        return 0
    if L <= 구간의 왼쪽 AND 구간의 오른쪽 <= R:
        # 완전히 포함되는 노드는 자신의 값을 리턴
        return 노드.구간의 합
    중간 = (구간의 왼쪽 + 구간의 오른쪽)//2
    # 위 두 조건에 걸리지 않았다면 애매하게 걸쳐 있는 노드, 자식들에게 맡기자.
    # 왼쪽 자식이 건져오는 값과 오른쪽 자식이 건져오는 값을 합함
    # 만약 포함되는게 없다면 0으로 리턴하여 아무런 영향이 없을것
    return 쿼리(노드.LEFT, L, R, 구간의 왼쪽, 중간) + 쿼리(노드.RIGHT, L, R, 중간+1, 구간의 오른쪽)

여기서 아주 약간의 최적화를 더할 수 있는데, 굉장히 뜬금 없는 부분입니다.

바로 구간의 중간을 구하는 부분입니다.

중간 = (구간의 왼쪽 + 구간의 오른쪽)//2

이걸 최적화한다고? 싶으실텐데, 여기서 비트연산을 이용해줄 것입니다.

중간 = 구간의 왼쪽 + 구간의 오른쪽 >> 1

>> 연산자 비트를 오른쪽으로 미뤄주는데, 여기서 1칸을 미뤄주면 2로 나눠주는 것과 같습니다.

연산자 우선순위가 +가 더 높기 때문에 위 두 식은 결과가 완전히 똑같습니다.

비트 연산자는 매우 빠릅니다.

비트 연산이 가능하다면 무조건 비트연산자를 쓰는 것이 시간적으로 이득입니다.

여기까지 세그먼트 트리를 코드를 첨부합니다.

1. 파이썬

arr = [*range(1, 257)]  # 1부터 256


class Node:
    def __init__(self) -> None:
        self.sum = 0                    # 구간의 합
        self.left = self.right = None   # 왼쪽, 오른쪽 자식


class Segtree:
    def __init__(self) -> None:
        # 세그먼트 트리 [1, 256] 범위 생성자
        self.root = self.init()
        # 트리의 루트 생성

    def init(self, start: int = 1, end: int = 256) -> Node:
        # 노드 생성
        node = Node()
        if start == end:
            # 마지막 계층 도달. 노드의 값 등록
            # 현재 세그먼트 트리는 1-based index이고
            # arr는 0-based index이므로 start에 1을 빼줌
            node.sum = arr[start-1]
            return node
        mid = start + end >> 1
        # 왼쪽, 오른쪽 자식들을 생성하고
        node.left = self.init(start, mid)
        node.right = self.init(mid+1, end)
        # 완성 이후 구간의 합을 줍줍
        node.sum = node.left.sum + node.right.sum
        return node

    def update(self, node: Node, idx: int, plus: int, start: int = 1, end: int = 256) -> None:
        # 노드를 업데이트
        # idx : 업데이트 하고자 하는 값
        # plus : 더하고자 하는 값
        if start == end:
            # 마지막 계층 도달. 노드의 값 추가
            node.sum += plus
            return

        mid = start + end >> 1
        # idx가 mid보다 작거나 같으면 왼쪽으로 아니면 오른쪽으로
        if idx <= mid:
            self.update(node.left, idx, plus, start, mid)
        else:
            self.update(node.right, idx, plus, mid+1, end)
        # 자식들 업데이트 완료. 구간합 줍줍
        node.sum = node.left.sum + node.right.sum

    def query(self, node: Node, l: int, r: int, start: int = 1, end: int = 256) -> int:
        # 구간합 구하기
        # l : 구간의 왼쪽 값, r : 구간의 오른쪽 값
        if end < l or r < start:
            return 0
        if l <= start and end <= r:
            return node.sum
        mid = start + end >> 1
        return self.query(node.left, l, r, start, mid) + self.query(node.right, l, r, mid+1, end)


tree = Segtree()

print(tree.query(tree.root, 2, 5))          # 2+3+4+5
tree.update(tree.root, 3, 2)                # 3 -> 5
print(tree.query(tree.root, 2, 5))          # 2+5+4+5
 

2. c++

#include <bits/stdc++.h>

using ll = long long;
using namespace std;

struct Node
{
    Node *left, *right;
    ll sum;
    Node() {
        left = right = nullptr;
        sum = 0;
    }
    ~Node(){
        if (left) delete left;
        if (right) delete right;
    }
};

int N; // 기본 값 설정을 위한 전역변수

struct Segtree
{
    Node *tree;
    vector<ll> arr;
    Segtree(vector<ll> &tmp){
        arr = tmp;
        N = tmp.size();
        tree = init();
    }
    ~Segtree(){
        if (tree) delete tree;
    }

    Node* init(int start=1, int end=N){
        Node *node = new Node();
        if (start == end){
            node->sum = arr[start-1];
            return node;
        }
        int mid = start + end >> 1;
        node->left = init(start, mid);
        node->right = init(mid+1, end);
        node->sum = node->left->sum + node->right->sum;
        return node;
    }

    void update(Node *node, int idx, int plus, int start=1, int end=N){
        if (start == end){
            node->sum += plus;
            return;
        }
        int mid = start + end >> 1;
        if (idx <= mid) update(node->left, idx, plus, start, mid);
        else update(node->right, idx, plus, mid+1, end);
        node->sum = node->left->sum + node->right->sum;
    }

    ll query(Node *node, int l, int r, int start=1, int end=N){
        if (end < l || r < start) return 0;
        if (l <= start && end <= r) return node->sum;
        int mid = start + end >> 1;
        return query(node->left, l, r, start, mid) + query(node->right, l, r, mid+1, end);
    }
};

이렇게 하면 최적화가 무려 끝이 아닙니다(?)

사실 이제부터 시작입니다.

 

여기부터 실제 PS러들이 실전에서 자주 사용하는 전문적인(?) 세그먼트 트리라고 볼 수 있습니다.

이를 설명하기 이전에 여러분들은 "완전 이진 트리" 에 대해 이해하셔야 합니다.

완전 이진트리란, 트리의 가장 말단 노드(리프 노드)와 그 부모를 제외하고는

모두 자식이 2개씩 채워져 있는 트리를 의미합니다.

그리고 리프 노드는 왼쪽부터 채워져 있어야 합니다.

 

세그먼트 트리는 완전 이진 트리입니다.

세그먼트 트리를 전처리하는 과정에서 리프 노드가 아니면 항상 2개의 자식 노드를 만들어왔기 때문이죠.

완전 이진 트리는 다음과 같이 번호를 붙힐 수 있습니다.

여기서 볼 수 있는 특징을 살펴봅시다..!

1. 자식의 왼쪽 노드는 자기 자신의 번호 *2, 오른쪽 노드는 자기 자신의 번호*2+1 입니다.

-> 비트 연산을 이용해줍시다. 앞에서 말씀 드렸듯이 비트 연산은 빠릅니다.

왼쪽 자식 = 자신<<1
오른쪽 자식 = 자신<<1|1
# 비트를 왼쪽으로 밀어주는 것은 *2를 하는 것과 같음
#  |1 을 해주는 것은 비어져있는 첫번째 비트를 채워준 것. 즉 +1과 역할이 유사.
 

2. 부모의 노드는 자신의 번호//2입니다. (현재 구현에서는 이용하지 않겠습니다.)

이제 무엇을 할거냐..! 객체의 설계도에서 자식 노드의 주소를 삭제할 겁니다.

내 자신의 번호만 알고 있다면 단순한 곱셈, 덧셈으로 자식 노드를 찾아갈 수 있기 때문이죠.

Class 노드:
    int 구간의 합
 

뭔가 허전하네요.. ㅎㅎ 구간의 합만이 남았습니다.

그럼 이제 눈치 빠른 분들은 눈치 채셨겠죠.

이제 객체를 이용할 필요가 없습니다.

그냥 int 배열을 이용하면 됩니다.

int 트리[필요한 노드의 갯수]
# (c++, java는 필요에 따라서 long long)

 

 

단, 여기서 필요한 노드의 갯수는 몇 개일까요?

쉽게 생각하면, N보다 큰 가장 작은 2의 제곱수의 2배 정도입니다.

위의 완전 이진 트리에서 하나의 계층을 채우는 가장 오른쪽 원소의 번호가 2^N-1 꼴임을 볼 수 있기 때문이죠.

하지만 N보다 큰 가장 작은 2의 제곱수의 2배를 구해도 되지만

귀찮다면 메모리를 조금 낭비하고 N의 4배를 해도 괜찮습니다.

아니 그러면 지금까지 객체는 왜 이용한거야?? 이유는 다음과 같습니다.

1. 이해하기 가장 쉬운 방법입니다.

2. 나중에는 또 필요해집니다.

(1) 게으른 세그먼트 트리 구현

(2) 금광 세그 등의 테크닉을 위한 구현

(3) 다이나믹 세그먼트 트리 구현

자 그러면 이제 세그먼트 트리를 다시 구현해봅시다.

위 그림에서 루트는 항상 "1" 이라는 것을 꼭 기억해둡시다. 

def 세그먼트 트리 전처리(구간의 왼쪽 : int = 1, 구간의 오른쪽 : int = N, 노드 : int = 1) -> None:
    #처음 노드는 항상 루트 이므로 기본값에 1을 넣어줄 수 있도록 합니다.
    if 구간의 왼쪽 == 구간의 오른쪽:
        # 구간에 포함된 값이 1개라면, 해당값을 등록하고 더 쪼갤 수 없으므로 재귀 종료
        트리[노드] = 해당값
        return NEW노드
    # 중간을 기준으로 구간 쪼개기를 진행.
    중간 = 구간의 왼쪽 + 구간의 오른쪽 >> 1
    세그먼트 트리 전처리(구간의 왼쪽, 중간, 노드<<1)
    # 왼쪽 자식에 접근
    세그먼트 트리 전처리(중간+1, 구간의 오른쪽, 노드<<1|1)
    # 오른쪽 자식에 접근
    트리[노드] = 트리[노드<<1] + 트리[노드<<1|1]
    # 구간의 합은 두 자식의 구간의 합의 합

업데이트 함수도 바꿔줍니다.

def 업데이트(번호 : int, 추가값 : int, 구간의 왼쪽 : int = 1, 구간의 오른쪽 : int = N, 노드 : int = 1):
    # 해당 노드를 업데이트 하며, 리턴값은 없음
    if 구간의 왼쪽 == 구간의 오른쪽:
        # 마지막 계층 도달. 추가해주고 ㅌㅌ
        트리[노드] += 추가값
        return
    중간 = 구간의 왼쪽 + 구간의 오른쪽 >> 1
    if 번호 <= 중간:
        업데이트(번호, 추가값, 구간의 왼쪽, 중간, 노드<<1)
    else:
        업데이트(번호, 추가값, 중간+1, 구간의 오른쪽, 노드<<1|1)
    # 자식들의 업데이트가 모두 끝났으므로 다시 자신의 값을 구함
    트리[노드] = 트리[노드<<1] + 트리[노드<<1|1]

쿼리 함수도 바꿔줍니다.

def 쿼리(L : int, R : int, 구간의 왼쪽 : int = 1, 구간의 오른쪽 : int = N, 노드 : int = 1) -> int :
    if 구간의 오른쪽 < L OR R < 구간의 왼쪽:
        # 완전히 포함되지 않는 노드는 답에 영향을 주지 않기 위해 0을 리턴
        return 0
    if L <= 구간의 왼쪽 AND 구간의 오른쪽 <= R:
        # 완전히 포함되는 노드는 자신의 값을 리턴
        return 트리[노드]
    중간 = 구간의 왼쪽 + 구간의 오른쪽 >> 1
    # 위 두 조건에 걸리지 않았다면 애매하게 걸쳐 있는 노드, 자식들에게 맡기자.
    # 왼쪽 자식이 건져오는 값과 오른쪽 자식이 건져오는 값을 합함
    # 만약 포함되는게 없다면 0으로 리턴하여 아무런 영향이 없을것
    return 쿼리(노드.LEFT, L, R, 구간의 왼쪽, 중간, x<<1) + 쿼리(노드.RIGHT, L, R, 중간+1, 구간의 오른쪽, x<<1|1)
 

기본적으로 배열에 접근하는 것이 훨씬 빠르고 구조적으로 간단하기 때문에,

객체로 접근하는 시간을 줄여 시간을 크게 단축할 수 있고, 무엇보다 호출할 때 사용이 훨씬 깔끔해집니다.

해당 코드를 첨부합니다.

1. 파이썬

n = (1 << 17) + 2050
tree = [0]*(1 << 19)


def build(arr: list, start: int = 1, end: int = n, node: int = 1) -> None:
    if start == end:
        tree[node] = arr[start-1]
        return
    mid = start + end >> 1
    build(arr, start, mid, node << 1)
    build(arr, mid+1, end, node << 1 | 1)
    tree[node] = tree[node << 1] + tree[node << 1 | 1]


def update(idx: int, plus: int, start: int = 1, end: int = n, node: int = 1) -> None:
    if start == end:
        tree[node] += plus
        return
    mid = start + end >> 1
    if idx <= mid:
        update(idx, plus, start, mid, node << 1)
    else:
        update(idx, plus, mid+1, end, node << 1 | 1)
    tree[node] = tree[node << 1] + tree[node << 1 | 1]


def query(l: int, r: int, start: int = 1, end: int = n, node: int = 1) -> int:
    if end < l or r < start:
        return 0
    if l <= start and end <= r:
        return tree[node]
    mid = start + end >> 1
    return query(l, r, start, mid, node << 1) + query(l, r, mid+1, end, node << 1 | 1)


arr = [*range(1, n+1)]
build(arr)
print(query(2, 5))      # 2+3+4+5
update(3, 2)            # 3 -> 5
print(query(2, 5))      # 2+5+4+5

2. c++

#include <bits/stdc++.h>

using namespace std;
using ll = long long;
int N;
ll tree[1<<19]; // N보다 큰 가장 작은 2의 제곱수의 2배라 가정 2^19

void build(vector<ll> &arr, int start=1, int end=N, int node=1){
    if (start == end){
        tree[node] = arr[start-1];
        return;
    }
    int mid = start + end >> 1;
    build(arr, start, mid, node<<1); 
    build(arr, mid+1, end, node<<1|1);
    tree[node] = tree[node<<1] + tree[node<<1|1];
}


void update(int idx, int plus, int start=1, int end=N, int node=1){
    if (start == end){
        tree[node] += plus;
        return;
    }
    int mid = start + end >> 1;
    if (idx <= mid){
        update(idx, plus, start, mid, node<<1);
    } else {
        update(idx, plus, mid+1, end, node<<1|1);
    }
    tree[node] = tree[node<<1] + tree[node<<1|1];
}

ll query(int l, int r, int start=1, int end=N, int node=1){
    if (end < l || r < start) return 0;
    if (l <= start && end <= r) return tree[node];
    int mid = start + end >> 1;
    return query(l, r, start, mid, node<<1) + query(l, r, mid+1, end, node<<1|1);
}

연습문제를 투척합니다.

14428번: 수열과 쿼리 16 (acmicpc.net)

18436번: 수열과 쿼리 37 (acmicpc.net)

5676번: 음주 코딩 (acmicpc.net)

피드백 및 오류 지적은 언제나 환영입니다.

질문은 댓글로 해주셔도 되고, 제 유튜브 채널을 방문해서 하셔도 되고, 오픈채팅방으로 하셔도 됩니다.

[유튜브 : 승욱은] 승욱은 - YouTube

[ 1:1 오픈채팅방] 카카오톡 오픈채팅 (kakao.com)

 

댓글