线段树模板


用python重写下线段树模板,线段树的记忆点在于脑子里要有线段树,特点是自顶向下构建

class segment_tree:
    def __init__(self, nums):
        n = len(nums)
        self.nums = nums
        self.tree = [0] * 4 * n
        self.tag = [0] * 4 * n

    def update(self, root):
        self.tree[root] = self.tree[root * 2] + self.tree[root * 2 + 1]

    # lazy change
    def down(self, left, right, root):
        mid = (left + right) // 2
        if self.tag[root] != 0:
            self.tag[root * 2] += self.tag[root]
            self.tag[root * 2 + 1] += self.tag[root]
            self.tree[root * 2] += (mid - left + 1) * self.tag[root]
            self.tree[root * 2 + 1] += (right - mid) * self.tag[root]
            self.tag[root] = 0

    # [A, B] add v
    # change(A, B, v, 1, n, 1)
    def change(self, A, B, v, left, right, root):
        if A <= left <= right <= B:
            self.tag[root] += v
            self.tree[root] += (right - left + 1) * v
            return
        self.down(left, right, root)
        mid = (left + right) // 2
        if mid >= A:
            self.change(A, B, v, left, mid, root * 2)
        if mid + 1 <= B:
            self.change(A, B, v, mid + 1, right, root * 2 + 1)
        self.update(root)

    # build(1, n, 1)
    def build(self, left, right, root):
        if left == right:
            self.tree[root] = self.nums[left - 1]
            return
        mid = (left + right) // 2
        self.build(left, mid, root * 2)
        self.build(mid + 1, right, root * 2 + 1)
        self.update(root)

    # query(A, B, 1, n, 1)
    def query(self, A, B, left, right, root):
        if A <= left <= right <= B:
            return self.tree[root]
        self.down(left, right, root)
        mid = (left + right) // 2
        res = 0
        if mid >= A:
            res += self.query(A, B, left, mid, root * 2)
        if mid + 1 <= B:
            res += self.query(A, B, mid + 1, right, root * 2 + 1)
        return res

import sys
n, m = map(int, sys.stdin.readline().split())
nums = list(map(int, sys.stdin.readline().split()))
tree = segment_tree(nums)
tree.build(1, n, 1)
for _ in range(m):
    a = list(map(int, sys.stdin.readline().split()))
    # change
    if len(a) == 4:
        x = a[1]
        y = a[2]
        k = a[3]
        tree.change(x, y, k, 1, n, 1)
    # query
    else:
        x = a[1]
        y = a[2]
        res = tree.query(x, y, 1, n, 1)
        print(res)