본문 바로가기
알고리즘 설명

[PS를 위한 자료구조 2강] 세그먼트 트리의 구현(재귀)

by 승욱은 2022. 3. 24.

 

 

PS를 위한 자료구조 1-2강

# 세그먼트 트리의 구현(재귀) # 에 대해 알아보겠습니다.

우선 세그먼트 트리의 구조에 대해 좀 더 이해해봅시다.

앞선 강의의 그림을 그대로 가져오겠습니다.

위 그림에서 볼 수 있는 특징을 정리하겠습니다.

1. 계층이 존재한다.

2. 한 층 내려갈 때마다 구간이 절반으로 쪼개진다.

3. 이 방식을 계층의 구간이 단 하나의 원소가 될 때 까지 반복된다.

이 특징을 정리해보면 아이디어가 하나 떠오를 겁니다. 바로 '재귀'입니다.

2번 특징은 현재 내가 쪼개고 있는 구간이 무엇인지 상관이 없죠.

그저 쪼개기만 하면 됩니다. 이것을 재귀로 간단하게 나타내면 다음과 같이 쓸 수 있습니다.

def 함수(시작 : int, 끝 : int) -> None:
    if 시작 == 끝 :
          # 시작과 끝이 같다면 구간의 원소는 1개일것입니다.
         return
    중간 = (시작 + 끝)/2 # 시작과 끝의 중간값
    함수(시작, 중간)
    함수(중간+1, 끝)
    # [시작 끝] 구간을 [시작 중간], [중간+1 끝]으로 나누어서 다시 재귀 호출합니다. 재귀 호출하면 그들끼리 알아서 쪼갤겁니다.
 
이제 본격적으로 코드에서 필요한 것이 무엇인지 설계해봅시다.

1. 각 구간의 정보을 저장할 무언가

2. 함수의 매개변수 및 리턴값, 설계

각 구간의 값을 저장할 무언가는 "객체"입니다.

주로 자료구조에서 이러한 역할을 수행하는 객체를 "노드"라고 부릅니다.

각 노드에게 필요한 변수들은 다음과 같습니다.

1. 담당하는 구간의 합

2. 담당하는 구간의 왼쪽, 오른쪽

3. 내 구간의 절반을 가진 자식 노드의 주소

나 자신의 구간의 절반을 가진 자식은 2개가 있으며,

왼쪽 구간을 담당하는 자식을 LEFT 노드, 오른쪽 구간을 담당하는 자식을 RIGHT 노드라고 칭하겠습니다.

Class 노드:
    int 구간의 합
    int 왼쪽, 오른쪽 # 구간의 왼쪽, 오른쪽 경계
    노드 LEFT, RIGHT # 노드는 자료형입니다.

노드의 설계도를 완성했습니다. 이제 2번 세그먼트 트리의 전처리 과정을 자세하게 살펴보죠.

처음에는 아무것도 없을 것입니다.

따라서 가장 상위 계층, 즉 전체 구간을 담당하는 노드를 생성하고,

특별히 이 노드를 "루트"라고 부르겠습니다. (트리의 뿌리)

그리고 구간을 절반으로 자르고 왼쪽 노드, 오른쪽 노드를 새롭게 생성해야합니다.

이를 위해 다음과 같이 함수를 정의하겠습니다.

 

def 세그먼트 트리 전처리(구간의 왼쪽 : int, 구간의 오른쪽 : int) -> 노드:
    # 해당하는 구간의 노드를 만들고, 해당 노드를 리턴합니다.
    NEW노드 = 노드()
    NEW노드.왼쪽 = 구간의 왼쪽
    NEW노드.오른쪽 = 구간의 오른쪽
    # 새로운 노드의 구간을 업데이트
    
    if 구간의 왼쪽 == 구간의 오른쪽:
        # 구간에 포함된 값이 1개라면, 더 쪼갤 수 없으므로 재귀 종료
        return NEW노드

    # 중간을 기준으로 구간 쪼개기를 진행.
    중간 = (구간의 왼쪽 + 구간의 오른쪽)/2
    NEW노드.LEFT = 세그먼트 트리 전처리(구간의 왼쪽, 중간)
    # 왼쪽 자식을 재귀로 생성
    NEW노드.RIGHT = 세그먼트 트리 전처리(중간+1, 구간의 오른쪽)
    # 오른쪽 자식을 재귀로 생성
    return NEW노드

해당 함수는 해당하는 구간의 노드를 만들고, 더 쪼개기가 가능하다면 구간 쪼개기를 진행합니다.​

진행한 이후 노드의 왼쪽 자식, 오른쪽 자식을 각각 LEFT, RIGHT에 등록하고, 자기 자신을 리턴합니다.

 새로운 노드를 만들고 모든 작업이 완료된 노드를 리턴하는 함수인 것입니다.

 

사실 전처리 함수의 설계는 이와 같은 방식만 있는 것은 아닙니다.

다만 제가 시도한 방식 중 가장 코드가 간단하고 직관적이기에 이렇게 전달드립니다.

만약 전체 세그먼트 트리의 구간이 1부터 8이라면 다음과 같이 전처리가 가능합니다.

 

루트 = 세그먼트 트리 전처리(1, 8)

 

끝입니다. 설계만 좀 어렵지 사용은 간단합니다.

우리 눈에 보이진 않아도 다음과 같은 그림이 만들어져있는 상태입니다.

하지만 아직 각각의 노드에는 아직 아무런 값도 등록되어 있지 않습니다. 이를 위해 다음 그림을 참고해봅시다.

여기서 알 수 있는 특징이 있습니다.

마지막 계층의 수가 정해진다면, 그 위쪽의 계층 들의 값은,

왼쪽 자식, 오른쪽 자식의 합으로서 자신의 합을 구할 수 있다는 겁니다.

더 일반적으로 이야기하면 특정 노드의 값은 자식 노드 둘의 '결합'으로 이루어집니다.

다음과 같은 방식을 생각할 수 있습니다.

 

1. 가장 마지막 계층의 노드들에 값을 등록합니다.

2. 해당 값을 부모에게 전달합니다.

하지만 눈치빠른 분들은 아셨을 겁니다.

마지막 계층을 직접적으로 접근하는 방법이 없을 뿐더러

부모에게 접근하는 방법은 없습니다.

 

따라서 다른 방법을 이용해야 합니다.

해당 문제를 해결하기 위해 재귀의 특성을 이용할겁니다.

전처리함수를 다시 불러옵니다.

 

def 세그먼트 트리 전처리(구간의 왼쪽 : int, 구간의 오른쪽 : int) -> 노드:
    # 해당하는 구간의 노드를 만들고, 해당 노드를 리턴합니다.
    NEW노드 = 노드()
    NEW노드.왼쪽 = 구간의 왼쪽
    NEW노드.오른쪽 = 구간의 오른쪽
    # 새로운 노드의 구간을 업데이트
    if 구간의 왼쪽 == 구간의 오른쪽:
        // 구간에 포함된 값이 1개라면, 해당값을 등록하고 더 쪼갤 수 없으므로 재귀 종료
        ################################################# 강조
        NEW노드.구간의 합 = 해당값
        ################################################# 강조
        return NEW노드
    # 중간을 기준으로 구간 쪼개기를 진행.
    중간 = (구간의 왼쪽 + 구간의 오른쪽)/2
    NEW노드.LEFT = 세그먼트 트리 전처리(구간의 왼쪽, 중간)
    # 왼쪽 자식을 재귀로 생성
    NEW노드.RIGHT = 세그먼트 트리 전처리(중간+1, 구간의 오른쪽)
    # 오른쪽 자식을 재귀로 생성
    ##################################################################### 강조
    NEW노드.구간의 합 = NEW노드.LEFT.구간의 합 + NEW노드.RIGHT.구간의합
    ##################################################################### 강조
    # 구간의 합은 두 자식의 구간의 합의 합
    return NEW노드
 

#으로 강조된 부분을 꼭 봐주세요.

마지막 계층에 도달했을 때, 그 값을 등록하고, 재귀를 마치고 올라가면서 부모들의 값을 완성하는 겁니다.

노드의 왼쪽, 오른쪽 자식을 재귀를 통해 완성하면,

왼쪽 자식의 합과 오른쪽 자식의 합을 현재 노드의 합으로 만들어주면,

현재 노드가 완벽하게 완성됩니다.

이제 위 그림이 모두 완성되었습니다.

이제 업데이트 함수와, 특정 구간의 합을 구하는 함수를 설계해보죠.

업데이트 함수

루트에서 시작해서 목표 인덱스가 왼쪽 자식 구간에 포함되면, 왼쪽 자식으로,

오른쪽 자식 구간에 포함되면 오른쪽 자식으로 재귀하여 들어갑니다.

해당 과정으로 마지막 계층까지 들어가면 값을 추가해주고, 종료합니다.

위 전처리 함수와 같이, 자식의 업데이트가 끝나면,

왼쪽 자식 노드의 합과 오른쪽 자식 노드의 합을 더해줘서 현재 노드를 업데이트 시킵니다.

def 업데이트(노드 : 노드, 번호 : int, 추가값 : int) -> None:
    # 해당 노드를 업데이트 하며, 리턴값은 없음
    if 노드.왼쪽 == 노드.오른쪽:
        # 마지막 계층 도달. 추가해주고 종료
        노드.구간의 합 += 추가값
        return
    중간 = (노드.왼쪽 + 노드.오른쪽)/2
    if 번호 <= 중간:
        업데이트(노드.LEFT, 번호, 추가값)
    else:
        업데이트(노드.RIGHT, 번호, 추가값)
    # 자식들의 업데이트가 모두 끝났으므로 다시 자신의 값을 구함
    노드.구간의 합 = 노드.LEFT.구간의 합 + 노드.RIGHT.구간의 합
 

해당 그림이 완성되었다고 볼 수 있습니다.

이제 구간의 합을 구해봅시다.

구간 합 쿼리

요점은 구간에 완전하게 포함되는 노드 가장 큰 것의 합들을 겹치지 않게 모아주면 됩니다.

여기서 키워드는 2개입니다.

1. 완전하게 포함되는

2. 구간에 포함되는 노드 중 가장 큰 것

[L,  R]까지의 구간이 있다고 해봅시다.

구간에 완전하게 포함되는 노드란 다음의 조건을 만족하는 노드일 것입니다.

L <= 노드.왼쪽 and 노드.오른쪽 <= R

그러면 완전하게 포함되지 않는 노드란 어떤 것일까요?

노드.오른쪽 < R OR L < 노드.왼쪽

다음의 조건을 만족한다면 이 노드는 아예 구간에 포함되지 않을 것이라는 것을 쉽게 알 수 있을 것입니다.

이 경우, 이 노드의 자식들도 구간에 포함되지 않을 것이 자명하기 때문에 더 이상 거들떠보지 않아도 됩니다.

 

마지막으로 완전히 포함도, 포함도 되지 않는 노드는 어떻게 해야할까요?

비록 자신은 완벽하게 포함되지 않지만, 자손들 중에는 완벽하게 포함되는 것이 반드시 존재할 것입니다.

따라서 자식들을 재귀적으로 호출하여 더 아래 계층으로 가야합니다.

노드 중 가장 큰 것은 어떻게 해결할 수 있을까요?

위쪽에서부터 아래쪽으로 재귀적으로 내려감으로서 해결이 가능합니다.

루트에서부터 시작하여 내려가면서, 구간에 완전히 포함되는 노드는 그 값을 답에 포함시키고,

더 이상 내려가지 않는 것입니다.

큰 것부터 봐왔기 때문에, 작은 값들은 절대 조회되지 않아 시간 낭비가 되지 않습니다.

이제 함수를 설계해봅시다.

def 쿼리(노드 : 노드, L : int, R : int) -> int :
    if 노드.오른쪽 < L OR R < 노드.왼쪽:
        # 완전히 포함되지 않는 노드는 답에 영향을 주지 않기 위해 0을 리턴
        return 0
    if L <= 노드.왼쪽 AND 노드.오른쪽 <= R:
        # 완전히 포함되는 노드는 자신의 값을 리턴
        return 노드.구간의 합
    # 위 두 조건에 걸리지 않았다면 애매하게 걸쳐 있는 노드, 자식들에게 맡기자.
    # 왼쪽 자식이 건져오는 값과 오른쪽 자식이 건져오는 값을 합함
    # 만약 포함되는게 없다면 0으로 리턴하여 아무런 영향이 없을것
    return 쿼리(노드.LEFT, L, R) + 쿼리(노드.RIGHT, L, R)

 위의 함수로 구간 합을 구하는 모든 것이 설명되었으리라 믿으며 시간복잡도를 분석해보겠습니다.

이 부분은 조금 난이도가 있기 때문에 생략해도 괜찮습니다.

1. 해당 구간 크기의 절반 이상 크기인 노드는 최대 2개 존재하며, 최소한 1개는 존재한다.

2. 1에 따라, 절반 이상 크기인 노드를 포함시키고 남은 구간의 크기는 원래 구간의 절반 이하가 된다.

3. 나눠진 구간 역시 1에 따라 절반 이상 크기인 노드는 최대 2개 존재한다.

4. 3에 따라, 각 계층에서 최대 2개의 노드만이 선택될 수 있다.

계층의 수는 대략 LogN이므로 선택되는 노드는 최대 2LogN개이며,

재귀 가지는 노드당 2번까지 뻗을 수 있으므로, 방문하는 노드는 최대 4LogN개입니다.

즉 해당 연산의 시간복잡도는 상수가 생략되어 O(LogN)임을 알 수 있습니다.

이렇게 세그먼트 트리의 재귀 구현에 대해 알아봤습니다.

하지만 위 구현은 최적화가 덜 되어있습니다.

글이 좀 길어져서 재귀 세그먼트 트리의 2가지 최적화에 대해 다음 강의에서 설명하겠습니다.

위 내용에 기반한 구현을 구현을 C++과 파이썬 모두 첨부합니다.

1. 파이썬 코드

arr = [*range(1, 257)]  # 1부터 256


class Node:
    def __init__(self, start, end) -> None:
        self.sum = 0                    # 구간의 합
        self.start = start              # 구간의 시작
        self.end = end                  # 구간의 끝
        self.left = self.right = None   # 왼쪽, 오른쪽 자식


class Segtree:
    def __init__(self) -> None:
        # 세그먼트 트리 [1, 256] 범위 생성자
        self.root = self.init(1, 256)
        # 트리의 루트 생성

    def init(self, start: int, end: int) -> Node:
        # 노드 생성
        node = Node(start, end)
        if start == end:
            # 마지막 계층 도달. 노드의 값 등록
            # 현재 세그먼트 트리는 1-based index이고
            # arr는 0-based index이므로 start에 1을 빼줌
            node.sum = arr[start-1]
            return node
        mid = (start + end)//2
        # 왼쪽, 오른쪽 자식들을 생성하고
        node.left = self.init(start, mid)
        node.right = self.init(mid+1, end)
        # 완성 이후 구간의 합을 줍줍
        node.sum = node.left.sum + node.right.sum
        return node

    def update(self, node: Node, idx: int, plus: int) -> None:
        # 노드를 업데이트
        # idx : 업데이트 하고자 하는 값
        # plus : 더하고자 하는 값
        if node.start == node.end:
            # 마지막 계층 도달. 노드의 값 추가
            node.sum += plus
            return

        mid = (node.start + node.end)//2
        # idx가 mid보다 작거나 같으면 왼쪽으로 아니면 오른쪽으로
        if idx <= mid:
            self.update(node.left, idx, plus)
        else:
            self.update(node.right, idx, plus)
        # 자식들 업데이트 완료. 구간합 줍줍
        node.sum = node.left.sum + node.right.sum

    def query(self, node: Node, l: int, r: int) -> int:
        # 구간합 구하기
        # l : 구간의 왼쪽 값, r : 구간의 오른쪽 값
        if node.end < l or r < node.start:
            return 0
        if l <= node.start and node.end <= r:
            return node.sum
        mid = (node.start + node.end)//2
        return self.query(node.left, l, r) + self.query(node.right, l, r)


tree = Segtree()

print(tree.query(tree.root, 2, 5))          # 2+3+4+5
tree.update(tree.root, 3, 2)                # 3 -> 5
print(tree.query(tree.root, 2, 5))          # 2+5+4+5

 

2. C++

#include <bits/stdc++.h>

using ll = long long;
using namespace std;

struct Node
{
    int start, end;
    Node *left, *right;
    ll sum;
    Node(int start, int end) : start(start), end(end) {
        left = right = nullptr;
        sum = 0;
    }
    ~Node(){
        if (left) delete left;
        if (right) delete right;
    }
};

struct Segtree
{
    Node *tree;
    vector<ll> arr;
    ~Segtree(){
        if (tree) delete tree;
    }

    void bulid(int start, int end, vector<ll> &tmp){
        arr = tmp;
        tree = init(start, end);
    }

    Node* init(int start, int end){
        Node *node = new Node(start, end);
        if (start == end){
            node->sum = arr[start-1];
            return node;
        }
        int mid = (start + end)/2;
        node->left = init(start, mid);
        node->right = init(mid+1, end);
        node->sum = node->left->sum + node->right->sum;
        return node;
    }

    void update(Node *node, int idx, int plus){
        if (node->start == node->end){
            node->sum += plus;
            return;
        }
        int mid = (node->start + node->end)/2;
        if (idx <= mid) update(node->left, idx, plus);
        else update(node->right, idx, plus);
        node->sum = node->left->sum + node->right->sum;
    }

    ll query(Node *node, int l, int r){
        if (node->end < l || r < node->start) return 0;
        if (l <= node->start && node->end <= r) return node->sum;
        int mid = (node->start + node->end)/2;
        return query(node->left, l, r) + query(node->right, l, r);
    }

};

   이제 여러분들은 세그먼트 트리의 기본에 대해 알았습니다.

연습문제를 투척합니다.

2042번: 구간 합 구하기 (acmicpc.net)

11505번: 구간 곱 구하기 (acmicpc.net)

2357번: 최솟값과 최댓값 (acmicpc.net)

 

[유튜브 : 승욱은] 승욱은 - YouTube

[ 1:1 오픈채팅방] 카카오톡 오픈채팅 (kakao.com)

 

댓글