一、蓝桥杯简介与备考策略

比赛概述

蓝桥杯全国软件和信息技术专业人才大赛是由工业和信息化部人才交流中心举办的年度赛事。Python 高中组比赛形式为闭卷机试,比赛时长 4 小时。

题型分布

题型数量分值特点
结果填空题5-6 题45-50 分只需输出答案,无需提交代码
编程大题5-6 题50-55 分需要提交完整代码,按测试点给分

备考建议

  1. 夯实基础:熟练掌握 Python 基本语法和常用库函数
  2. 刷题积累:完成历年真题,总结常见题型和解题模板
  3. 注重细节:注意边界条件、数据范围、时间复杂度
  4. 调试技巧:学会使用 print 调试和边界测试

二、Python 基础语法

输入输出

输入函数

1
2
3
4
5
6
7
8
9
10
s = input()              # 读取一行字符串
n = int(input()) # 读取整数
m = float(input()) # 读取浮点数

a, b = map(int, input().split()) # 读取两个整数
nums = list(map(int, input().split())) # 读取整数列表

# 读取多行数据
n = int(input())
matrix = [list(map(int, input().split())) for _ in range(n)]

输出函数

1
2
3
4
5
6
print("Hello World")                    # 输出字符串
print(a, b, c) # 输出多个值,空格分隔
print(a, b, c, sep=',') # 用逗号分隔
print("答案是:", ans, sep='') # 不加空格
print(f"结果是 {ans}") # f-string 格式化
print("{:.2f}".format(3.14159)) # 保留两位小数

数据类型

数值类型

1
2
3
4
5
6
7
8
9
10
a = 10                  # 整数
b = 3.14 # 浮点数
c = 10 ** 100 # Python 支持大整数,无溢出问题

# 常用运算
print(7 // 2) # 整除,输出 3
print(7 % 2) # 取余,输出 1
print(2 ** 10) # 幂运算,输出 1024
print(abs(-5)) # 绝对值,输出 5
print(divmod(17, 5)) # 同时返回商和余,输出 (3, 2)

字符串操作

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
s = "Hello World"

print(s[0]) # 索引,输出 'H'
print(s[0:5]) # 切片,输出 'Hello'
print(s[::-1]) # 反转字符串
print(len(s)) # 长度
print(s.lower()) # 转小写
print(s.upper()) # 转大写
print(s.split()) # 分割,输出 ['Hello', 'World']
print('-'.join(['a', 'b', 'c'])) # 连接,输出 'a-b-c'
print(s.replace('World', 'Python')) # 替换
print(s.find('o')) # 查找位置,输出 4
print(s.count('l')) # 统计出现次数

# 字符与 ASCII 码转换
print(ord('A')) # 字符转 ASCII,输出 65
print(chr(65)) # ASCII 转字符,输出 'A'

条件语句

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
score = 85

if score >= 90:
print("优秀")
elif score >= 80:
print("良好")
elif score >= 60:
print("及格")
else:
print("不及格")

# 三元表达式
result = "及格" if score >= 60 else "不及格"

# 多条件判断
if 60 <= score < 80:
print("中等")

循环语句

for 循环

1
2
3
4
5
6
7
8
9
10
11
for i in range(5):              # 0, 1, 2, 3, 4
print(i)

for i in range(1, 10, 2): # 1, 3, 5, 7, 9
print(i)

for i in range(10, 0, -1): # 10, 9, 8, ..., 1
print(i)

for idx, val in enumerate(['a', 'b', 'c']): # 同时获取索引和值
print(idx, val)

while 循环

1
2
3
4
5
6
7
8
9
10
11
12
n = 10
while n > 0:
print(n)
n -= 1

# 循环控制
for i in range(10):
if i == 3:
continue # 跳过本次循环
if i == 7:
break # 跳出循环
print(i)

函数定义

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
def add(a, b):
return a + b

def greet(name, msg="你好"):
print(f"{msg}, {name}!")

def get_stats(nums):
return min(nums), max(nums), sum(nums)

# 递归函数
def factorial(n):
if n <= 1:
return 1
return n * factorial(n - 1)

# Lambda 表达式
square = lambda x: x ** 2
add = lambda a, b: a + b

三、核心数据结构

列表

基本操作

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
arr = [1, 2, 3, 4, 5]

arr.append(6) # 末尾添加
arr.insert(0, 0) # 指定位置插入
arr.pop() # 删除末尾元素
arr.pop(0) # 删除指定位置
arr.remove(3) # 删除指定值(第一个)
arr.extend([7, 8]) # 扩展列表
arr.reverse() # 反转
arr.sort() # 排序(原地)
arr.sort(reverse=True) # 降序排序

print(arr[0]) # 访问元素
print(arr[-1]) # 最后一个元素
print(arr[1:4]) # 切片
print(len(arr)) # 长度
print(sum(arr)) # 求和
print(max(arr)) # 最大值
print(min(arr)) # 最小值

列表推导式

1
2
3
4
squares = [x**2 for x in range(10)]
evens = [x for x in range(20) if x % 2 == 0]
matrix = [[0] * 3 for _ in range(3)] # 创建 3x3 矩阵
flattened = [x for row in matrix for x in row] # 二维转一维

字典

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
d = {'a': 1, 'b': 2, 'c': 3}

d['d'] = 4 # 添加键值对
d['a'] = 10 # 修改值
print(d.get('e', 0)) # 安全获取,不存在返回默认值
d.pop('a') # 删除键值对

print(d.keys()) # 所有键
print(d.values()) # 所有值
print(d.items()) # 所有键值对

for key, val in d.items():
print(key, val)

# 字典推导式
squares = {x: x**2 for x in range(5)}

集合

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
s = {1, 2, 3, 4, 5}

s.add(6) # 添加元素
s.remove(1) # 删除元素(不存在会报错)
s.discard(10) # 删除元素(不存在不报错)

a = {1, 2, 3}
b = {2, 3, 4}
print(a | b) # 并集
print(a & b) # 交集
print(a - b) # 差集
print(a ^ b) # 对称差集

# 去重
nums = [1, 1, 2, 2, 3, 3]
unique = list(set(nums))

元组

1
2
3
4
5
6
t = (1, 2, 3)
print(t[0]) # 访问元素
a, b, c = t # 解包

# 元组作为字典键
d = {(0, 0): '起点', (1, 1): '终点'}

四、常用内置模块

math 模块

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
import math

print(math.pi) # 圆周率
print(math.e) # 自然常数
print(math.sqrt(16)) # 平方根,输出 4.0
print(math.ceil(3.2)) # 向上取整,输出 4
print(math.floor(3.8)) # 向下取整,输出 3
print(math.pow(2, 10)) # 幂运算
print(math.log(100, 10)) # 对数
print(math.gcd(12, 8)) # 最大公约数
print(math.lcm(4, 6)) # 最小公倍数(Python 3.9+)

# 三角函数
print(math.sin(math.pi / 2)) # 正弦
print(math.cos(0)) # 余弦
print(math.tan(math.pi / 4)) # 正切

itertools 模块

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
from itertools import permutations, combinations, product

# 排列
for p in permutations([1, 2, 3], 2):
print(p) # (1,2), (1,3), (2,1), (2,3), (3,1), (3,2)

# 组合
for c in combinations([1, 2, 3, 4], 2):
print(c) # (1,2), (1,3), (1,4), (2,3), (2,4), (3,4)

# 笛卡尔积
for p in product([1, 2], ['a', 'b']):
print(p) # (1,'a'), (1,'b'), (2,'a'), (2,'b')

# 无限迭代器
from itertools import count, cycle, repeat
for i in count(1, 2): # 1, 3, 5, 7, ...
if i > 10:
break

collections 模块

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
from collections import Counter, defaultdict, deque

# 计数器
cnt = Counter(['a', 'b', 'a', 'c', 'a', 'b'])
print(cnt['a']) # 输出 3
print(cnt.most_common(2)) # 最常见的 2 个

# 默认字典
d = defaultdict(list)
d['fruits'].append('apple') # 不需要初始化

# 双端队列
q = deque([1, 2, 3])
q.append(4) # 右端添加
q.appendleft(0) # 左端添加
q.pop() # 右端弹出
q.popleft() # 左端弹出

datetime 模块

1
2
3
4
5
6
7
8
9
10
11
12
13
14
from datetime import datetime, timedelta

now = datetime.now()
print(now.year, now.month, now.day)
print(now.strftime("%Y-%m-%d %H:%M:%S"))

# 日期计算
d1 = datetime(2024, 1, 1)
d2 = datetime(2024, 12, 31)
print((d2 - d1).days) # 相差天数

# 日期加减
tomorrow = now + timedelta(days=1)
next_week = now + timedelta(weeks=1)

五、基础算法

排序算法

Python 内置排序

1
2
3
4
5
6
7
8
9
10
11
12
arr = [3, 1, 4, 1, 5, 9, 2, 6]

arr.sort() # 原地排序
sorted_arr = sorted(arr) # 返回新列表

# 自定义排序
students = [('Alice', 85), ('Bob', 92), ('Carol', 78)]
students.sort(key=lambda x: x[1], reverse=True) # 按分数降序

# 多条件排序
data = [(1, 3), (2, 2), (1, 2), (2, 3)]
data.sort(key=lambda x: (x[0], -x[1])) # 第一列升序,第二列降序

手写快速排序

1
2
3
4
5
6
7
8
def quick_sort(arr):
if len(arr) <= 1:
return arr
pivot = arr[len(arr) // 2]
left = [x for x in arr if x < pivot]
middle = [x for x in arr if x == pivot]
right = [x for x in arr if x > pivot]
return quick_sort(left) + middle + quick_sort(right)

查找算法

二分查找

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
def binary_search(arr, target):
left, right = 0, len(arr) - 1
while left <= right:
mid = (left + right) // 2
if arr[mid] == target:
return mid
elif arr[mid] < target:
left = mid + 1
else:
right = mid - 1
return -1

# 查找第一个 >= target 的位置
def lower_bound(arr, target):
left, right = 0, len(arr)
while left < right:
mid = (left + right) // 2
if arr[mid] < target:
left = mid + 1
else:
right = mid
return left

# 使用 bisect 模块
import bisect
arr = [1, 2, 2, 2, 3, 4]
print(bisect.bisect_left(arr, 2)) # 第一个 >= 2 的位置
print(bisect.bisect_right(arr, 2)) # 第一个 > 2 的位置

递归算法

阶乘与斐波那契

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
def factorial(n):
if n <= 1:
return 1
return n * factorial(n - 1)

def fibonacci(n):
if n <= 1:
return n
return fibonacci(n - 1) + fibonacci(n - 2)

# 记忆化优化
from functools import lru_cache

@lru_cache(maxsize=None)
def fib(n):
if n <= 1:
return n
return fib(n - 1) + fib(n - 2)

搜索算法

深度优先搜索 (DFS)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
def dfs(graph, start, visited=None):
if visited is None:
visited = set()
visited.add(start)
print(start, end=' ')
for neighbor in graph[start]:
if neighbor not in visited:
dfs(graph, neighbor, visited)
return visited

# 迷宫问题 DFS
def solve_maze(maze, x, y, end_x, end_y, visited):
if x == end_x and y == end_y:
return True
if x < 0 or x >= len(maze) or y < 0 or y >= len(maze[0]):
return False
if maze[x][y] == 1 or (x, y) in visited:
return False

visited.add((x, y))
directions = [(0, 1), (1, 0), (0, -1), (-1, 0)]
for dx, dy in directions:
if solve_maze(maze, x + dx, y + dy, end_x, end_y, visited):
return True
return False

广度优先搜索 (BFS)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
from collections import deque

def bfs(graph, start):
visited = set()
queue = deque([start])
visited.add(start)

while queue:
node = queue.popleft()
print(node, end=' ')
for neighbor in graph[node]:
if neighbor not in visited:
visited.add(neighbor)
queue.append(neighbor)

# 最短路径 BFS
def shortest_path(maze, start, end):
queue = deque([(start[0], start[1], 0)])
visited = set([start])

while queue:
x, y, dist = queue.popleft()
if (x, y) == end:
return dist

for dx, dy in [(0, 1), (1, 0), (0, -1), (-1, 0)]:
nx, ny = x + dx, y + dy
if 0 <= nx < len(maze) and 0 <= ny < len(maze[0]):
if maze[nx][ny] == 0 and (nx, ny) not in visited:
visited.add((nx, ny))
queue.append((nx, ny, dist + 1))
return -1

六、动态规划

基本概念

动态规划的核心思想是将复杂问题分解为子问题,通过存储子问题的解避免重复计算。

适用条件

  1. 最优子结构:问题的最优解包含子问题的最优解
  2. 重叠子问题:子问题会被多次计算

经典问题

斐波那契数列

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
def fib_dp(n):
if n <= 1:
return n
dp = [0] * (n + 1)
dp[1] = 1
for i in range(2, n + 1):
dp[i] = dp[i - 1] + dp[i - 2]
return dp[n]

# 空间优化
def fib_optimized(n):
if n <= 1:
return n
a, b = 0, 1
for _ in range(2, n + 1):
a, b = b, a + b
return b

爬楼梯问题

1
2
3
4
5
6
7
8
def climb_stairs(n):
if n <= 2:
return n
dp = [0] * (n + 1)
dp[1], dp[2] = 1, 2
for i in range(3, n + 1):
dp[i] = dp[i - 1] + dp[i - 2]
return dp[n]

0-1 背包问题

1
2
3
4
5
6
7
8
9
def knapsack(weights, values, capacity):
n = len(weights)
dp = [0] * (capacity + 1)

for i in range(n):
for w in range(capacity, weights[i] - 1, -1):
dp[w] = max(dp[w], dp[w - weights[i]] + values[i])

return dp[capacity]

最长公共子序列 (LCS)

1
2
3
4
5
6
7
8
9
10
11
12
def lcs(s1, s2):
m, n = len(s1), len(s2)
dp = [[0] * (n + 1) for _ in range(m + 1)]

for i in range(1, m + 1):
for j in range(1, n + 1):
if s1[i - 1] == s2[j - 1]:
dp[i][j] = dp[i - 1][j - 1] + 1
else:
dp[i][j] = max(dp[i - 1][j], dp[i][j - 1])

return dp[m][n]

最长递增子序列 (LIS)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
def lis(nums):
n = len(nums)
dp = [1] * n

for i in range(1, n):
for j in range(i):
if nums[j] < nums[i]:
dp[i] = max(dp[i], dp[j] + 1)

return max(dp)

# O(n log n) 优化版本
import bisect

def lis_optimized(nums):
tails = []
for num in nums:
pos = bisect.bisect_left(tails, num)
if pos == len(tails):
tails.append(num)
else:
tails[pos] = num
return len(tails)

七、数学相关算法

质数判断与筛法

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
def is_prime(n):
if n < 2:
return False
if n == 2:
return True
if n % 2 == 0:
return False
for i in range(3, int(n ** 0.5) + 1, 2):
if n % i == 0:
return False
return True

def sieve_of_eratosthenes(n):
is_prime = [True] * (n + 1)
is_prime[0] = is_prime[1] = False

for i in range(2, int(n ** 0.5) + 1):
if is_prime[i]:
for j in range(i * i, n + 1, i):
is_prime[j] = False

return [i for i in range(n + 1) if is_prime[i]]

最大公约数与最小公倍数

1
2
3
4
5
6
7
8
9
10
11
12
13
14
def gcd(a, b):
while b:
a, b = b, a % b
return a

def lcm(a, b):
return a * b // gcd(a, b)

# 扩展欧几里得
def extended_gcd(a, b):
if b == 0:
return a, 1, 0
g, x, y = extended_gcd(b, a % b)
return g, y, x - (a // b) * y

快速幂

1
2
3
4
5
6
7
8
9
10
11
12
def fast_pow(base, exp, mod=None):
result = 1
while exp > 0:
if exp % 2 == 1:
result = result * base
if mod:
result %= mod
base = base * base
if mod:
base %= mod
exp //= 2
return result

组合数计算

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
def factorial(n, mod=None):
result = 1
for i in range(2, n + 1):
result *= i
if mod:
result %= mod
return result

def combination(n, k, mod=None):
if k > n or k < 0:
return 0
return factorial(n, mod) // (factorial(k, mod) * factorial(n - k, mod))

# 杨辉三角预处理
def init_combinations(max_n, mod=None):
C = [[0] * (max_n + 1) for _ in range(max_n + 1)]
for i in range(max_n + 1):
C[i][0] = 1
for j in range(1, i + 1):
C[i][j] = C[i - 1][j - 1] + C[i - 1][j]
if mod:
C[i][j] %= mod
return C

八、图论基础

图的表示

1
2
3
4
5
6
7
8
9
10
11
12
13
14
# 邻接表
graph = {
'A': ['B', 'C'],
'B': ['A', 'D', 'E'],
'C': ['A', 'F'],
'D': ['B'],
'E': ['B', 'F'],
'F': ['C', 'E']
}

# 邻接矩阵
n = 5
matrix = [[0] * n for _ in range(n)]
matrix[0][1] = matrix[1][0] = 1

最短路径算法

Dijkstra 算法

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
import heapq

def dijkstra(graph, start):
distances = {node: float('inf') for node in graph}
distances[start] = 0
pq = [(0, start)]

while pq:
dist, node = heapq.heappop(pq)
if dist > distances[node]:
continue
for neighbor, weight in graph[node]:
new_dist = dist + weight
if new_dist < distances[neighbor]:
distances[neighbor] = new_dist
heapq.heappush(pq, (new_dist, neighbor))

return distances

Floyd-Warshall 算法

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
def floyd_warshall(graph):
n = len(graph)
dist = [[float('inf')] * n for _ in range(n)]

for i in range(n):
dist[i][i] = 0
for j in range(n):
if graph[i][j] != 0:
dist[i][j] = graph[i][j]

for k in range(n):
for i in range(n):
for j in range(n):
dist[i][j] = min(dist[i][j], dist[i][k] + dist[k][j])

return dist

并查集

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
class UnionFind:
def __init__(self, n):
self.parent = list(range(n))
self.rank = [0] * n

def find(self, x):
if self.parent[x] != x:
self.parent[x] = self.find(self.parent[x])
return self.parent[x]

def union(self, x, y):
px, py = self.find(x), self.find(y)
if px == py:
return False
if self.rank[px] < self.rank[py]:
px, py = py, px
self.parent[py] = px
if self.rank[px] == self.rank[py]:
self.rank[px] += 1
return True

def connected(self, x, y):
return self.find(x) == self.find(y)

九、字符串处理

字符串匹配

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
def find_all(text, pattern):
positions = []
start = 0
while True:
pos = text.find(pattern, start)
if pos == -1:
break
positions.append(pos)
start = pos + 1
return positions

# KMP 算法
def kmp_search(text, pattern):
def build_lps(pattern):
lps = [0] * len(pattern)
length = 0
for i in range(1, len(pattern)):
while length > 0 and pattern[i] != pattern[length]:
length = lps[length - 1]
if pattern[i] == pattern[length]:
length += 1
lps[i] = length
return lps

lps = build_lps(pattern)
positions = []
j = 0
for i in range(len(text)):
while j > 0 and text[i] != pattern[j]:
j = lps[j - 1]
if text[i] == pattern[j]:
j += 1
if j == len(pattern):
positions.append(i - j + 1)
j = lps[j - 1]
return positions

字符串处理技巧

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
# 回文判断
def is_palindrome(s):
return s == s[::-1]

# 最长回文子串
def longest_palindrome(s):
def expand(left, right):
while left >= 0 and right < len(s) and s[left] == s[right]:
left -= 1
right += 1
return s[left + 1:right]

result = ""
for i in range(len(s)):
odd = expand(i, i)
even = expand(i, i + 1)
if len(odd) > len(result):
result = odd
if len(even) > len(result):
result = even
return result

# 字符串去重
def remove_duplicates(s):
return ''.join(dict.fromkeys(s))

# 统计字符频率
def char_frequency(s):
from collections import Counter
return Counter(s)

十、常见题型与技巧

模拟题

模拟题要求按照题目描述逐步实现过程,注意细节和边界条件。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
# 日期模拟
def count_days(start_date, end_date):
from datetime import datetime
d1 = datetime.strptime(start_date, "%Y-%m-%d")
d2 = datetime.strptime(end_date, "%Y-%m-%d")
return abs((d2 - d1).days)

# 时钟模拟
def clock_simulation(initial_time, minutes):
h, m = map(int, initial_time.split(':'))
total = h * 60 + m + minutes
new_h = (total // 60) % 24
new_m = total % 60
return f"{new_h:02d}:{new_m:02d}"

枚举与暴力

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
# 枚举所有子集
def all_subsets(arr):
result = []
for i in range(1 << len(arr)):
subset = [arr[j] for j in range(len(arr)) if i & (1 << j)]
result.append(subset)
return result

# 枚举区间
def find_max_subarray(arr):
n = len(arr)
max_sum = float('-inf')
for i in range(n):
current_sum = 0
for j in range(i, n):
current_sum += arr[j]
max_sum = max(max_sum, current_sum)
return max_sum

双指针

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
# 两数之和
def two_sum(nums, target):
nums.sort()
left, right = 0, len(nums) - 1
while left < right:
s = nums[left] + nums[right]
if s == target:
return [left, right]
elif s < target:
left += 1
else:
right -= 1
return [-1, -1]

# 滑动窗口
def max_subarray_sum(nums, k):
window_sum = sum(nums[:k])
max_sum = window_sum
for i in range(k, len(nums)):
window_sum = window_sum - nums[i - k] + nums[i]
max_sum = max(max_sum, window_sum)
return max_sum

前缀和与差分

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
# 一维前缀和
def prefix_sum(arr):
n = len(arr)
prefix = [0] * (n + 1)
for i in range(n):
prefix[i + 1] = prefix[i] + arr[i]
return prefix

def range_sum(prefix, l, r):
return prefix[r + 1] - prefix[l]

# 二维前缀和
def prefix_sum_2d(matrix):
m, n = len(matrix), len(matrix[0])
prefix = [[0] * (n + 1) for _ in range(m + 1)]
for i in range(m):
for j in range(n):
prefix[i + 1][j + 1] = prefix[i][j + 1] + prefix[i + 1][j] - prefix[i][j] + matrix[i][j]
return prefix

def range_sum_2d(prefix, r1, c1, r2, c2):
return prefix[r2 + 1][c2 + 1] - prefix[r1][c2 + 1] - prefix[r2 + 1][c1] + prefix[r1][c1]

十一、调试技巧与注意事项

常见错误

  1. 数组越界:注意索引从 0 开始,检查边界条件
  2. 整数溢出:Python 支持大整数,但要注意中间计算
  3. 浮点精度:使用整数运算或 Decimal 模块
  4. 递归深度:设置 sys.setrecursionlimit(10**6)
  5. 时间复杂度:注意数据范围,选择合适算法

调试方法

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
# 打印调试
def debug(*args):
print(f"[DEBUG] {args}")

# 边界测试
def test_boundary():
print(func(0))
print(func(1))
print(func(-1))
print(func(10**9))

# 性能测试
import time

def time_test(func, *args):
start = time.time()
result = func(*args)
end = time.time()
print(f"Time: {end - start:.4f}s")
return result

代码模板

1
2
3
4
5
6
7
8
9
10
11
12
13
import sys
from collections import defaultdict, deque

def solve():
n = int(input())
arr = list(map(int, input().split()))

result = 0

print(result)

if __name__ == "__main__":
solve()

十二、备考资源

推荐练习平台

  1. 蓝桥杯官网:历年真题练习
  2. 洛谷:题目分类清晰,适合刷题
  3. 力扣:算法题目丰富,有详细题解
  4. Codeforces:提高编程能力

学习建议

  1. 基础阶段:熟练掌握 Python 语法和常用库
  2. 进阶阶段:学习常见算法和数据结构
  3. 冲刺阶段:大量刷真题,总结题型模板
  4. 模拟练习:限时完成模拟题,适应考试节奏

十三、真题解析与实战

结果填空题技巧

结果填空题只需输出答案,可以使用暴力枚举、数学推导、程序计算等方法。

例题:数字排列

题目:用 0-9 十个数字各一次,组成一个四位数和一个六位数,使得四位数是六位数的约数,共有多少种不同的组合?

解题思路

1
2
3
4
5
6
7
8
9
10
11
from itertools import permutations

count = 0
for p in permutations('0123456789'):
if p[0] == '0' or p[4] == '0':
continue
four = int(''.join(p[:4]))
six = int(''.join(p[4:]))
if six % four == 0:
count += 1
print(count)

例题:日期问题

题目:小明 1949 年 10 月 1 日出生,问 2024 年 10 月 1 日是他出生的第多少天?

解题思路

1
2
3
4
5
6
from datetime import datetime, timedelta

start = datetime(1949, 10, 1)
end = datetime(2024, 10, 1)
days = (end - start).days
print(days)

例题:等差素数列

题目:类似 7, 37, 67, 97, 127, 157 这样完全由素数组成的等差数列,叫等差素数列。公差可以容忍,但数列长度必须尽可能长。最长的等差素数列长度是多少?

解题思路

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
def is_prime(n):
if n < 2:
return False
for i in range(2, int(n**0.5) + 1):
if n % i == 0:
return False
return True

primes = [i for i in range(2, 10000) if is_prime(i)]
prime_set = set(primes)

max_len = 0
for i, start in enumerate(primes):
for j in range(i + 1, min(i + 500, len(primes))):
diff = primes[j] - start
length = 2
current = primes[j] + diff
while current in prime_set:
length += 1
current += diff
max_len = max(max_len, length)

print(max_len)

编程大题技巧

例题:区间第 K 小

题目:给定 n 个数的序列,m 次询问,每次询问区间 [l, r] 中第 k 小的数。

解题思路

1
2
3
4
5
6
7
n, m = map(int, input().split())
arr = list(map(int, input().split()))

for _ in range(m):
l, r, k = map(int, input().split())
sub = sorted(arr[l-1:r])
print(sub[k-1])

例题:迷宫最短路径

题目:给定 n×m 的迷宫,0 表示可走,1 表示障碍,求从起点到终点的最短路径长度。

解题思路

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
from collections import deque

def solve():
n, m = map(int, input().split())
maze = [list(map(int, input().split())) for _ in range(n)]
sx, sy = map(int, input().split())
ex, ey = map(int, input().split())

visited = [[False] * m for _ in range(n)]
queue = deque([(sx, sy, 0)])
visited[sx][sy] = True

directions = [(0, 1), (1, 0), (0, -1), (-1, 0)]

while queue:
x, y, dist = queue.popleft()
if x == ex and y == ey:
print(dist)
return

for dx, dy in directions:
nx, ny = x + dx, y + dy
if 0 <= nx < n and 0 <= ny < m and not visited[nx][ny] and maze[nx][ny] == 0:
visited[nx][ny] = True
queue.append((nx, ny, dist + 1))

print(-1)

solve()

例题:数字三角形

题目:给定一个数字三角形,从顶部出发,每次只能移动到下方相邻的两个数字,求从顶到底的最大路径和。

解题思路

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
n = int(input())
triangle = [list(map(int, input().split())) for _ in range(n)]

dp = [[0] * n for _ in range(n)]
dp[0][0] = triangle[0][0]

for i in range(1, n):
for j in range(i + 1):
if j == 0:
dp[i][j] = dp[i-1][j] + triangle[i][j]
elif j == i:
dp[i][j] = dp[i-1][j-1] + triangle[i][j]
else:
dp[i][j] = max(dp[i-1][j-1], dp[i-1][j]) + triangle[i][j]

print(max(dp[n-1]))

十四、进阶数据结构

线段树

线段树用于区间查询和单点更新,时间复杂度 O(log n)。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
class SegmentTree:
def __init__(self, data):
self.n = len(data)
self.tree = [0] * (4 * self.n)
self._build(data, 1, 0, self.n - 1)

def _build(self, data, node, start, end):
if start == end:
self.tree[node] = data[start]
else:
mid = (start + end) // 2
self._build(data, node * 2, start, mid)
self._build(data, node * 2 + 1, mid + 1, end)
self.tree[node] = self.tree[node * 2] + self.tree[node * 2 + 1]

def update(self, idx, val):
self._update(1, 0, self.n - 1, idx, val)

def _update(self, node, start, end, idx, val):
if start == end:
self.tree[node] = val
else:
mid = (start + end) // 2
if idx <= mid:
self._update(node * 2, start, mid, idx, val)
else:
self._update(node * 2 + 1, mid + 1, end, idx, val)
self.tree[node] = self.tree[node * 2] + self.tree[node * 2 + 1]

def query(self, l, r):
return self._query(1, 0, self.n - 1, l, r)

def _query(self, node, start, end, l, r):
if r < start or end < l:
return 0
if l <= start and end <= r:
return self.tree[node]
mid = (start + end) // 2
left_sum = self._query(node * 2, start, mid, l, r)
right_sum = self._query(node * 2 + 1, mid + 1, end, l, r)
return left_sum + right_sum

树状数组

树状数组(Fenwick Tree)用于前缀和查询和单点更新。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
class FenwickTree:
def __init__(self, n):
self.n = n
self.tree = [0] * (n + 1)

def update(self, i, delta):
while i <= self.n:
self.tree[i] += delta
i += i & (-i)

def query(self, i):
result = 0
while i > 0:
result += self.tree[i]
i -= i & (-i)
return result

def range_query(self, l, r):
return self.query(r) - self.query(l - 1)

单调栈

单调栈用于解决下一个更大元素等问题。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
def next_greater_element(nums):
n = len(nums)
result = [-1] * n
stack = []

for i in range(n):
while stack and nums[stack[-1]] < nums[i]:
result[stack.pop()] = nums[i]
stack.append(i)

return result

def largest_rectangle_area(heights):
heights = heights + [0]
stack = [-1]
max_area = 0

for i, h in enumerate(heights):
while stack[-1] != -1 and heights[stack[-1]] > h:
height = heights[stack.pop()]
width = i - stack[-1] - 1
max_area = max(max_area, height * width)
stack.append(i)

return max_area

单调队列

单调队列用于滑动窗口最值问题。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
from collections import deque

def sliding_window_max(nums, k):
result = []
dq = deque()

for i, num in enumerate(nums):
while dq and nums[dq[-1]] < num:
dq.pop()
dq.append(i)

if dq[0] <= i - k:
dq.popleft()

if i >= k - 1:
result.append(nums[dq[0]])

return result

字典树(Trie)

字典树用于字符串前缀匹配。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
class Trie:
def __init__(self):
self.children = {}
self.is_end = False

def insert(self, word):
node = self
for ch in word:
if ch not in node.children:
node.children[ch] = Trie()
node = node.children[ch]
node.is_end = True

def search(self, word):
node = self
for ch in word:
if ch not in node.children:
return False
node = node.children[ch]
return node.is_end

def starts_with(self, prefix):
node = self
for ch in prefix:
if ch not in node.children:
return False
node = node.children[ch]
return True

十五、高级算法技巧

状态压缩 DP

状态压缩用于处理状态集合较小的问题。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
def traveling_salesman(cost):
n = len(cost)
dp = [[float('inf')] * n for _ in range(1 << n)]
dp[1][0] = 0

for mask in range(1 << n):
for u in range(n):
if not (mask & (1 << u)):
continue
for v in range(n):
if mask & (1 << v):
continue
new_mask = mask | (1 << v)
dp[new_mask][v] = min(dp[new_mask][v], dp[mask][u] + cost[u][v])

full_mask = (1 << n) - 1
return min(dp[full_mask][i] + cost[i][0] for i in range(1, n))

def count_subsets_with_sum(nums, target):
n = len(nums)
count = 0

for mask in range(1 << n):
total = sum(nums[i] for i in range(n) if mask & (1 << i))
if total == target:
count += 1

return count

记忆化搜索

记忆化搜索是自顶向下的动态规划。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
from functools import lru_cache

def climb_stairs_memo(n):
@lru_cache(maxsize=None)
def dfs(i):
if i <= 1:
return 1
return dfs(i - 1) + dfs(i - 2)
return dfs(n)

def min_path_sum_memo(grid):
m, n = len(grid), len(grid[0])

@lru_cache(maxsize=None)
def dfs(i, j):
if i == 0 and j == 0:
return grid[0][0]
if i < 0 or j < 0:
return float('inf')
return grid[i][j] + min(dfs(i - 1, j), dfs(i, j - 1))

return dfs(m - 1, n - 1)

贪心算法

贪心算法在每一步选择局部最优解。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
def activity_selection(starts, ends):
activities = sorted(zip(starts, ends), key=lambda x: x[1])
count = 0
last_end = -float('inf')

for s, e in activities:
if s >= last_end:
count += 1
last_end = e

return count

def fractional_knapsack(weights, values, capacity):
items = sorted(zip(weights, values), key=lambda x: x[1] / x[0], reverse=True)
total_value = 0

for w, v in items:
if capacity >= w:
total_value += v
capacity -= w
else:
total_value += v * (capacity / w)
break

return total_value

def jump_game(nums):
max_reach = 0
for i, num in enumerate(nums):
if i > max_reach:
return False
max_reach = max(max_reach, i + num)
return True

分治算法

分治算法将问题分解为子问题递归求解。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
def merge_sort(arr):
if len(arr) <= 1:
return arr

mid = len(arr) // 2
left = merge_sort(arr[:mid])
right = merge_sort(arr[mid:])

result = []
i = j = 0
while i < len(left) and j < len(right):
if left[i] <= right[j]:
result.append(left[i])
i += 1
else:
result.append(right[j])
j += 1
result.extend(left[i:])
result.extend(right[j:])

return result

def count_inversions(arr):
def merge_count(left, right):
result = []
count = 0
i = j = 0
while i < len(left) and j < len(right):
if left[i] <= right[j]:
result.append(left[i])
i += 1
else:
result.append(right[j])
count += len(left) - i
j += 1
result.extend(left[i:])
result.extend(right[j:])
return result, count

def sort_count(arr):
if len(arr) <= 1:
return arr, 0
mid = len(arr) // 2
left, lc = sort_count(arr[:mid])
right, rc = sort_count(arr[mid:])
merged, mc = merge_count(left, right)
return merged, lc + rc + mc

_, count = sort_count(arr)
return count

二分答案

二分答案用于求解满足条件的最值。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
def find_min_max_distance(positions, m):
positions.sort()

def can_place(min_dist):
count = 1
last = positions[0]
for pos in positions[1:]:
if pos - last >= min_dist:
count += 1
last = pos
return count >= m

left, right = 1, positions[-1] - positions[0]
while left < right:
mid = (left + right + 1) // 2
if can_place(mid):
left = mid
else:
right = mid - 1

return left

def minimize_max_value(nums, k):
def can_achieve(max_val):
operations = 0
for num in nums:
if num > max_val:
operations += (num - max_val + k - 1) // k
return operations <= sum(n - max_val for n in nums if n > max_val)

left, right = min(nums), max(nums)
while left < right:
mid = (left + right) // 2
if can_achieve(mid):
right = mid
else:
left = mid + 1

return left

十六、图论进阶

拓扑排序

拓扑排序用于有向无环图的排序。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
from collections import deque

def topological_sort(n, edges):
graph = [[] for _ in range(n)]
in_degree = [0] * n

for u, v in edges:
graph[u].append(v)
in_degree[v] += 1

queue = deque([i for i in range(n) if in_degree[i] == 0])
result = []

while queue:
node = queue.popleft()
result.append(node)
for neighbor in graph[node]:
in_degree[neighbor] -= 1
if in_degree[neighbor] == 0:
queue.append(neighbor)

return result if len(result) == n else []

最小生成树

Kruskal 算法

1
2
3
4
5
6
7
8
9
10
11
12
def kruskal(n, edges):
edges.sort(key=lambda x: x[2])
uf = UnionFind(n)
total_weight = 0
mst_edges = []

for u, v, w in edges:
if uf.union(u, v):
total_weight += w
mst_edges.append((u, v, w))

return total_weight, mst_edges

Prim 算法

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
import heapq

def prim(n, graph):
visited = [False] * n
min_heap = [(0, 0)]
total_weight = 0

while min_heap:
weight, node = heapq.heappop(min_heap)
if visited[node]:
continue
visited[node] = True
total_weight += weight

for neighbor, w in graph[node]:
if not visited[neighbor]:
heapq.heappush(min_heap, (w, neighbor))

return total_weight

强连通分量

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
def kosaraju(n, graph):
def dfs1(v, visited, stack):
visited[v] = True
for u in graph[v]:
if not visited[u]:
dfs1(u, visited, stack)
stack.append(v)

def dfs2(v, visited, component, reverse_graph):
visited[v] = True
component.append(v)
for u in reverse_graph[v]:
if not visited[u]:
dfs2(u, visited, component, reverse_graph)

visited = [False] * n
stack = []
for i in range(n):
if not visited[i]:
dfs1(i, visited, stack)

reverse_graph = [[] for _ in range(n)]
for v in range(n):
for u in graph[v]:
reverse_graph[u].append(v)

visited = [False] * n
sccs = []

while stack:
v = stack.pop()
if not visited[v]:
component = []
dfs2(v, visited, component, reverse_graph)
sccs.append(component)

return sccs

十七、数论进阶

欧拉函数

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
def euler_phi(n):
result = n
p = 2
while p * p <= n:
if n % p == 0:
while n % p == 0:
n //= p
result -= result // p
p += 1
if n > 1:
result -= result // n
return result

def euler_phi_sieve(n):
phi = list(range(n + 1))
for i in range(2, n + 1):
if phi[i] == i:
for j in range(i, n + 1, i):
phi[j] -= phi[j] // i
return phi

模逆元

1
2
3
4
5
6
7
8
9
10
11
12
13
14
def mod_inverse(a, mod):
def extended_gcd(a, b):
if b == 0:
return a, 1, 0
g, x, y = extended_gcd(b, a % b)
return g, y, x - (a // b) * y

g, x, _ = extended_gcd(a, mod)
if g != 1:
return None
return x % mod

def mod_inverse_fermat(a, mod):
return pow(a, mod - 2, mod)

中国剩余定理

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
def chinese_remainder_theorem(remainders, moduli):
def extended_gcd(a, b):
if b == 0:
return a, 1, 0
g, x, y = extended_gcd(b, a % b)
return g, y, x - (a // b) * y

M = 1
for m in moduli:
M *= m

result = 0
for r, m in zip(remainders, moduli):
Mi = M // m
_, inv, _ = extended_gcd(Mi, m)
result += r * Mi * inv

return result % M

矩阵快速幂

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
def matrix_multiply(A, B, mod=None):
n = len(A)
m = len(B[0])
k = len(B)
C = [[0] * m for _ in range(n)]

for i in range(n):
for j in range(m):
for p in range(k):
C[i][j] += A[i][p] * B[p][j]
if mod:
C[i][j] %= mod

return C

def matrix_pow(M, n, mod=None):
size = len(M)
result = [[1 if i == j else 0 for j in range(size)] for i in range(size)]

while n > 0:
if n % 2 == 1:
result = matrix_multiply(result, M, mod)
M = matrix_multiply(M, M, mod)
n //= 2

return result

def fibonacci_matrix(n, mod=None):
if n <= 1:
return n
M = [[1, 1], [1, 0]]
result = matrix_pow(M, n - 1, mod)
return result[0][0]

十八、常见易错点详解

整数除法问题

1
2
3
4
5
6
7
8
9
# 错误:Python 3 中 / 返回浮点数
a = 7 / 2 # 3.5,不是 3

# 正确:使用 // 整除
a = 7 // 2 # 3

# 注意负数整除
print(-7 // 2) # -4,向下取整
print(int(-7 / 2)) # -3,向零取整

列表复制问题

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# 错误:浅拷贝,修改会影响原列表
arr = [[1, 2], [3, 4]]
copy_arr = arr[:]
copy_arr[0][0] = 99 # arr 也会被修改

# 正确:深拷贝
import copy
arr = [[1, 2], [3, 4]]
copy_arr = copy.deepcopy(arr)
copy_arr[0][0] = 99 # arr 不受影响

# 二维数组创建错误
matrix = [[0] * 3] * 3 # 三行共享同一行
matrix[0][0] = 1 # 所有行的第一个元素都变成 1

# 正确创建二维数组
matrix = [[0] * 3 for _ in range(3)]

字典遍历修改问题

1
2
3
4
5
6
7
8
9
10
11
# 错误:遍历时修改字典
d = {'a': 1, 'b': 2, 'c': 3}
for key in d:
if d[key] == 2:
del d[key] # RuntimeError

# 正确:收集要删除的键
d = {'a': 1, 'b': 2, 'c': 3}
to_delete = [k for k, v in d.items() if v == 2]
for key in to_delete:
del d[key]

浮点数精度问题

1
2
3
4
5
6
7
8
9
10
11
12
# 错误:浮点数比较
a = 0.1 + 0.2
print(a == 0.3) # False

# 正确:使用容差比较
def float_equal(a, b, eps=1e-9):
return abs(a - b) < eps

# 或者使用 Decimal
from decimal import Decimal
a = Decimal('0.1') + Decimal('0.2')
print(a == Decimal('0.3')) # True

递归深度问题

1
2
3
4
5
6
7
import sys
sys.setrecursionlimit(10**6)

def deep_recursion(n):
if n == 0:
return 0
return 1 + deep_recursion(n - 1)

边界条件处理

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
def binary_search(arr, target):
left, right = 0, len(arr) - 1
while left <= right: # 注意是 <=
mid = left + (right - left) // 2 # 防止溢出
if arr[mid] == target:
return mid
elif arr[mid] < target:
left = mid + 1
else:
right = mid - 1
return -1

def safe_divide(a, b):
if b == 0:
return float('inf') if a >= 0 else float('-inf')
return a / b

十九、时间复杂度优化技巧

常见时间复杂度对照表

数据规模可接受复杂度适用算法
n ≤ 10O(n!)全排列
n ≤ 20O(2^n)状态压缩
n ≤ 100O(n^3)Floyd
n ≤ 1000O(n^2)朴素 DP
n ≤ 10^4O(n√n)优化枚举
n ≤ 10^5O(n log n)排序、二分
n ≤ 10^6O(n)线性算法
n ≤ 10^8O(log n)快速幂

优化技巧

1. 预处理

1
2
3
4
5
6
7
8
9
10
11
12
13
14
# 预处理阶乘
MAX_N = 10**5
factorial = [1] * (MAX_N + 1)
for i in range(1, MAX_N + 1):
factorial[i] = factorial[i - 1] * i

# 预处理素数
is_prime = [True] * (MAX_N + 1)
primes = []
for i in range(2, MAX_N + 1):
if is_prime[i]:
primes.append(i)
for j in range(i * i, MAX_N + 1, i):
is_prime[j] = False

2. 空间换时间

1
2
3
4
5
6
7
8
# 使用哈希表加速查找
def two_sum_hash(nums, target):
seen = {}
for i, num in enumerate(nums):
if target - num in seen:
return [seen[target - num], i]
seen[num] = i
return [-1, -1]

3. 剪枝优化

1
2
3
4
5
6
7
8
9
10
11
def dfs_with_pruning(arr, target, start, current, result):
if current == target:
result.append(current[:])
return
if current > target:
return # 剪枝

for i in range(start, len(arr)):
current.append(arr[i])
dfs_with_pruning(arr, target, i, current, result)
current.pop()

4. 滚动数组

1
2
3
4
5
6
def knapsack_optimized(weights, values, capacity):
dp = [0] * (capacity + 1)
for w, v in zip(weights, values):
for j in range(capacity, w - 1, -1):
dp[j] = max(dp[j], dp[j - w] + v)
return dp[capacity]

二十、模拟考试模板

标准答题模板

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
import sys
from collections import defaultdict, deque

input = sys.stdin.readline

def solve():
n = int(input())
arr = list(map(int, input().split()))

result = 0

print(result)

if __name__ == "__main__":
solve()

多组测试数据模板

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
import sys

input = sys.stdin.readline

def solve():
n = int(input())
if n == 0:
return False

arr = list(map(int, input().split()))

print(result)
return True

if __name__ == "__main__":
while solve():
pass

快速 IO 模板

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
import sys

input = sys.stdin.readline
print = sys.stdout.write

def solve():
n = int(input())
arr = list(map(int, input().split()))

result = 0

print(str(result) + '\n')

if __name__ == "__main__":
solve()

二十一、历年真题精选解析

2023 年真题:卡片

题目:小蓝有很多数字卡片,每张卡片上都是数字 0 到 9。小蓝准备用这些卡片来拼一些数,他想从 1 开始拼出正整数,每拼一个就保存起来,卡片就不能用来拼其它数了。小蓝想知道自己能从 1 拼到多少。

解题思路

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
def solve():
cards = [2021] * 10

for num in range(1, 100000):
n = num
while n > 0:
digit = n % 10
if cards[digit] == 0:
print(num - 1)
return
cards[digit] -= 1
n //= 10

print(num)

solve()

2022 年真题:裁纸刀

题目:小蓝有一个裁纸刀,每次可以将一张纸切成两份。现在有 n 张纸,问最少需要切多少刀才能得到 m 张纸?

解题思路

1
2
3
4
5
6
n, m = map(int, input().split())

if m <= n:
print(0)
else:
print(m - n)

2021 年真题:货物摆放

题目:小蓝有一个超大的仓库,可以摆放很多货物。现在有 n 个货物,每个货物占 1×1 的位置。仓库是长方形的,问有多少种不同的摆放方案(长宽可以交换,视为同一种)。

解题思路

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
def count_divisors(n):
count = 0
i = 1
while i * i <= n:
if n % i == 0:
count += 1
if i * i != n:
count += 1
i += 1
return count

n = 2021041820210418

factors = []
i = 1
while i * i <= n:
if n % i == 0:
factors.append(i)
if i * i != n:
factors.append(n // i)
i += 1

count = 0
for a in factors:
for b in factors:
if n % (a * b) == 0:
c = n // (a * b)
if c in factors:
count += 1

print(count)

2020 年真题:蛇形填数

题目:如下图所示,小明用从 1 开始的正整数”蛇形”填充无限大的矩阵。容易看出矩阵第 2 行第 2 列中的数是 5。请你计算矩阵中第 20 行第 20 列的数是多少?

1
2
3
4
5
1  2  6  7  15 ...
3 5 8 14 ...
4 9 13 ...
10 12 ...
11 ...

解题思路

1
2
3
4
5
6
7
8
9
def get_value(row, col):
n = row + col - 1
start = n * (n - 1) // 2 + 1
if n % 2 == 1:
return start + col - 1
else:
return start + row - 1

print(get_value(20, 20))

2019 年真题:数的分解

题目:把 2019 分解成 3 个各不相同的正整数之和,并且要求每个正整数都不包含数字 2 和 4,一共有多少种不同的分解方法?

解题思路

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
def has_invalid_digit(n):
while n > 0:
digit = n % 10
if digit == 2 or digit == 4:
return True
n //= 10
return False

count = 0
for a in range(1, 2020):
if has_invalid_digit(a):
continue
for b in range(a + 1, 2020):
if has_invalid_digit(b):
continue
c = 2019 - a - b
if c <= b or has_invalid_digit(c):
continue
count += 1

print(count)

二十二、高级 DP 专题

区间 DP

区间 DP 用于处理区间上的最优问题,通常枚举区间长度和分割点。

石子合并

题目:有 n 堆石子排成一排,每次可以将相邻的两堆合并,代价为两堆石子数之和。求将所有石子合并成一堆的最小代价。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
def stone_merge(stones):
n = len(stones)
prefix = [0] * (n + 1)
for i in range(n):
prefix[i + 1] = prefix[i] + stones[i]

dp = [[0] * n for _ in range(n)]

for length in range(2, n + 1):
for i in range(n - length + 1):
j = i + length - 1
dp[i][j] = float('inf')
for k in range(i, j):
cost = dp[i][k] + dp[k + 1][j] + prefix[j + 1] - prefix[i]
dp[i][j] = min(dp[i][j], cost)

return dp[0][n - 1]

矩阵链乘法

1
2
3
4
5
6
7
8
9
10
11
12
13
def matrix_chain_order(dims):
n = len(dims) - 1
dp = [[0] * n for _ in range(n)]

for length in range(2, n + 1):
for i in range(n - length + 1):
j = i + length - 1
dp[i][j] = float('inf')
for k in range(i, j):
cost = dp[i][k] + dp[k + 1][j] + dims[i] * dims[k + 1] * dims[j + 1]
dp[i][j] = min(dp[i][j], cost)

return dp[0][n - 1]

树形 DP

树形 DP 用于处理树结构上的动态规划问题。

树的最大独立集

1
2
3
4
5
6
7
8
9
10
11
12
13
14
def max_independent_set(tree, n):
def dfs(node, parent):
include = 1
exclude = 0

for child in tree[node]:
if child != parent:
child_include, child_exclude = dfs(child, node)
include += child_exclude
exclude += max(child_include, child_exclude)

return include, exclude

return max(dfs(0, -1))

树的直径

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
def tree_diameter(tree, n):
def dfs(node, parent):
max_depth = 0
second_depth = 0

for child, weight in tree[node]:
if child != parent:
depth = dfs(child, node) + weight
if depth > max_depth:
second_depth = max_depth
max_depth = depth
elif depth > second_depth:
second_depth = depth

nonlocal diameter
diameter = max(diameter, max_depth + second_depth)
return max_depth

diameter = 0
dfs(0, -1)
return diameter

数位 DP

数位 DP 用于统计满足特定条件的数字个数。

不含连续 1 的数字

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
def count_without_consecutive_ones(n):
s = str(n)
length = len(s)

from functools import lru_cache

@lru_cache(None)
def dfs(pos, tight, last_one):
if pos == length:
return 1

limit = int(s[pos]) if tight else 9
count = 0

for digit in range(limit + 1):
new_tight = tight and digit == limit
if last_one and digit == 1:
continue
count += dfs(pos + 1, new_tight, digit == 1)

return count

return dfs(0, True, False)

数字中某数位出现的次数

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
def count_digit_occurrences(n, target):
s = str(n)
length = len(s)

from functools import lru_cache

@lru_cache(None)
def dfs(pos, tight, count):
if pos == length:
return count

limit = int(s[pos]) if tight else 9
total = 0

for digit in range(limit + 1):
new_tight = tight and digit == limit
new_count = count + (1 if digit == target else 0)
total += dfs(pos + 1, new_tight, new_count)

return total

return dfs(0, True, 0)

概率 DP

骰子期望

1
2
3
4
5
6
7
8
9
10
def dice_expectation(n):
dp = [0] * (n + 1)
dp[0] = 0

for i in range(1, n + 1):
for j in range(1, 7):
if i >= j:
dp[i] += (dp[i - j] + 1) / 6

return dp[n]

二十三、字符串高级算法

Z 函数

Z 函数用于计算每个后缀与原串的最长公共前缀。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
def z_function(s):
n = len(s)
z = [0] * n
z[0] = n
l, r = 0, 0

for i in range(1, n):
if i < r:
z[i] = min(r - i, z[i - l])
while i + z[i] < n and s[z[i]] == s[i + z[i]]:
z[i] += 1
if i + z[i] > r:
l, r = i, i + z[i]

return z

Manacher 算法

Manacher 算法用于在线性时间内找出所有回文子串。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
def manacher(s):
t = '#' + '#'.join(s) + '#'
n = len(t)
p = [0] * n
center = right = 0

for i in range(n):
if i < right:
p[i] = min(right - i, p[2 * center - i])

while i - p[i] - 1 >= 0 and i + p[i] + 1 < n and t[i - p[i] - 1] == t[i + p[i] + 1]:
p[i] += 1

if i + p[i] > right:
center, right = i, i + p[i]

return p

def longest_palindrome_manacher(s):
p = manacher(s)
max_len = max(p)
center = p.index(max_len)
start = (center - max_len) // 2
return s[start:start + max_len]

后缀数组

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
def suffix_array(s):
n = len(s)
k = 1
rank = [ord(c) for c in s] + [-1]
sa = list(range(n))
tmp = [0] * n

while k < n:
sa.sort(key=lambda x: (rank[x], rank[x + k] if x + k < n else -1))
tmp[sa[0]] = 0
for i in range(1, n):
prev, curr = sa[i - 1], sa[i]
tmp[curr] = tmp[prev] + (
(rank[prev], rank[prev + k] if prev + k < n else -1) <
(rank[curr], rank[curr + k] if curr + k < n else -1)
)
rank = tmp[:]
k *= 2

return sa

def lcp_array(s, sa):
n = len(s)
rank = [0] * n
for i in range(n):
rank[sa[i]] = i

lcp = [0] * (n - 1)
h = 0
for i in range(n):
if rank[i] > 0:
j = sa[rank[i] - 1]
while i + h < n and j + h < n and s[i + h] == s[j + h]:
h += 1
lcp[rank[i] - 1] = h
if h > 0:
h -= 1

return lcp

AC 自动机

AC 自动机用于多模式串匹配。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
from collections import deque

class AhoCorasick:
def __init__(self):
self.trie = [{}]
self.fail = [0]
self.output = [[]]

def add_pattern(self, pattern, index):
node = 0
for ch in pattern:
if ch not in self.trie[node]:
self.trie.append({})
self.fail.append(0)
self.output.append([])
self.trie[node][ch] = len(self.trie) - 1
node = self.trie[node][ch]
self.output[node].append(index)

def build(self):
queue = deque()
for ch, node in self.trie[0].items():
queue.append(node)

while queue:
curr = queue.popleft()
for ch, next_node in self.trie[curr].items():
fail = self.fail[curr]
while fail and ch not in self.trie[fail]:
fail = self.fail[fail]
self.fail[next_node] = self.trie[fail].get(ch, 0)
self.output[next_node].extend(self.output[self.fail[next_node]])
queue.append(next_node)

def search(self, text):
node = 0
results = []

for i, ch in enumerate(text):
while node and ch not in self.trie[node]:
node = self.fail[node]
node = self.trie[node].get(ch, 0)
for pattern_idx in self.output[node]:
results.append((i, pattern_idx))

return results

二十四、计算几何基础

点与向量

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
class Point:
def __init__(self, x, y):
self.x = x
self.y = y

def __add__(self, other):
return Point(self.x + other.x, self.y + other.y)

def __sub__(self, other):
return Point(self.x - other.x, self.y - other.y)

def __mul__(self, scalar):
return Point(self.x * scalar, self.y * scalar)

def dot(self, other):
return self.x * other.x + self.y * other.y

def cross(self, other):
return self.x * other.y - self.y * other.x

def distance(self, other):
return ((self.x - other.x) ** 2 + (self.y - other.y) ** 2) ** 0.5

def length(self):
return (self.x ** 2 + self.y ** 2) ** 0.5

线段相交判断

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
def direction(p, q, r):
return (q - p).cross(r - p)

def on_segment(p, q, r):
return min(p.x, r.x) <= q.x <= max(p.x, r.x) and min(p.y, r.y) <= q.y <= max(p.y, r.y)

def segments_intersect(p1, p2, p3, p4):
d1 = direction(p3, p4, p1)
d2 = direction(p3, p4, p2)
d3 = direction(p1, p2, p3)
d4 = direction(p1, p2, p4)

if ((d1 > 0 and d2 < 0) or (d1 < 0 and d2 > 0)) and ((d3 > 0 and d4 < 0) or (d3 < 0 and d4 > 0)):
return True

if d1 == 0 and on_segment(p3, p1, p4):
return True
if d2 == 0 and on_segment(p3, p2, p4):
return True
if d3 == 0 and on_segment(p1, p3, p2):
return True
if d4 == 0 and on_segment(p1, p4, p2):
return True

return False

凸包(Graham 扫描)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
def convex_hull(points):
points = sorted(set(points), key=lambda p: (p.x, p.y))

if len(points) <= 1:
return points

def cross(o, a, b):
return (a.x - o.x) * (b.y - o.y) - (a.y - o.y) * (b.x - o.x)

lower = []
for p in points:
while len(lower) >= 2 and cross(lower[-2], lower[-1], p) <= 0:
lower.pop()
lower.append(p)

upper = []
for p in reversed(points):
while len(upper) >= 2 and cross(upper[-2], upper[-1], p) <= 0:
upper.pop()
upper.append(p)

return lower[:-1] + upper[:-1]

多边形面积

1
2
3
4
5
6
7
8
def polygon_area(points):
n = len(points)
area = 0
for i in range(n):
j = (i + 1) % n
area += points[i].x * points[j].y
area -= points[j].x * points[i].y
return abs(area) / 2

二十五、博弈论基础

Nim 游戏

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
def nim_game(piles):
xor_sum = 0
for pile in piles:
xor_sum ^= pile
return xor_sum != 0

def nim_move(piles):
xor_sum = 0
for pile in piles:
xor_sum ^= pile

for i, pile in enumerate(piles):
if pile ^ xor_sum < pile:
return i, pile - (pile ^ xor_sum)

return -1, 0

SG 函数

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
def sg_function(n, moves):
sg = [0] * (n + 1)

for i in range(1, n + 1):
reachable = set()
for move in moves:
if i >= move:
reachable.add(sg[i - move])

g = 0
while g in reachable:
g += 1
sg[i] = g

return sg

def combined_sg(sg_values):
result = 0
for sg in sg_values:
result ^= sg
return result != 0

巴什博弈

1
2
def bash_game(n, m):
return n % (m + 1) != 0

威佐夫博弈

1
2
3
4
5
6
7
8
9
10
import math

def wythoff_game(a, b):
if a > b:
a, b = b, a

k = b - a
golden_ratio = (1 + math.sqrt(5)) / 2

return a != int(k * golden_ratio)

二十六、常用技巧与黑科技

位运算技巧

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
def bit_tricks():
x = 10

print(x & (x - 1)) # 清除最低位的 1
print(x & -x) # 取出最低位的 1
print(x | (x + 1)) # 把最低位的 0 变成 1
print(x & (x + 1)) # 清除最低位的连续 1
print((x ^ (x - 1)) >> 1) # 取出最低位 1 及后面的 0

print(bin(x).count('1')) # 统计 1 的个数

print(1 << n) # 2 的 n 次方
print((1 << n) - 1) # n 个 1

def lowbit(x):
return x & -x

def count_bits(n):
count = 0
while n:
n &= n - 1
count += 1
return count

随机化技巧

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import random

def shuffle_array(arr):
for i in range(len(arr) - 1, 0, -1):
j = random.randint(0, i)
arr[i], arr[j] = arr[j], arr[i]
return arr

def quick_select(arr, k):
def partition(left, right):
pivot = arr[right]
i = left
for j in range(left, right):
if arr[j] <= pivot:
arr[i], arr[j] = arr[j], arr[i]
i += 1
arr[i], arr[right] = arr[right], arr[i]
return i

left, right = 0, len(arr) - 1
while left <= right:
pos = partition(left, right)
if pos == k:
return arr[pos]
elif pos < k:
left = pos + 1
else:
right = pos - 1
return -1

离散化

1
2
3
4
5
6
7
8
9
10
11
def discretize(arr):
sorted_unique = sorted(set(arr))
mapping = {v: i for i, v in enumerate(sorted_unique)}
return [mapping[v] for v in arr]

def discretize_with_rank(arr):
sorted_arr = sorted(enumerate(arr), key=lambda x: x[1])
ranks = [0] * len(arr)
for rank, (idx, _) in enumerate(sorted_arr):
ranks[idx] = rank + 1
return ranks

坐标压缩

1
2
3
4
5
6
7
8
def compress_coordinates(points):
x_coords = sorted(set(p[0] for p in points))
y_coords = sorted(set(p[1] for p in points))

x_map = {x: i for i, x in enumerate(x_coords)}
y_map = {y: i for i, y in enumerate(y_coords)}

return [(x_map[p[0]], y_map[p[1]]) for p in points]

滑动窗口技巧

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
def sliding_window_techniques():
pass

def max_subarray_with_constraint(nums, k):
from collections import deque
dq = deque()
max_sum = 0
current_sum = 0

for i, num in enumerate(nums):
current_sum += num
dq.append(num)

while len(dq) > k:
current_sum -= dq.popleft()

if len(dq) == k:
max_sum = max(max_sum, current_sum)

return max_sum

def longest_substring_with_k_distinct(s, k):
from collections import defaultdict
char_count = defaultdict(int)
left = 0
max_len = 0

for right, ch in enumerate(s):
char_count[ch] += 1

while len(char_count) > k:
left_char = s[left]
char_count[left_char] -= 1
if char_count[left_char] == 0:
del char_count[left_char]
left += 1

max_len = max(max_len, right - left + 1)

return max_len

二十七、比赛策略与心态

时间分配策略

阶段时间任务
第一遍30 分钟快速浏览所有题目,标记难度
填空题60 分钟完成所有结果填空题
简单编程90 分钟完成简单编程大题
难题攻坚60 分钟尝试中等和困难题目
检查30 分钟检查代码,边界测试

得分策略

  1. 填空题优先:填空题分值高,难度相对较低
  2. 部分分意识:编程题按测试点给分,能拿多少拿多少
  3. 暴力也是分:小数据范围可以用暴力解法
  4. 特殊性质:利用题目特殊性质简化问题

调试技巧

1
2
3
4
5
6
7
8
9
10
11
12
13
def debug_print(*args, **kwargs):
import sys
print("[DEBUG]", *args, **kwargs, file=sys.stderr)

def assert_condition(condition, message=""):
if not condition:
raise AssertionError(f"Assertion failed: {message}")

def test_with_cases(func, test_cases):
for i, (input_data, expected) in enumerate(test_cases):
result = func(*input_data) if isinstance(input_data, tuple) else func(input_data)
status = "PASS" if result == expected else "FAIL"
print(f"Test {i + 1}: {status} (got {result}, expected {expected})")

常见失误避免

  1. 忘记初始化:循环内变量要重置
  2. 数组越界:检查索引范围
  3. 整数溢出:Python 无此问题,但要注意中间计算
  4. 精度问题:浮点数比较用容差
  5. 边界条件:空输入、单个元素、最大最小值

二十八、代码模板速查

输入模板

1
2
3
4
5
6
7
8
import sys
input = sys.stdin.readline

n = int(input())
a, b = map(int, input().split())
arr = list(map(int, input().split()))
matrix = [list(map(int, input().split())) for _ in range(n)]
string = input().strip()

快速排序

1
2
3
4
5
6
7
8
def quick_sort(arr):
if len(arr) <= 1:
return arr
pivot = arr[len(arr) // 2]
left = [x for x in arr if x < pivot]
middle = [x for x in arr if x == pivot]
right = [x for x in arr if x > pivot]
return quick_sort(left) + middle + quick_sort(right)

并查集

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
class UnionFind:
def __init__(self, n):
self.parent = list(range(n))
self.rank = [0] * n

def find(self, x):
if self.parent[x] != x:
self.parent[x] = self.find(self.parent[x])
return self.parent[x]

def union(self, x, y):
px, py = self.find(x), self.find(y)
if px == py:
return False
if self.rank[px] < self.rank[py]:
px, py = py, px
self.parent[py] = px
if self.rank[px] == self.rank[py]:
self.rank[px] += 1
return True

线段树

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
class SegmentTree:
def __init__(self, data):
self.n = len(data)
self.tree = [0] * (4 * self.n)
self._build(data, 1, 0, self.n - 1)

def _build(self, data, node, l, r):
if l == r:
self.tree[node] = data[l]
else:
mid = (l + r) // 2
self._build(data, node * 2, l, mid)
self._build(data, node * 2 + 1, mid + 1, r)
self.tree[node] = self.tree[node * 2] + self.tree[node * 2 + 1]

def query(self, ql, qr):
return self._query(1, 0, self.n - 1, ql, qr)

def _query(self, node, l, r, ql, qr):
if ql > r or qr < l:
return 0
if ql <= l and r <= qr:
return self.tree[node]
mid = (l + r) // 2
return self._query(node * 2, l, mid, ql, qr) + self._query(node * 2 + 1, mid + 1, r, ql, qr)

Dijkstra

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
import heapq

def dijkstra(graph, start, n):
dist = [float('inf')] * n
dist[start] = 0
pq = [(0, start)]

while pq:
d, u = heapq.heappop(pq)
if d > dist[u]:
continue
for v, w in graph[u]:
if dist[u] + w < dist[v]:
dist[v] = dist[u] + w
heapq.heappush(pq, (dist[v], v))

return dist

BFS 模板

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
from collections import deque

def bfs(graph, start, n):
visited = [False] * n
queue = deque([start])
visited[start] = True

while queue:
u = queue.popleft()
for v in graph[u]:
if not visited[v]:
visited[v] = True
queue.append(v)

return visited

DFS 模板

1
2
3
4
5
def dfs(graph, u, visited):
visited[u] = True
for v in graph[u]:
if not visited[v]:
dfs(graph, v, visited)

二分查找

1
2
3
4
5
6
7
8
9
10
11
def binary_search(arr, target):
left, right = 0, len(arr) - 1
while left <= right:
mid = (left + right) // 2
if arr[mid] == target:
return mid
elif arr[mid] < target:
left = mid + 1
else:
right = mid - 1
return -1

前缀和

1
2
3
4
5
6
7
8
9
def prefix_sum(arr):
n = len(arr)
prefix = [0] * (n + 1)
for i in range(n):
prefix[i + 1] = prefix[i] + arr[i]
return prefix

def range_sum(prefix, l, r):
return prefix[r + 1] - prefix[l]

二十九、网络流与匹配基础

最大流问题

Ford-Fulkerson 算法

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
from collections import deque

def max_flow(capacity, source, sink, n):
graph = [[] for _ in range(n)]
for u in range(n):
for v in range(n):
if capacity[u][v] > 0 or capacity[v][u] > 0:
graph[u].append(v)
graph[v].append(u)

def bfs():
parent = [-1] * n
parent[source] = source
queue = deque([(source, float('inf'))])

while queue:
u, flow = queue.popleft()
for v in graph[u]:
if parent[v] == -1 and capacity[u][v] > 0:
parent[v] = u
new_flow = min(flow, capacity[u][v])
if v == sink:
return parent, new_flow
queue.append((v, new_flow))

return parent, 0

total_flow = 0

while True:
parent, flow = bfs()
if flow == 0:
break

total_flow += flow
v = sink
while v != source:
u = parent[v]
capacity[u][v] -= flow
capacity[v][u] += flow
v = u

return total_flow

Dinic 算法

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
from collections import deque

class Dinic:
def __init__(self, n):
self.n = n
self.graph = [[] for _ in range(n)]

def add_edge(self, u, v, cap):
self.graph[u].append([v, cap, len(self.graph[v])])
self.graph[v].append([u, 0, len(self.graph[u]) - 1])

def bfs(self, s, t):
self.level = [-1] * self.n
queue = deque([s])
self.level[s] = 0

while queue:
u = queue.popleft()
for v, cap, rev in self.graph[u]:
if cap > 0 and self.level[v] == -1:
self.level[v] = self.level[u] + 1
queue.append(v)

return self.level[t] != -1

def dfs(self, u, t, f):
if u == t:
return f

for i in range(self.it[u], len(self.graph[u])):
self.it[u] = i
v, cap, rev = self.graph[u][i]
if cap > 0 and self.level[v] == self.level[u] + 1:
ret = self.dfs(v, t, min(f, cap))
if ret > 0:
self.graph[u][i][1] -= ret
self.graph[v][rev][1] += ret
return ret

return 0

def max_flow(self, s, t):
flow = 0
while self.bfs(s, t):
self.it = [0] * self.n
while True:
f = self.dfs(s, t, float('inf'))
if f == 0:
break
flow += f
return flow

二分图匹配

匈牙利算法

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
def hungarian(graph, n_left, n_right):
match_right = [-1] * n_right

def dfs(u, visited):
for v in graph[u]:
if visited[v]:
continue
visited[v] = True
if match_right[v] == -1 or dfs(match_right[v], visited):
match_right[v] = u
return True
return False

result = 0
for u in range(n_left):
visited = [False] * n_right
if dfs(u, visited):
result += 1

return result

Hopcroft-Karp 算法

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
from collections import deque

def hopcroft_karp(graph, n_left, n_right):
match_left = [-1] * n_left
match_right = [-1] * n_right
dist = [0] * n_left

def bfs():
queue = deque()
for u in range(n_left):
if match_left[u] == -1:
dist[u] = 0
queue.append(u)
else:
dist[u] = float('inf')

dist_null = float('inf')
while queue:
u = queue.popleft()
if dist[u] < dist_null:
for v in graph[u]:
if match_right[v] == -1:
dist_null = dist[u] + 1
elif dist[match_right[v]] == float('inf'):
dist[match_right[v]] = dist[u] + 1
queue.append(match_right[v])

return dist_null != float('inf')

def dfs(u):
for v in graph[u]:
if match_right[v] == -1 or (dist[match_right[v]] == dist[u] + 1 and dfs(match_right[v])):
match_left[u] = v
match_right[v] = u
return True
dist[u] = float('inf')
return False

result = 0
while bfs():
for u in range(n_left):
if match_left[u] == -1 and dfs(u):
result += 1

return result

三十、高级数据结构进阶

平衡二叉搜索树(Treap)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import random

class TreapNode:
def __init__(self, key):
self.key = key
self.priority = random.randint(1, 10**9)
self.left = None
self.right = None
self.size = 1

def get_size(node):
return node.size if node else 0

def update_size(node):
if node:
node.size = 1 + get_size(node.left) + get_size(node.right)

def rotate_right(p):
q = p.left
p.left = q.right
q.right = p
update_size(p)
update_size(q)
return q

def rotate_left(p):
q = p.right
p.right = q.left
q.left = p
update_size(p)
update_size(q)
return q

def insert(node, key):
if not node:
return TreapNode(key)
if key <= node.key:
node.left = insert(node.left, key)
if node.left.priority < node.priority:
node = rotate_right(node)
else:
node.right = insert(node.right, key)
if node.right.priority < node.priority:
node = rotate_left(node)
update_size(node)
return node

def delete(node, key):
if not node:
return None
if key < node.key:
node.left = delete(node.left, key)
elif key > node.key:
node.right = delete(node.right, key)
else:
if not node.left:
return node.right
if not node.right:
return node.left
if node.left.priority < node.right.priority:
node = rotate_right(node)
node.right = delete(node.right, key)
else:
node = rotate_left(node)
node.left = delete(node.left, key)
update_size(node)
return node

def kth_element(node, k):
if not node:
return None
left_size = get_size(node.left)
if k == left_size + 1:
return node.key
elif k <= left_size:
return kth_element(node.left, k)
else:
return kth_element(node.right, k - left_size - 1)

跳表

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import random

class SkipListNode:
def __init__(self, key=None, level=0):
self.key = key
self.forward = [None] * (level + 1)

class SkipList:
def __init__(self, max_level=16, p=0.5):
self.max_level = max_level
self.p = p
self.level = 0
self.header = SkipListNode(level=max_level)

def random_level(self):
level = 0
while random.random() < self.p and level < self.max_level:
level += 1
return level

def insert(self, key):
update = [None] * (self.max_level + 1)
current = self.header

for i in range(self.level, -1, -1):
while current.forward[i] and current.forward[i].key < key:
current = current.forward[i]
update[i] = current

new_level = self.random_level()
if new_level > self.level:
for i in range(self.level + 1, new_level + 1):
update[i] = self.header
self.level = new_level

new_node = SkipListNode(key, new_level)
for i in range(new_level + 1):
new_node.forward[i] = update[i].forward[i]
update[i].forward[i] = new_node

def search(self, key):
current = self.header
for i in range(self.level, -1, -1):
while current.forward[i] and current.forward[i].key < key:
current = current.forward[i]
current = current.forward[0]
return current and current.key == key

def delete(self, key):
update = [None] * (self.max_level + 1)
current = self.header

for i in range(self.level, -1, -1):
while current.forward[i] and current.forward[i].key < key:
current = current.forward[i]
update[i] = current

current = current.forward[0]
if current and current.key == key:
for i in range(self.level + 1):
if update[i].forward[i] != current:
break
update[i].forward[i] = current.forward[i]

while self.level > 0 and not self.header.forward[self.level]:
self.level -= 1

可持久化数据结构

可持久化线段树

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
class PersistentSegmentTree:
def __init__(self, arr):
self.n = len(arr)
self.roots = []
self.nodes = []
self.roots.append(self._build(arr, 0, self.n - 1))

def _new_node(self, val=0):
self.nodes.append({'val': val, 'left': None, 'right': None})
return len(self.nodes) - 1

def _build(self, arr, l, r):
node = self._new_node()
if l == r:
self.nodes[node]['val'] = arr[l]
else:
mid = (l + r) // 2
self.nodes[node]['left'] = self._build(arr, l, mid)
self.nodes[node]['right'] = self._build(arr, mid + 1, r)
self.nodes[node]['val'] = self.nodes[self.nodes[node]['left']]['val'] + self.nodes[self.nodes[node]['right']]['val']
return node

def _update(self, prev, l, r, idx, val):
node = self._new_node()
if l == r:
self.nodes[node]['val'] = val
else:
mid = (l + r) // 2
if idx <= mid:
self.nodes[node]['left'] = self._update(self.nodes[prev]['left'], l, mid, idx, val)
self.nodes[node]['right'] = self.nodes[prev]['right']
else:
self.nodes[node]['left'] = self.nodes[prev]['left']
self.nodes[node]['right'] = self._update(self.nodes[prev]['right'], mid + 1, r, idx, val)
self.nodes[node]['val'] = self.nodes[self.nodes[node]['left']]['val'] + self.nodes[self.nodes[node]['right']]['val']
return node

def update(self, idx, val):
new_root = self._update(self.roots[-1], 0, self.n - 1, idx, val)
self.roots.append(new_root)

def _query(self, node, l, r, ql, qr):
if ql > r or qr < l:
return 0
if ql <= l and r <= qr:
return self.nodes[node]['val']
mid = (l + r) // 2
return self._query(self.nodes[node]['left'], l, mid, ql, qr) + self._query(self.nodes[node]['right'], mid + 1, r, ql, qr)

def query(self, version, ql, qr):
return self._query(self.roots[version], 0, self.n - 1, ql, qr)

三十一、更多真题深度解析

2018 年真题:第几天

题目:2000 年的 1 月 1 日,是那一年的第 1 天。那么,2000 年的 5 月 4 日,是那一年的第几天?

解题思路

1
2
3
4
5
from datetime import datetime

start = datetime(2000, 1, 1)
end = datetime(2000, 5, 4)
print((end - start).days + 1)

2018 年真题:明码

题目:汉字的字形存在于字库中,即便在今天,16 点阵的字库也仍然使用广泛。16 点阵的字库把每个汉字看成是 16x16 个像素信息。把这些信息记录下来就是字节。一个字节可以存储 8 位信息,用 32 个字节就可以存一个汉字的字形了。把每个字节转为 2 进制表示,1 表示墨迹,0 表示底色。每行 2 个字节,一共 16 行,布局是:第 1 字节,第 2 字节,第 3 字节,第 4 字节…

解题思路

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
def bytes_to_binary(byte_val):
binary = bin(byte_val)[2:].zfill(8)
return binary

def decode_character(data):
result = []
for i in range(0, len(data), 2):
row = bytes_to_binary(data[i]) + bytes_to_binary(data[i + 1])
row_str = ''.join(['*' if b == '1' else ' ' for b in row])
result.append(row_str)
return '\n'.join(result)

data = [
4, 0, 4, 0, 4, 0, 4, 32, 4, 0, 4, 0, 4, 0, 68, 0,
4, 3, 252, 0, 36, 4, 3, 252, 32, 36, 32, 36, 32, 36, 32,
36, 32, 36, 32, 36, 32, 36, 32, 36, 32, 36, 32, 164, 32, 36, 32,
36, 32, 36, 32, 36, 32, 36, 32, 36, 32, 36, 32, 4, 0, 4, 0
]

print(decode_character(data))

2017 年真题:迷宫

题目:X 星球的居民并不友好,威威刚踏上 X 星球就被抓进了迷宫。迷宫是一个 10x10 的方格,每个格子里有一个数字。威威需要从左上角走到右下角,每次只能向右或向下走,求经过的数字之和最小的路径。

解题思路

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
def min_path_sum(grid):
m, n = len(grid), len(grid[0])
dp = [[0] * n for _ in range(m)]
dp[0][0] = grid[0][0]

for j in range(1, n):
dp[0][j] = dp[0][j - 1] + grid[0][j]

for i in range(1, m):
dp[i][0] = dp[i - 1][0] + grid[i][0]

for i in range(1, m):
for j in range(1, n):
dp[i][j] = min(dp[i - 1][j], dp[i][j - 1]) + grid[i][j]

return dp[m - 1][n - 1]

grid = [
[1, 4, 3, 8, 2, 5, 9, 6, 7, 1],
[2, 6, 5, 4, 1, 3, 8, 9, 7, 2],
[3, 1, 2, 7, 9, 5, 4, 6, 8, 3],
[4, 9, 8, 5, 6, 2, 1, 3, 7, 4],
[5, 2, 7, 1, 8, 4, 6, 9, 3, 5],
[6, 3, 4, 9, 7, 1, 5, 2, 8, 6],
[7, 8, 1, 6, 3, 9, 2, 4, 5, 7],
[8, 5, 9, 2, 4, 7, 3, 1, 6, 8],
[9, 7, 6, 3, 5, 8, 1, 2, 4, 9],
[1, 2, 3, 4, 5, 6, 7, 8, 9, 1]
]

print(min_path_sum(grid))

2016 年真题:煤球数目

题目:有一堆煤球,堆成三角棱锥形。具体:第一层放 1 个,第二层 3 个(排列成三角形),第三层 6 个(排列成三角形),第四层 10 个(排列成三角形),… 如果一共有 100 层,共有多少个煤球?

解题思路

1
2
3
4
5
6
7
8
9
def coal_balls(n):
total = 0
current = 0
for i in range(1, n + 1):
current += i
total += current
return total

print(coal_balls(100))

2015 年真题:奖券数目

题目:有些人很迷信数字,比如带”4”的数字,认为和”死”谐音,就觉得不吉利。虽然这些说法纯属无稽之谈,但有时还要迎合大众的需求。某抽奖活动的奖券号码是 5 位数(10000-99999),要求其中不出现”4”,请求出一共有多少种可能的号码。

解题思路

1
2
3
4
5
6
7
8
9
10
11
12
13
def count_tickets():
count = 0
for num in range(10000, 100000):
if '4' not in str(num):
count += 1
return count

print(count_tickets())

def count_tickets_math():
return 8 * 9 * 9 * 9 * 9

print(count_tickets_math())

三十二、数学专题进阶

莫比乌斯反演

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
def mobius_sieve(n):
mu = [1] * (n + 1)
is_prime = [True] * (n + 1)
primes = []

for i in range(2, n + 1):
if is_prime[i]:
primes.append(i)
mu[i] = -1
for p in primes:
if i * p > n:
break
is_prime[i * p] = False
if i % p == 0:
mu[i * p] = 0
break
else:
mu[i * p] = -mu[i]

return mu

def sum_gcd(n):
mu = mobius_sieve(n)
result = 0

for d in range(1, n + 1):
sum_mu = 0
for k in range(1, n // d + 1):
sum_mu += mu[k] * (n // (d * k)) ** 2
result += d * sum_mu

return result

卢卡斯定理

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
def lucas(n, m, p):
if m == 0:
return 1
return (lucas(n // p, m // p, p) * comb(n % p, m % p, p)) % p

def comb(n, m, p):
if m > n:
return 0
if m == 0 or m == n:
return 1

fact = [1] * (p + 1)
for i in range(1, p + 1):
fact[i] = (fact[i - 1] * i) % p

return (fact[n] * pow(fact[m], p - 2, p) * pow(fact[n - m], p - 2, p)) % p

高斯消元

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
def gaussian_elimination(matrix, n):
for i in range(n):
max_row = i
for k in range(i + 1, n):
if abs(matrix[k][i]) > abs(matrix[max_row][i]):
max_row = k
matrix[i], matrix[max_row] = matrix[max_row], matrix[i]

if abs(matrix[i][i]) < 1e-10:
continue

for k in range(i + 1, n):
factor = matrix[k][i] / matrix[i][i]
for j in range(i, n + 1):
matrix[k][j] -= factor * matrix[i][j]

solution = [0] * n
for i in range(n - 1, -1, -1):
solution[i] = matrix[i][n]
for j in range(i + 1, n):
solution[i] -= matrix[i][j] * solution[j]
if abs(matrix[i][i]) > 1e-10:
solution[i] /= matrix[i][i]

return solution

矩阵行列式

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
def determinant(matrix, n, mod=None):
matrix = [row[:] for row in matrix]
det = 1

for i in range(n):
if matrix[i][i] == 0:
for j in range(i + 1, n):
if matrix[j][i] != 0:
matrix[i], matrix[j] = matrix[j], matrix[i]
det = -det
break

if matrix[i][i] == 0:
return 0

for j in range(i + 1, n):
if matrix[j][i] != 0:
ratio = matrix[j][i] / matrix[i][i]
if mod:
ratio = matrix[j][i] * pow(matrix[i][i], mod - 2, mod) % mod
for k in range(i, n):
matrix[j][k] -= ratio * matrix[i][k]
if mod:
matrix[j][k] %= mod

for i in range(n):
det *= matrix[i][i]
if mod:
det %= mod

return det

FFT 快速傅里叶变换

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import cmath

def fft(a, invert):
n = len(a)
j = 0
for i in range(1, n):
bit = n >> 1
while j & bit:
j ^= bit
bit >>= 1
j ^= bit
if i < j:
a[i], a[j] = a[j], a[i]

length = 2
while length <= n:
ang = 2 * cmath.pi / length * (-1 if invert else 1)
wlen = cmath.exp(1j * ang)
for i in range(0, n, length):
w = 1
for j in range(i, i + length // 2):
u = a[j]
v = a[j + length // 2] * w
a[j] = u + v
a[j + length // 2] = u - v
w *= wlen
length *= 2

if invert:
for i in range(n):
a[i] /= n

def multiply_polynomials(a, b):
n = 1
while n < len(a) + len(b):
n *= 2

fa = a + [0] * (n - len(a))
fb = b + [0] * (n - len(b))

fft(fa, False)
fft(fb, False)

for i in range(n):
fa[i] *= fb[i]

fft(fa, True)

return [int(round(x.real)) for x in fa[:len(a) + len(b) - 1]]

三十三、搜索算法进阶

A* 算法

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import heapq

def a_star(grid, start, end):
def heuristic(a, b):
return abs(a[0] - b[0]) + abs(a[1] - b[1])

rows, cols = len(grid), len(grid[0])
open_set = [(0, start)]
came_from = {}
g_score = {start: 0}
f_score = {start: heuristic(start, end)}

while open_set:
_, current = heapq.heappop(open_set)

if current == end:
path = []
while current in came_from:
path.append(current)
current = came_from[current]
path.append(start)
return path[::-1]

for dx, dy in [(0, 1), (1, 0), (0, -1), (-1, 0)]:
neighbor = (current[0] + dx, current[1] + dy)
if 0 <= neighbor[0] < rows and 0 <= neighbor[1] < cols and grid[neighbor[0]][neighbor[1]] == 0:
tentative_g = g_score[current] + 1

if neighbor not in g_score or tentative_g < g_score[neighbor]:
came_from[neighbor] = current
g_score[neighbor] = tentative_g
f_score[neighbor] = tentative_g + heuristic(neighbor, end)
heapq.heappush(open_set, (f_score[neighbor], neighbor))

return None

IDA* 算法

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
def ida_star(start, goal, heuristic, get_neighbors):
def search(node, g, bound):
f = g + heuristic(node, goal)
if f > bound:
return f
if node == goal:
return -g

min_bound = float('inf')
for neighbor, cost in get_neighbors(node):
result = search(neighbor, g + cost, bound)
if result < 0:
return result
min_bound = min(min_bound, result)

return min_bound

bound = heuristic(start, goal)
while True:
result = search(start, 0, bound)
if result < 0:
return -result
if result == float('inf'):
return None
bound = result

双向 BFS

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
from collections import deque

def bidirectional_bfs(graph, start, end):
if start == end:
return 0

visited_start = {start: 0}
visited_end = {end: 0}
queue_start = deque([start])
queue_end = deque([end])

while queue_start or queue_end:
if queue_start:
node = queue_start.popleft()
for neighbor in graph[node]:
if neighbor in visited_end:
return visited_start[node] + 1 + visited_end[neighbor]
if neighbor not in visited_start:
visited_start[neighbor] = visited_start[node] + 1
queue_start.append(neighbor)

if queue_end:
node = queue_end.popleft()
for neighbor in graph[node]:
if neighbor in visited_start:
return visited_end[node] + 1 + visited_start[neighbor]
if neighbor not in visited_end:
visited_end[neighbor] = visited_end[node] + 1
queue_end.append(neighbor)

return -1

迭代加深搜索

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
def iterative_deepening_dfs(graph, start, goal, max_depth):
def dfs(node, depth, visited):
if node == goal:
return [node]
if depth == 0:
return None

visited.add(node)
for neighbor in graph[node]:
if neighbor not in visited:
result = dfs(neighbor, depth - 1, visited)
if result:
return [node] + result
visited.remove(node)
return None

for depth in range(max_depth + 1):
result = dfs(start, depth, set())
if result:
return result

return None

三十四、Python 高级特性与优化

生成器与迭代器

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
def fibonacci_generator():
a, b = 0, 1
while True:
yield a
a, b = b, a + b

def range_generator(start, end, step=1):
while start < end:
yield start
start += step

def batch_generator(data, batch_size):
for i in range(0, len(data), batch_size):
yield data[i:i + batch_size]

class Counter:
def __init__(self, start, end):
self.current = start
self.end = end

def __iter__(self):
return self

def __next__(self):
if self.current >= self.end:
raise StopIteration
self.current += 1
return self.current - 1

装饰器

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import time
from functools import wraps

def timer(func):
@wraps(func)
def wrapper(*args, **kwargs):
start = time.time()
result = func(*args, **kwargs)
end = time.time()
print(f"{func.__name__} took {end - start:.4f} seconds")
return result
return wrapper

def memoize(func):
cache = {}
@wraps(func)
def wrapper(*args):
if args not in cache:
cache[args] = func(*args)
return cache[args]
return wrapper

def retry(max_attempts=3):
def decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
for attempt in range(max_attempts):
try:
return func(*args, **kwargs)
except Exception as e:
if attempt == max_attempts - 1:
raise e
time.sleep(1)
return wrapper
return decorator

上下文管理器

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
from contextlib import contextmanager

@contextmanager
def timer_context(name):
start = time.time()
yield
end = time.time()
print(f"{name} took {end - start:.4f} seconds")

@contextmanager
def file_handler(filename, mode='r'):
f = open(filename, mode)
try:
yield f
finally:
f.close()

class Timer:
def __enter__(self):
self.start = time.time()
return self

def __exit__(self, *args):
self.end = time.time()
print(f"Elapsed: {self.end - self.start:.4f} seconds")

性能优化技巧

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import sys
from io import StringIO

def fast_io():
input = sys.stdin.readline
output = StringIO()

n = int(input())
for _ in range(n):
line = input().strip()
output.write(line + '\n')

sys.stdout.write(output.getvalue())

def use_local_variables():
data = list(range(10000))
result = 0
append = result.__add__
for x in data:
result = append(x)
return result

def avoid_global_lookups():
from math import sqrt
data = list(range(10000))
local_sqrt = sqrt
return [local_sqrt(x) for x in data]

def use_builtins():
data = list(range(10000))
return sum(data)

三十五、附录

常用数学公式

公式名称公式
等差数列求和S = n(a₁ + aₙ) / 2
等比数列求和S = a₁(qⁿ - 1) / (q - 1)
组合数C(n, k) = n! / (k!(n-k)!)
排列数P(n, k) = n! / (n-k)!
二项式定理(a+b)ⁿ = Σ C(n,k) aᵏ bⁿ⁻ᵏ
斐波那契通项Fₙ = (φⁿ - ψⁿ) / √5

时间复杂度速查

复杂度101001000100001000001000000
O(n!)
O(2ⁿ)
O(n²)
O(n√n)
O(n log n)
O(n)
O(log n)

Python 内置函数复杂度

操作复杂度
len(list)O(1)
list.append()O(1)
list.pop()O(1)
list.insert(i, x)O(n)
list.remove(x)O(n)
x in listO(n)
x in setO(1)
x in dictO(1)
dict[key]O(1)
sorted(list)O(n log n)

常见错误代码

错误类型示例解决方案
IndexErrorarr[10] 当 len(arr) < 11检查索引范围
KeyErrord[‘key’] 当 key 不存在使用 d.get(‘key’, default)
ValueErrorint(‘abc’)使用 try-except 或 isdigit()
TypeError‘1’ + 1类型转换 str(1) 或 int(‘1’)
ZeroDivisionError1 / 0检查除数是否为 0
RecursionError递归太深设置 sys.setrecursionlimit()

ASCII 码表

字符ASCII字符ASCII字符ASCII
‘0’48‘A’65‘a’97
‘1’49‘B’66‘b’98
‘9’57‘Z’90‘z’122

常用常量

1
2
3
4
5
6
7
8
9
10
import math

PI = math.pi
E = math.e
INF = float('inf')
NEG_INF = float('-inf')

MOD = 10**9 + 7
MAX_INT = 2**31 - 1
MIN_INT = -2**31

三十六、分块与莫队算法

分块思想

分块是一种通过将数据分成若干块来优化查询的数据结构思想。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
class BlockArray:
def __init__(self, arr):
self.arr = arr[:]
self.n = len(arr)
self.block_size = int(self.n ** 0.5) + 1
self.block_count = (self.n + self.block_size - 1) // self.block_size
self.block_max = [float('-inf')] * self.block_count
self.block_sum = [0] * self.block_count

for i in range(self.n):
block_id = i // self.block_size
self.block_max[block_id] = max(self.block_max[block_id], arr[i])
self.block_sum[block_id] += arr[i]

def update(self, idx, val):
block_id = idx // self.block_size
diff = val - self.arr[idx]
self.arr[idx] = val
self.block_sum[block_id] += diff

self.block_max[block_id] = float('-inf')
start = block_id * self.block_size
end = min(start + self.block_size, self.n)
for i in range(start, end):
self.block_max[block_id] = max(self.block_max[block_id], self.arr[i])

def range_sum(self, l, r):
block_l, block_r = l // self.block_size, r // self.block_size
result = 0

if block_l == block_r:
for i in range(l, r + 1):
result += self.arr[i]
else:
for i in range(l, (block_l + 1) * self.block_size):
result += self.arr[i]
for b in range(block_l + 1, block_r):
result += self.block_sum[b]
for i in range(block_r * self.block_size, r + 1):
result += self.arr[i]

return result

def range_max(self, l, r):
block_l, block_r = l // self.block_size, r // self.block_size
result = float('-inf')

if block_l == block_r:
for i in range(l, r + 1):
result = max(result, self.arr[i])
else:
for i in range(l, (block_l + 1) * self.block_size):
result = max(result, self.arr[i])
for b in range(block_l + 1, block_r):
result = max(result, self.block_max[b])
for i in range(block_r * self.block_size, r + 1):
result = max(result, self.arr[i])

return result

普通莫队算法

莫队算法用于离线处理区间查询问题,通过巧妙地排序查询顺序来优化。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
def mo_algorithm(arr, queries):
n = len(arr)
block_size = int(n ** 0.5) + 1

indexed_queries = [(l, r, i) for i, (l, r) in enumerate(queries)]
indexed_queries.sort(key=lambda x: (x[0] // block_size, x[1] if (x[0] // block_size) % 2 == 0 else -x[1]))

def add(pos):
nonlocal current_answer
current_answer += arr[pos]

def remove(pos):
nonlocal current_answer
current_answer -= arr[pos]

current_l, current_r = 0, -1
current_answer = 0
results = [0] * len(queries)

for l, r, idx in indexed_queries:
while current_l > l:
current_l -= 1
add(current_l)
while current_r < r:
current_r += 1
add(current_r)
while current_l < l:
remove(current_l)
current_l += 1
while current_r > r:
remove(current_r)
current_r -= 1
results[idx] = current_answer

return results

带修莫队

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
def mo_with_updates(arr, queries, updates):
n = len(arr)
block_size = int(n ** (2/3)) + 1

current_arr = arr[:]
query_list = []
update_list = []

update_count = 0
for q in queries:
if q[0] == 'Q':
query_list.append((q[1], q[2], update_count, len(query_list)))
else:
update_list.append((q[1], q[2]))
update_count += 1

query_list.sort(key=lambda x: (
x[0] // block_size,
x[1] // block_size,
x[2]
))

def add(pos):
nonlocal current_answer
current_answer += current_arr[pos]

def remove(pos):
nonlocal current_answer
current_answer -= current_arr[pos]

def apply_update(pos, old_val, new_val, current_l, current_r):
nonlocal current_answer
if current_l <= pos <= current_r:
current_answer -= old_val
current_answer += new_val
current_arr[pos] = new_val

current_l, current_r, current_update = 0, -1, 0
current_answer = 0
results = [0] * len(query_list)

for l, r, update_idx, query_idx in query_list:
while current_update < update_idx:
pos, new_val = update_list[current_update]
old_val = current_arr[pos]
apply_update(pos, old_val, new_val, current_l, current_r)
current_update += 1
while current_update > update_idx:
current_update -= 1
pos, new_val = update_list[current_update]
old_val = current_arr[pos]
apply_update(pos, new_val, old_val, current_l, current_r)

while current_l > l:
current_l -= 1
add(current_l)
while current_r < r:
current_r += 1
add(current_r)
while current_l < l:
remove(current_l)
current_l += 1
while current_r > r:
remove(current_r)
current_r -= 1

results[query_idx] = current_answer

return results

树上莫队

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
def tree_mo(adj, values, queries, root=0):
n = len(adj)

euler = []
first = [-1] * n
last = [-1] * n
stack = [(root, -1, 0)]

while stack:
node, parent, state = stack.pop()
if state == 0:
first[node] = len(euler)
euler.append(node)
stack.append((node, parent, 1))
for child in reversed(adj[node]):
if child != parent:
stack.append((child, node, 0))
else:
last[node] = len(euler)
euler.append(node)

def lca(u, v):
if first[u] > first[v]:
u, v = v, u
return u

indexed_queries = []
for i, (u, v) in enumerate(queries):
if first[u] > first[v]:
u, v = v, u
l = lca(u, v)
if l == u:
indexed_queries.append((first[u], first[v], -1, i))
else:
indexed_queries.append((last[u], first[v], l, i))

block_size = int(len(euler) ** 0.5) + 1
indexed_queries.sort(key=lambda x: (x[0] // block_size, x[1]))

count = [0] * n
in_range = [False] * n
current_answer = 0

def toggle(node):
nonlocal current_answer
if in_range[node]:
count[values[node]] -= 1
if count[values[node]] == 0:
current_answer -= 1
else:
if count[values[node]] == 0:
current_answer += 1
count[values[node]] += 1
in_range[node] = not in_range[node]

results = [0] * len(queries)
current_l, current_r = 0, -1

for l, r, lca_node, idx in indexed_queries:
while current_l > l:
current_l -= 1
toggle(euler[current_l])
while current_r < r:
current_r += 1
toggle(euler[current_r])
while current_l < l:
toggle(euler[current_l])
current_l += 1
while current_r > r:
toggle(euler[current_r])
current_r -= 1

if lca_node != -1:
toggle(lca_node)

results[idx] = current_answer

if lca_node != -1:
toggle(lca_node)

return results

三十七、字符串哈希与后缀自动机

字符串哈希

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
class StringHash:
def __init__(self, s, base=131, mod=10**9 + 7):
self.s = s
self.n = len(s)
self.base = base
self.mod = mod
self.power = [1] * (self.n + 1)
self.hash = [0] * (self.n + 1)

for i in range(self.n):
self.power[i + 1] = (self.power[i] * base) % mod
self.hash[i + 1] = (self.hash[i] * base + ord(s[i])) % mod

def get_hash(self, l, r):
return (self.hash[r + 1] - self.hash[l] * self.power[r - l + 1] % self.mod + self.mod) % self.mod

def is_equal(self, l1, r1, l2, r2):
return self.get_hash(l1, r1) == self.get_hash(l2, r2)

def find_pattern(self, pattern):
m = len(pattern)
pattern_hash = 0
for c in pattern:
pattern_hash = (pattern_hash * self.base + ord(c)) % self.mod

results = []
for i in range(self.n - m + 1):
if self.get_hash(i, i + m - 1) == pattern_hash:
results.append(i)

return results

class DoubleHash:
def __init__(self, s):
self.h1 = StringHash(s, 131, 10**9 + 7)
self.h2 = StringHash(s, 137, 10**9 + 9)

def get_hash(self, l, r):
return (self.h1.get_hash(l, r), self.h2.get_hash(l, r))

def is_equal(self, l1, r1, l2, r2):
return self.get_hash(l1, r1) == self.get_hash(l2, r2)

后缀自动机

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
class SuffixAutomaton:
def __init__(self, s):
self.states = [{'link': -1, 'len': 0, 'next': {}}]
self.last = 0

for c in s:
self.extend(c)

def extend(self, c):
p = self.last
curr = len(self.states)
self.states.append({'link': 0, 'len': self.states[p]['len'] + 1, 'next': {}})

while p != -1 and c not in self.states[p]['next']:
self.states[p]['next'][c] = curr
p = self.states[p]['link']

if p != -1:
q = self.states[p]['next'][c]
if self.states[p]['len'] + 1 == self.states[q]['len']:
self.states[curr]['link'] = q
else:
clone = len(self.states)
self.states.append({
'link': self.states[q]['link'],
'len': self.states[p]['len'] + 1,
'next': self.states[q]['next'].copy()
})

while p != -1 and self.states[p]['next'].get(c) == q:
self.states[p]['next'][c] = clone
p = self.states[p]['link']

self.states[q]['link'] = clone
self.states[curr]['link'] = clone

self.last = curr

def count_substrings(self):
result = 0
for i in range(1, len(self.states)):
result += self.states[i]['len'] - self.states[self.states[i]['link']]['len']
return result

def longest_common_substring(self, t):
v = 0
l = 0
best = 0

for c in t:
while v != 0 and c not in self.states[v]['next']:
v = self.states[v]['link']
l = self.states[v]['len']

if c in self.states[v]['next']:
v = self.states[v]['next'][c]
l += 1

best = max(best, l)

return best

def is_substring(self, t):
v = 0
for c in t:
if c not in self.states[v]['next']:
return False
v = self.states[v]['next'][c]
return True

回文自动机

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
class PalindromicTree:
def __init__(self, s):
self.s = s
self.n = len(s)
self.tree = [{'next': {}, 'link': 0, 'len': 0, 'count': 0}]
self.tree.append({'next': {}, 'link': 0, 'len': -1, 'count': 0})

self.last = 1
self.suffix_link = [0] * (self.n + 1)
self.num = 2

for i in range(self.n):
self.add(i)

def get_link(self, node, pos):
while True:
cur_len = self.tree[node]['len']
if pos - 1 - cur_len >= 0 and self.s[pos] == self.s[pos - 1 - cur_len]:
return node
node = self.tree[node]['link']

def add(self, pos):
cur = self.get_link(self.last, pos)
c = self.s[pos]

if c not in self.tree[cur]['next']:
new_node = self.num
self.num += 1
self.tree.append({'next': {}, 'link': 0, 'len': self.tree[cur]['len'] + 2, 'count': 0})
self.tree[cur]['next'][c] = new_node

if self.tree[new_node]['len'] == 1:
self.tree[new_node]['link'] = 1
else:
link_node = self.get_link(self.tree[cur]['link'], pos)
self.tree[new_node]['link'] = self.tree[link_node]['next'][c]

self.last = self.tree[cur]['next'][c]
self.tree[self.last]['count'] += 1

def count_distinct_palindromes(self):
return self.num - 2

def count_all_palindromes(self):
total = 0
for i in range(self.num - 1, 1, -1):
self.tree[self.tree[i]['link']]['count'] += self.tree[i]['count']
total += self.tree[i]['count']
return total

三十八、树上问题专题

树链剖分

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
class HeavyLightDecomposition:
def __init__(self, adj, values, root=0):
self.n = len(adj)
self.adj = adj
self.values = values
self.root = root

self.parent = [-1] * self.n
self.depth = [0] * self.n
self.size = [0] * self.n
self.heavy = [-1] * self.n
self.head = [0] * self.n
self.pos = [0] * self.n
self.order = []

self._dfs_size(root, -1)
self._dfs_decompose(root, root)

self.seg_tree = [0] * (4 * self.n)
for i, node in enumerate(self.order):
self._update(1, 0, self.n - 1, i, values[node])

def _dfs_size(self, node, par):
self.parent[node] = par
self.size[node] = 1
max_size = 0

for child in self.adj[node]:
if child != par:
self.depth[child] = self.depth[node] + 1
self._dfs_size(child, node)
self.size[node] += self.size[child]
if self.size[child] > max_size:
max_size = self.size[child]
self.heavy[node] = child

def _dfs_decompose(self, node, head):
self.head[node] = head
self.pos[node] = len(self.order)
self.order.append(node)

if self.heavy[node] != -1:
self._dfs_decompose(self.heavy[node], head)

for child in self.adj[node]:
if child != self.parent[node] and child != self.heavy[node]:
self._dfs_decompose(child, child)

def _update(self, node, l, r, idx, val):
if l == r:
self.seg_tree[node] = val
else:
mid = (l + r) // 2
if idx <= mid:
self._update(node * 2, l, mid, idx, val)
else:
self._update(node * 2 + 1, mid + 1, r, idx, val)
self.seg_tree[node] = max(self.seg_tree[node * 2], self.seg_tree[node * 2 + 1])

def _query(self, node, l, r, ql, qr):
if ql > r or qr < l:
return float('-inf')
if ql <= l and r <= qr:
return self.seg_tree[node]
mid = (l + r) // 2
return max(self._query(node * 2, l, mid, ql, qr), self._query(node * 2 + 1, mid + 1, r, ql, qr))

def path_query(self, u, v):
result = float('-inf')
while self.head[u] != self.head[v]:
if self.depth[self.head[u]] > self.depth[self.head[v]]:
u, v = v, u
result = max(result, self._query(1, 0, self.n - 1, self.pos[self.head[v]], self.pos[v]))
v = self.parent[self.head[v]]

if self.depth[u] > self.depth[v]:
u, v = v, u
result = max(result, self._query(1, 0, self.n - 1, self.pos[u], self.pos[v]))

return result

def point_update(self, node, val):
self._update(1, 0, self.n - 1, self.pos[node], val)

LCA 最近公共祖先

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
class LCA:
def __init__(self, adj, root=0):
self.n = len(adj)
self.adj = adj
self.root = root
self.log = (self.n).bit_length()

self.depth = [0] * self.n
self.parent = [[-1] * self.n for _ in range(self.log)]

self._dfs(root, -1)
self._build_sparse_table()

def _dfs(self, node, par):
self.parent[0][node] = par
for child in self.adj[node]:
if child != par:
self.depth[child] = self.depth[node] + 1
self._dfs(child, node)

def _build_sparse_table(self):
for k in range(1, self.log):
for v in range(self.n):
if self.parent[k - 1][v] != -1:
self.parent[k][v] = self.parent[k - 1][self.parent[k - 1][v]]

def lca(self, u, v):
if self.depth[u] < self.depth[v]:
u, v = v, u

diff = self.depth[u] - self.depth[v]
for k in range(self.log):
if diff >> k & 1:
u = self.parent[k][u]

if u == v:
return u

for k in range(self.log - 1, -1, -1):
if self.parent[k][u] != self.parent[k][v]:
u = self.parent[k][u]
v = self.parent[k][v]

return self.parent[0][u]

def distance(self, u, v):
return self.depth[u] + self.depth[v] - 2 * self.depth[self.lca(u, v)]

def kth_ancestor(self, node, k):
if k > self.depth[node]:
return -1
for i in range(self.log):
if k >> i & 1:
node = self.parent[i][node]
return node

树的中心与直径

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
def tree_center_and_diameter(adj):
n = len(adj)

def bfs_farthest(start):
dist = [-1] * n
queue = [start]
dist[start] = 0
farthest = start

for node in queue:
for neighbor in adj[node]:
if dist[neighbor] == -1:
dist[neighbor] = dist[node] + 1
queue.append(neighbor)
if dist[neighbor] > dist[farthest]:
farthest = neighbor

return farthest, dist

end1, _ = bfs_farthest(0)
end2, dist1 = bfs_farthest(end1)
_, dist2 = bfs_farthest(end2)

diameter = dist1[end2]
centers = []

for i in range(n):
if dist1[i] + dist2[i] == diameter and abs(dist1[i] - dist2[i]) <= 1:
centers.append(i)

return centers, diameter

def tree_radius(adj):
centers, diameter = tree_center_and_diameter(adj)
return (diameter + 1) // 2

虚树

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
def virtual_tree(adj, key_nodes):
n = len(adj)

depth = [0] * n
parent = [-1] * n
dfn = [0] * n
order = []

stack = [(0, -1)]
while stack:
node, par = stack.pop()
parent[node] = par
dfn[node] = len(order)
order.append(node)
for child in adj[node]:
if child != par:
depth[child] = depth[node] + 1
stack.append((child, node))

log = (n).bit_length()
up = [[-1] * n for _ in range(log)]
up[0] = parent[:]
for k in range(1, log):
for v in range(n):
if up[k - 1][v] != -1:
up[k][v] = up[k - 1][up[k - 1][v]]

def lca(u, v):
if depth[u] < depth[v]:
u, v = v, u
diff = depth[u] - depth[v]
for k in range(log):
if diff >> k & 1:
u = up[k][u]
if u == v:
return u
for k in range(log - 1, -1, -1):
if up[k][u] != up[k][v]:
u = up[k][u]
v = up[k][v]
return parent[u]

key_nodes = sorted(key_nodes, key=lambda x: dfn[x])
stack = [key_nodes[0]]
virtual_edges = [[] for _ in range(n)]

for node in key_nodes[1:]:
l = lca(node, stack[-1])

while len(stack) >= 2 and depth[stack[-2]] >= depth[l]:
virtual_edges[stack[-2]].append(stack[-1])
stack.pop()

if stack[-1] != l:
virtual_edges[l].append(stack[-1])
stack.pop()
if not stack or stack[-1] != l:
stack.append(l)

stack.append(node)

while len(stack) >= 2:
virtual_edges[stack[-2]].append(stack[-1])
stack.pop()

return virtual_edges, stack[0]

三十九、CDQ 分治与整体二分

CDQ 分治

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
def cdq_divide_conquer(points):
n = len(points)
points = [(x, y, z, i) for i, (x, y, z) in enumerate(points)]

def cdq(l, r):
if l >= r:
return

mid = (l + r) // 2
cdq(l, mid)
cdq(mid + 1, r)

left = points[l:mid + 1]
right = points[mid + 1:r + 1]

left.sort(key=lambda p: p[1])
right.sort(key=lambda p: p[1])

j = 0
fenwick = [0] * (n + 1)

def update(i, val):
while i <= n:
fenwick[i] += val
i += i & -i

def query(i):
result = 0
while i > 0:
result += fenwick[i]
i -= i & -i
return result

for point in right:
while j < len(left) and left[j][1] <= point[1]:
update(left[j][2], 1)
j += 1
point = (point[0], point[1], point[2], point[3], query(point[2]))

points.sort(key=lambda p: p[0])
cdq(0, n - 1)

return points

整体二分

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
def parallel_binary_search(n, queries, check_func):
low = [0] * len(queries)
high = [n] * len(queries)
answers = [-1] * len(queries)

while True:
mid_queries = {}
all_done = True

for i, (l, h) in enumerate(zip(low, high)):
if l <= h:
all_done = False
mid = (l + h) // 2
if mid not in mid_queries:
mid_queries[mid] = []
mid_queries[mid].append(i)

if all_done:
break

for mid, query_indices in sorted(mid_queries.items()):
for idx in query_indices:
if check_func(idx, mid):
answers[idx] = mid
high[idx] = mid - 1
else:
low[idx] = mid + 1

return answers

def kth_smallest_queries(arr, queries):
n = len(arr)
indexed_arr = [(val, i) for i, val in enumerate(arr)]
indexed_arr.sort()

def check(query_idx, mid):
l, r, k = queries[query_idx]
count = 0
for i in range(mid + 1):
if l <= indexed_arr[i][1] <= r:
count += 1
return count >= k

return parallel_binary_search(n - 1, queries, check)

四十、随机化算法

模拟退火

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import random
import math

def simulated_annealing(initial_state, energy_func, neighbor_func, iterations=100000):
state = initial_state
energy = energy_func(state)
best_state = state
best_energy = energy

for i in range(iterations):
temperature = 1.0 - i / iterations

new_state = neighbor_func(state)
new_energy = energy_func(new_state)

delta = new_energy - energy

if delta < 0 or random.random() < math.exp(-delta / temperature):
state = new_state
energy = new_energy

if energy < best_energy:
best_state = state
best_energy = energy

return best_state, best_energy

def tsp_simulated_annealing(cities):
n = len(cities)

def distance(a, b):
return ((a[0] - b[0]) ** 2 + (a[1] - b[1]) ** 2) ** 0.5

def total_distance(path):
return sum(distance(cities[path[i]], cities[path[(i + 1) % n]]) for i in range(n))

def neighbor(path):
new_path = path[:]
i, j = sorted(random.sample(range(n), 2))
new_path[i:j + 1] = reversed(new_path[i:j + 1])
return new_path

initial_path = list(range(n))
random.shuffle(initial_path)

return simulated_annealing(initial_path, total_distance, neighbor)

随机哈希

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
import random

class RandomHash:
def __init__(self, n):
self.n = n
self.hash_values = {}

def get_hash(self, x):
if x not in self.hash_values:
self.hash_values[x] = random.randint(1, 10**18)
return self.hash_values[x]

def hash_array(self, arr):
result = 0
for x in arr:
result ^= self.get_hash(x)
return result

def subset_hash_check(arr1, arr2):
rh = RandomHash(100)
return rh.hash_array(arr1) == rh.hash_array(arr2)

Miller-Rabin 素性测试

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import random

def miller_rabin(n, k=10):
if n < 2:
return False
if n == 2 or n == 3:
return True
if n % 2 == 0:
return False

r, d = 0, n - 1
while d % 2 == 0:
r += 1
d //= 2

for _ in range(k):
a = random.randrange(2, n - 1)
x = pow(a, d, n)

if x == 1 or x == n - 1:
continue

for _ in range(r - 1):
x = pow(x, 2, n)
if x == n - 1:
break
else:
return False

return True

def pollard_rho(n):
if n % 2 == 0:
return 2

x = random.randint(2, n - 1)
y = x
c = random.randint(1, n - 1)
d = 1

while d == 1:
x = (x * x + c) % n
y = (y * y + c) % n
y = (y * y + c) % n
d = math.gcd(abs(x - y), n)

return d if d != n else None

def factorize(n):
if n == 1:
return []
if miller_rabin(n):
return [n]

factor = pollard_rho(n)
while factor is None:
factor = pollard_rho(n)

return factorize(factor) + factorize(n // factor)

四十一、更多真题与实战练习

2014 年真题:李白打酒

题目:话说大诗人李白,一生好饮。幸好他从不开车。一天,他提着酒壶,从家里出来,酒壶中有酒 2 斗。他边走边唱:无事街上走,提壶去打酒。逢店加一倍,遇花喝一斗。这一路上,他一共遇到店 5 次,遇到花 10 次,已知最后一次遇到的是花,他正好把酒喝光了。请你计算李白遇到店和花的次序,有多少种可能的方案?

解题思路

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
def count_ways():
def dfs(wine, shops, flowers, last_is_flower):
if wine < 0:
return 0
if shops == 0 and flowers == 0:
return 1 if wine == 0 and last_is_flower else 0

count = 0
if shops > 0:
count += dfs(wine * 2, shops - 1, flowers, False)
if flowers > 0:
count += dfs(wine - 1, shops, flowers - 1, True)

return count

return dfs(2, 5, 10, False)

print(count_ways())

2013 年真题:振兴中华

题目:小明参加了学校的趣味运动会,其中的一个项目是:跳格子。地上画着一些格子,每个格子里写一个字,如下所示:

1
2
3
4
从我做起振
我做起振兴
做起振兴中
起振兴中华

比赛时,先站在左上角的”从”字上,可以横向或纵向跳到相邻的格子里,但不能跳到对角的格子或其他位置。一直要跳到”华”字结束。要求跳过的路线刚好构成”从我做起振兴中华”这句话。请你帮助小明算一算他一共有多少种可能的跳跃路线?

解题思路

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
def count_paths():
grid = [
['从', '我', '做', '起', '振'],
['我', '做', '起', '振', '兴'],
['做', '起', '振', '兴', '中'],
['起', '振', '兴', '中', '华']
]

target = "从我做起振兴中华"
m, n = len(grid), len(grid[0])

from functools import lru_cache

@lru_cache(None)
def dfs(i, j, idx):
if idx == len(target):
return 1

count = 0
for di, dj in [(0, 1), (1, 0), (0, -1), (-1, 0)]:
ni, nj = i + di, j + dj
if 0 <= ni < m and 0 <= nj < n and grid[ni][nj] == target[idx]:
count += dfs(ni, nj, idx + 1)

return count

return dfs(0, 0, 1)

print(count_paths())

2012 年真题:密码发生器

题目:在对银行账户等重要权限设置密码的时候,我们常常遇到这样的烦恼:如果为了好记用生日吧,容易被破解,不安全;如果设置不好记的密码,又担心自己也会忘记;如果写在纸上,担心纸张被别人发现或弄丢了…这个程序的任务就是把一串拼音字母转换为 6 位数字(密码)。我们可以使用任何好记的拼音串(比如名字,王喜明,就用 wangximing)作为输入,程序输出 6 位数字。

解题思路

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
def generate_password(name):
char_values = {
'a': 1, 'b': 2, 'c': 3, 'd': 4, 'e': 5, 'f': 6, 'g': 7, 'h': 8, 'i': 9,
'j': 10, 'k': 11, 'l': 12, 'm': 13, 'n': 14, 'o': 15, 'p': 16, 'q': 17,
'r': 18, 's': 19, 't': 20, 'u': 21, 'v': 22, 'w': 23, 'x': 24, 'y': 25, 'z': 26
}

values = [char_values[c] for c in name.lower() if c in char_values]

while len(values) > 6:
new_values = []
for i in range(len(values) - 1):
new_values.append((values[i] + values[i + 1]) % 10)
values = new_values

while len(values) < 6:
values.append(0)

return ''.join(map(str, values[:6]))

print(generate_password("wangximing"))

综合练习:表达式求值

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
def evaluate_expression(s):
def precedence(op):
if op in '+-':
return 1
if op in '*/':
return 2
return 0

def apply_op(a, b, op):
if op == '+':
return a + b
if op == '-':
return a - b
if op == '*':
return a * b
if op == '/':
return a // b

values = []
ops = []
i = 0

while i < len(s):
if s[i] == ' ':
i += 1
continue

if s[i] == '(':
ops.append(s[i])
elif s[i].isdigit():
val = 0
while i < len(s) and s[i].isdigit():
val = val * 10 + int(s[i])
i += 1
values.append(val)
continue
elif s[i] == ')':
while ops and ops[-1] != '(':
val2 = values.pop()
val1 = values.pop()
op = ops.pop()
values.append(apply_op(val1, val2, op))
ops.pop()
else:
while ops and precedence(ops[-1]) >= precedence(s[i]):
val2 = values.pop()
val1 = values.pop()
op = ops.pop()
values.append(apply_op(val1, val2, op))
ops.append(s[i])

i += 1

while ops:
val2 = values.pop()
val1 = values.pop()
op = ops.pop()
values.append(apply_op(val1, val2, op))

return values[-1]

print(evaluate_expression("3 + 4 * 2 - 1"))

综合练习:大整数运算

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
class BigInteger:
def __init__(self, num=0):
if isinstance(num, str):
self.digits = [int(c) for c in reversed(num)]
else:
self.digits = [num]
self._normalize()

def _normalize(self):
while len(self.digits) > 1 and self.digits[-1] == 0:
self.digits.pop()

def __add__(self, other):
result = BigInteger()
result.digits = []
carry = 0
max_len = max(len(self.digits), len(other.digits))

for i in range(max_len):
a = self.digits[i] if i < len(self.digits) else 0
b = other.digits[i] if i < len(other.digits) else 0
total = a + b + carry
result.digits.append(total % 10)
carry = total // 10

if carry:
result.digits.append(carry)

result._normalize()
return result

def __mul__(self, other):
result = BigInteger()
result.digits = [0] * (len(self.digits) + len(other.digits))

for i in range(len(self.digits)):
carry = 0
for j in range(len(other.digits)):
total = result.digits[i + j] + self.digits[i] * other.digits[j] + carry
result.digits[i + j] = total % 10
carry = total // 10
if carry:
result.digits[i + len(other.digits)] += carry

result._normalize()
return result

def __str__(self):
return ''.join(map(str, reversed(self.digits)))

def __repr__(self):
return self.__str__()

a = BigInteger("12345678901234567890")
b = BigInteger("98765432109876543210")
print(a + b)
print(a * b)

四十二、常用算法思想总结

枚举与暴力

适用场景:数据规模小,没有明显规律

1
2
3
4
5
6
7
8
9
def enumerate_examples():
pass

def brute_force_substring(text, pattern):
n, m = len(text), len(pattern)
for i in range(n - m + 1):
if text[i:i + m] == pattern:
return i
return -1

贪心算法

适用场景:局部最优能推出全局最优

1
2
3
4
5
6
7
8
9
10
def greedy_examples():
pass

def activity_selection(activities):
activities.sort(key=lambda x: x[1])
selected = [activities[0]]
for activity in activities[1:]:
if activity[0] >= selected[-1][1]:
selected.append(activity)
return selected

分治算法

适用场景:问题可分解为独立子问题

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
def divide_conquer_examples():
pass

def merge_sort(arr):
if len(arr) <= 1:
return arr
mid = len(arr) // 2
return merge(merge_sort(arr[:mid]), merge_sort(arr[mid:]))

def merge(left, right):
result = []
i = j = 0
while i < len(left) and j < len(right):
if left[i] <= right[j]:
result.append(left[i])
i += 1
else:
result.append(right[j])
j += 1
result.extend(left[i:])
result.extend(right[j:])
return result

动态规划

适用场景:最优子结构 + 重叠子问题

1
2
3
4
5
6
7
8
9
10
def dp_examples():
pass

def knapsack(weights, values, capacity):
n = len(weights)
dp = [0] * (capacity + 1)
for i in range(n):
for w in range(capacity, weights[i] - 1, -1):
dp[w] = max(dp[w], dp[w - weights[i]] + values[i])
return dp[capacity]

搜索算法

适用场景:状态空间可遍历

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
def search_examples():
pass

def bfs_shortest_path(graph, start, end):
from collections import deque
queue = deque([(start, [start])])
visited = {start}

while queue:
node, path = queue.popleft()
if node == end:
return path
for neighbor in graph[node]:
if neighbor not in visited:
visited.add(neighbor)
queue.append((neighbor, path + [neighbor]))

return None

四十三、网络流算法

最大流问题

网络流是图论中的重要算法,用于解决流量分配问题。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
from collections import deque

class MaxFlow:
def __init__(self, n):
self.n = n
self.graph = [[] for _ in range(n)]
self.capacity = {}

def add_edge(self, u, v, cap):
self.graph[u].append(v)
self.graph[v].append(u)
self.capacity[(u, v)] = cap
self.capacity[(v, u)] = 0

def bfs(self, s, t, parent):
visited = [False] * self.n
queue = deque([s])
visited[s] = True

while queue:
u = queue.popleft()
for v in self.graph[u]:
if not visited[v] and self.capacity[(u, v)] > 0:
visited[v] = True
parent[v] = u
if v == t:
return True
queue.append(v)
return False

def edmonds_karp(self, s, t):
parent = [-1] * self.n
max_flow = 0

while self.bfs(s, t, parent):
path_flow = float('inf')
v = t
while v != s:
u = parent[v]
path_flow = min(path_flow, self.capacity[(u, v)])
v = u

max_flow += path_flow
v = t
while v != s:
u = parent[v]
self.capacity[(u, v)] -= path_flow
self.capacity[(v, u)] += path_flow
v = u

return max_flow

mf = MaxFlow(6)
mf.add_edge(0, 1, 16)
mf.add_edge(0, 2, 13)
mf.add_edge(1, 2, 10)
mf.add_edge(1, 3, 12)
mf.add_edge(2, 1, 4)
mf.add_edge(2, 4, 14)
mf.add_edge(3, 2, 9)
mf.add_edge(3, 5, 20)
mf.add_edge(4, 3, 7)
mf.add_edge(4, 5, 4)

print(f"最大流: {mf.edmonds_karp(0, 5)}")

Dinic 算法

更高效的最大流算法,适用于大规模网络。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
class Dinic:
def __init__(self, n):
self.n = n
self.adj = [[] for _ in range(n)]

def add_edge(self, u, v, cap):
self.adj[u].append([v, cap, len(self.adj[v])])
self.adj[v].append([u, 0, len(self.adj[u]) - 1])

def bfs_level(self, s, t, level):
queue = deque([s])
level[s] = 0

while queue:
u = queue.popleft()
for v, cap, rev in self.adj[u]:
if level[v] < 0 and cap > 0:
level[v] = level[u] + 1
queue.append(v)

return level[t] >= 0

def dfs_flow(self, u, t, flow, level, ptr):
if u == t:
return flow

for i in range(ptr[u], len(self.adj[u])):
ptr[u] = i
v, cap, rev = self.adj[u][i]

if level[v] == level[u] + 1 and cap > 0:
pushed = self.dfs_flow(v, t, min(flow, cap), level, ptr)
if pushed > 0:
self.adj[u][i][1] -= pushed
self.adj[v][rev][1] += pushed
return pushed

return 0

def max_flow(self, s, t):
total = 0

while True:
level = [-1] * self.n
if not self.bfs_level(s, t, level):
break

ptr = [0] * self.n
while True:
pushed = self.dfs_flow(s, t, float('inf'), level, ptr)
if pushed == 0:
break
total += pushed

return total

dinic = Dinic(6)
dinic.add_edge(0, 1, 16)
dinic.add_edge(0, 2, 13)
dinic.add_edge(1, 2, 10)
dinic.add_edge(1, 3, 12)
dinic.add_edge(2, 4, 14)
dinic.add_edge(3, 5, 20)
dinic.add_edge(4, 5, 4)

print(f"Dinic 最大流: {dinic.max_flow(0, 5)}")

最小割问题

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
def min_cut(n, edges, s, t):
mf = MaxFlow(n)
for u, v, cap in edges:
mf.add_edge(u, v, cap)

max_flow = mf.edmonds_karp(s, t)

visited = [False] * n
queue = deque([s])
visited[s] = True

while queue:
u = queue.popleft()
for v in mf.graph[u]:
if not visited[v] and mf.capacity[(u, v)] > 0:
visited[v] = True
queue.append(v)

cut_edges = []
for u, v, _ in edges:
if visited[u] and not visited[v]:
cut_edges.append((u, v))

return max_flow, cut_edges

二分图匹配

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
class BipartiteMatching:
def __init__(self, n_left, n_right):
self.n_left = n_left
self.n_right = n_right
self.graph = [[] for _ in range(n_left)]

def add_edge(self, u, v):
self.graph[u].append(v)

def bpm_dfs(self, u, seen, match_r):
for v in self.graph[u]:
if not seen[v]:
seen[v] = True
if match_r[v] == -1 or self.bpm_dfs(match_r[v], seen, match_r):
match_r[v] = u
return True
return False

def max_matching(self):
match_r = [-1] * self.n_right
result = 0

for u in range(self.n_left):
seen = [False] * self.n_right
if self.bpm_dfs(u, seen, match_r):
result += 1

return result, match_r

bm = BipartiteMatching(4, 4)
bm.add_edge(0, 1)
bm.add_edge(0, 2)
bm.add_edge(1, 0)
bm.add_edge(2, 1)
bm.add_edge(3, 2)

matching, pairs = bm.max_matching()
print(f"最大匹配数: {matching}")

四十四、AC 自动机

AC 自动机是多模式字符串匹配的经典算法。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
from collections import deque

class AhoCorasick:
def __init__(self):
self.trie = [{}]
self.fail = [0]
self.output = [[]]

def add_pattern(self, pattern, index):
node = 0
for char in pattern:
if char not in self.trie[node]:
self.trie.append({})
self.fail.append(0)
self.output.append([])
self.trie[node][char] = len(self.trie) - 1
node = self.trie[node][char]
self.output[node].append(index)

def build_fail(self):
queue = deque()

for char, node in self.trie[0].items():
queue.append(node)

while queue:
curr = queue.popleft()

for char, next_node in self.trie[curr].items():
fail = self.fail[curr]

while fail and char not in self.trie[fail]:
fail = self.fail[fail]

self.fail[next_node] = self.trie[fail].get(char, 0)
self.output[next_node].extend(self.output[self.fail[next_node]])
queue.append(next_node)

def search(self, text):
node = 0
results = []

for i, char in enumerate(text):
while node and char not in self.trie[node]:
node = self.fail[node]

node = self.trie[node].get(char, 0)

for pattern_idx in self.output[node]:
results.append((i, pattern_idx))

return results

ac = AhoCorasick()
patterns = ["he", "she", "his", "hers"]
for i, pattern in enumerate(patterns):
ac.add_pattern(pattern, i)

ac.build_fail()
text = "ushers"
matches = ac.search(text)

print(f"文本: {text}")
for pos, idx in matches:
print(f"位置 {pos}: 找到模式 '{patterns[idx]}'")

AC 自动机应用:敏感词过滤

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
class SensitiveWordFilter:
def __init__(self, words):
self.ac = AhoCorasick()
self.words = words
for i, word in enumerate(words):
self.ac.add_pattern(word, i)
self.ac.build_fail()

def filter(self, text, replace_char='*'):
matches = self.ac.search(text)
result = list(text)

for end_pos, word_idx in matches:
word = self.words[word_idx]
start_pos = end_pos - len(word) + 1
for i in range(start_pos, end_pos + 1):
result[i] = replace_char

return ''.join(result)

filter = SensitiveWordFilter(["敏感", "违禁", "禁止"])
text = "这是一条包含敏感词汇和违禁内容的消息"
print(f"原文: {text}")
print(f"过滤后: {filter.filter(text)}")

四十五、后缀数组

后缀数组是处理字符串问题的强大工具。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
def build_suffix_array(s):
n = len(s)
k = 1
rank = [ord(c) for c in s] + [-1]
tmp = [0] * n
sa = list(range(n))

while k < n:
sa.sort(key=lambda x: (rank[x], rank[x + k] if x + k < n else -1))

tmp[sa[0]] = 0
for i in range(1, n):
prev, curr = sa[i - 1], sa[i]
tmp[curr] = tmp[prev]
if (rank[prev], rank[prev + k] if prev + k < n else -1) < \
(rank[curr], rank[curr + k] if curr + k < n else -1):
tmp[curr] += 1

rank = tmp[:]
k *= 2

return sa

def build_lcp(s, sa):
n = len(s)
rank = [0] * n
for i in range(n):
rank[sa[i]] = i

lcp = [0] * (n - 1)
h = 0

for i in range(n):
if rank[i] == 0:
h = 0
continue

j = sa[rank[i] - 1]
while i + h < n and j + h < n and s[i + h] == s[j + h]:
h += 1

lcp[rank[i] - 1] = h
if h > 0:
h -= 1

return lcp

s = "banana"
sa = build_suffix_array(s)
lcp = build_lcp(s, sa)

print(f"字符串: {s}")
print("后缀数组:")
for i, pos in enumerate(sa):
print(f" SA[{i}] = {pos}: {s[pos:]}")
print(f"LCP 数组: {lcp}")

后缀数组应用:最长重复子串

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
def longest_repeated_substring(s):
if len(s) <= 1:
return ""

sa = build_suffix_array(s)
lcp = build_lcp(s, sa)

max_len = 0
pos = 0

for i in range(len(lcp)):
if lcp[i] > max_len:
max_len = lcp[i]
pos = sa[i]

return s[pos:pos + max_len]

print(f"最长重复子串: {longest_repeated_substring('banana')}")

后缀数组应用:不同子串个数

1
2
3
4
5
6
7
8
9
10
11
12
def count_distinct_substrings(s):
n = len(s)
sa = build_suffix_array(s)
lcp = build_lcp(s, sa)

total = n * (n + 1) // 2
for l in lcp:
total -= l

return total

print(f"不同子串个数: {count_distinct_substrings('banana')}")

四十六、博弈论进阶

SG 函数深入

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
def mex(s):
mex_val = 0
while mex_val in s:
mex_val += 1
return mex_val

def compute_sg(max_n, moves):
sg = [0] * (max_n + 1)

for i in range(1, max_n + 1):
reachable = set()
for move in moves:
if i >= move:
reachable.add(sg[i - move])
sg[i] = mex(reachable)

return sg

moves = [1, 3, 4]
sg = compute_sg(20, moves)
print(f"移动 {moves} 的 SG 值: {sg[:15]}")

def find_pattern(sg):
n = len(sg)
for period in range(1, n // 2):
is_periodic = True
for i in range(period, n):
if sg[i] != sg[i - period]:
is_periodic = False
break
if is_periodic:
return period
return None

print(f"周期: {find_pattern(sg)}")

Nim 游戏变体

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
def nim_game_variants():
pass

def misere_nim(piles):
if all(p == 0 for p in piles):
return "先手必败"

xor_sum = 0
for p in piles:
xor_sum ^= p

if xor_sum == 0:
return "先手必败"

if all(p <= 1 for p in piles):
return "先手必败" if xor_sum == 1 else "先手必胜"

return "先手必胜"

def wythoff_game(a, b):
from math import sqrt

if a > b:
a, b = b, a

k = b - a
phi = (1 + sqrt(5)) / 2

if a == int(k * phi):
return "先手必败"
return "先手必胜"

def grundy_game(n):
if n <= 2:
return 0

reachable = set()
for i in range(1, n // 2 + 1):
if i != n - i:
reachable.add(grundy_game(i) ^ grundy_game(n - i))

return mex(reachable)

print(f"Misère Nim [1,1,1]: {misere_nim([1,1,1])}")
print(f"Wythoff 游戏 (3,5): {wythoff_game(3, 5)}")

组合游戏

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
class GameState:
def __init__(self, state):
self.state = state

def get_moves(self):
pass

def is_terminal(self):
pass

def get_result(self):
pass

def alpha_beta_pruning(state, depth, alpha, beta, maximizing):
if depth == 0 or state.is_terminal():
return state.get_result()

if maximizing:
max_eval = float('-inf')
for move in state.get_moves():
eval_score = alpha_beta_pruning(move, depth - 1, alpha, beta, False)
max_eval = max(max_eval, eval_score)
alpha = max(alpha, eval_score)
if beta <= alpha:
break
return max_eval
else:
min_eval = float('inf')
for move in state.get_moves():
eval_score = alpha_beta_pruning(move, depth - 1, alpha, beta, True)
min_eval = min(min_eval, eval_score)
beta = min(beta, eval_score)
if beta <= alpha:
break
return min_eval

四十七、线段树进阶

懒惰传播线段树

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
class LazySegmentTree:
def __init__(self, data):
self.n = len(data)
self.size = 1
while self.size < self.n:
self.size *= 2

self.tree = [0] * (2 * self.size)
self.lazy = [0] * (2 * self.size)

for i in range(self.n):
self.tree[self.size + i] = data[i]

for i in range(self.size - 1, 0, -1):
self.tree[i] = self.tree[2 * i] + self.tree[2 * i + 1]

def push(self, node, node_len):
if self.lazy[node] != 0:
self.tree[node] += self.lazy[node] * node_len

if node < self.size:
self.lazy[2 * node] += self.lazy[node]
self.lazy[2 * node + 1] += self.lazy[node]

self.lazy[node] = 0

def update_range(self, l, r, val, node=1, node_l=0, node_r=None):
if node_r is None:
node_r = self.size

self.push(node, node_r - node_l)

if r <= node_l or node_r <= l:
return

if l <= node_l and node_r <= r:
self.lazy[node] += val
self.push(node, node_r - node_l)
return

mid = (node_l + node_r) // 2
self.update_range(l, r, val, 2 * node, node_l, mid)
self.update_range(l, r, val, 2 * node + 1, mid, node_r)

self.tree[node] = self.tree[2 * node] + self.tree[2 * node + 1]

def query_range(self, l, r, node=1, node_l=0, node_r=None):
if node_r is None:
node_r = self.size

self.push(node, node_r - node_l)

if r <= node_l or node_r <= l:
return 0

if l <= node_l and node_r <= r:
return self.tree[node]

mid = (node_l + node_r) // 2
left_sum = self.query_range(l, r, 2 * node, node_l, mid)
right_sum = self.query_range(l, r, 2 * node + 1, mid, node_r)

return left_sum + right_sum

lst = LazySegmentTree([1, 2, 3, 4, 5])
print(f"初始区间和 [0,5): {lst.query_range(0, 5)}")
lst.update_range(1, 4, 10)
print(f"更新后区间和 [0,5): {lst.query_range(0, 5)}")
print(f"区间和 [2,4): {lst.query_range(2, 4)}")

可持久化线段树

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
class PersistentSegmentTree:
def __init__(self, arr):
self.n = len(arr)
self.roots = [None]
self.nodes = []

sorted_arr = sorted(set(arr))
self.compress = {v: i for i, v in enumerate(sorted_arr)}
self.decompress = sorted_arr

self.root0 = self._build(0, len(sorted_arr) - 1)
self.roots[0] = self.root0

for val in arr:
new_root = self._update(self.roots[-1], 0, len(sorted_arr) - 1, self.compress[val])
self.roots.append(new_root)

def _build(self, l, r):
node = {'l': l, 'r': r, 'count': 0, 'left': None, 'right': None}
if l == r:
return node
mid = (l + r) // 2
node['left'] = self._build(l, mid)
node['right'] = self._build(mid + 1, r)
return node

def _update(self, prev, l, r, pos):
node = {'l': l, 'r': r, 'count': prev['count'] + 1, 'left': None, 'right': None}

if l == r:
return node

mid = (l + r) // 2
if pos <= mid:
node['left'] = self._update(prev['left'], l, mid, pos)
node['right'] = prev['right']
else:
node['left'] = prev['left']
node['right'] = self._update(prev['right'], mid + 1, r, pos)

return node

def _query_kth(self, node_l, node_r, l, r, k):
if l == r:
return l

mid = (l + r) // 2
left_count = node_r['left']['count'] - node_l['left']['count']

if k <= left_count:
return self._query_kth(node_l['left'], node_r['left'], l, mid, k)
else:
return self._query_kth(node_l['right'], node_r['right'], mid + 1, r, k - left_count)

def kth_smallest(self, l, r, k):
return self.decompress[self._query_kth(self.roots[l], self.roots[r], 0, len(self.decompress) - 1, k)]

四十八、平衡树

Treap

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import random

class TreapNode:
def __init__(self, key):
self.key = key
self.priority = random.randint(1, 10**9)
self.left = None
self.right = None
self.size = 1

class Treap:
def __init__(self):
self.root = None

def _update_size(self, node):
if node:
node.size = 1 + (node.left.size if node.left else 0) + (node.right.size if node.right else 0)

def _rotate_right(self, y):
x = y.left
y.left = x.right
x.right = y
self._update_size(y)
self._update_size(x)
return x

def _rotate_left(self, x):
y = x.right
x.right = y.left
y.left = x
self._update_size(x)
self._update_size(y)
return y

def insert(self, key):
self.root = self._insert(self.root, key)

def _insert(self, node, key):
if not node:
return TreapNode(key)

if key < node.key:
node.left = self._insert(node.left, key)
if node.left.priority > node.priority:
node = self._rotate_right(node)
else:
node.right = self._insert(node.right, key)
if node.right.priority > node.priority:
node = self._rotate_left(node)

self._update_size(node)
return node

def delete(self, key):
self.root = self._delete(self.root, key)

def _delete(self, node, key):
if not node:
return None

if key < node.key:
node.left = self._delete(node.left, key)
elif key > node.key:
node.right = self._delete(node.right, key)
else:
if not node.left:
return node.right
if not node.right:
return node.left

if node.left.priority > node.right.priority:
node = self._rotate_right(node)
node.right = self._delete(node.right, key)
else:
node = self._rotate_left(node)
node.left = self._delete(node.left, key)

self._update_size(node)
return node

def kth(self, k):
return self._kth(self.root, k)

def _kth(self, node, k):
left_size = node.left.size if node.left else 0

if k <= left_size:
return self._kth(node.left, k)
elif k == left_size + 1:
return node.key
else:
return self._kth(node.right, k - left_size - 1)

def rank(self, key):
return self._rank(self.root, key)

def _rank(self, node, key):
if not node:
return 0

if key < node.key:
return self._rank(node.left, key)
elif key == node.key:
return (node.left.size if node.left else 0) + 1
else:
left_size = node.left.size if node.left else 0
return left_size + 1 + self._rank(node.right, key)

treap = Treap()
for x in [5, 3, 7, 1, 4, 6, 8]:
treap.insert(x)

print(f"第 3 小的元素: {treap.kth(3)}")
print(f"5 的排名: {treap.rank(5)}")

Splay 树

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
class SplayNode:
def __init__(self, key):
self.key = key
self.left = None
self.right = None
self.parent = None
self.size = 1

class SplayTree:
def __init__(self):
self.root = None

def _update_size(self, node):
if node:
node.size = 1
if node.left:
node.size += node.left.size
if node.right:
node.size += node.right.size

def _rotate(self, x):
p = x.parent
if not p:
return

if p.left == x:
p.left = x.right
if x.right:
x.right.parent = p
x.right = p
else:
p.right = x.left
if x.left:
x.left.parent = p
x.left = p

x.parent = p.parent
p.parent = x

if x.parent:
if x.parent.left == p:
x.parent.left = x
else:
x.parent.right = x
else:
self.root = x

self._update_size(p)
self._update_size(x)

def _splay(self, x):
while x.parent:
p = x.parent
g = p.parent

if not g:
self._rotate(x)
elif (g.left == p) == (p.left == x):
self._rotate(p)
self._rotate(x)
else:
self._rotate(x)
self._rotate(x)

def insert(self, key):
if not self.root:
self.root = SplayNode(key)
return

node = self.root
while True:
if key < node.key:
if not node.left:
node.left = SplayNode(key)
node.left.parent = node
self._splay(node.left)
return
node = node.left
elif key > node.key:
if not node.right:
node.right = SplayNode(key)
node.right.parent = node
self._splay(node.right)
return
node = node.right
else:
self._splay(node)
return

四十九、高级搜索技术

迭代加深搜索

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
def iterative_deepening_dfs(start, goal, max_depth, get_neighbors, is_goal):
for depth in range(max_depth + 1):
visited = set()
result = dfs_with_depth(start, goal, depth, visited, get_neighbors, is_goal)
if result is not None:
return result
return None

def dfs_with_depth(node, goal, depth, visited, get_neighbors, is_goal):
if is_goal(node, goal):
return [node]

if depth == 0:
return None

visited.add(node)

for neighbor in get_neighbors(node):
if neighbor not in visited:
result = dfs_with_depth(neighbor, goal, depth - 1, visited, get_neighbors, is_goal)
if result is not None:
return [node] + result

visited.remove(node)
return None

def get_neighbors_8puzzle(state):
neighbors = []
zero_idx = state.index(0)
row, col = zero_idx // 3, zero_idx % 3

moves = [(-1, 0), (1, 0), (0, -1), (0, 1)]

for dr, dc in moves:
new_row, new_col = row + dr, col + dc
if 0 <= new_row < 3 and 0 <= new_col < 3:
new_idx = new_row * 3 + new_col
new_state = list(state)
new_state[zero_idx], new_state[new_idx] = new_state[new_idx], new_state[zero_idx]
neighbors.append(tuple(new_state))

return neighbors

def is_goal_8puzzle(state, goal):
return state == goal

start = (1, 2, 3, 4, 0, 5, 6, 7, 8)
goal = (1, 2, 3, 4, 5, 6, 7, 8, 0)

result = iterative_deepening_dfs(start, goal, 30, get_neighbors_8puzzle, is_goal_8puzzle)
if result:
print(f"找到解,步数: {len(result) - 1}")

IDA* 算法

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
def ida_star(start, goal, heuristic, get_neighbors, is_goal):
threshold = heuristic(start, goal)
path = [start]

while True:
result = ida_search(path, 0, threshold, goal, heuristic, get_neighbors, is_goal)

if isinstance(result, list):
return result

if result == float('inf'):
return None

threshold = result

def ida_search(path, g, threshold, goal, heuristic, get_neighbors, is_goal):
node = path[-1]
f = g + heuristic(node, goal)

if f > threshold:
return f

if is_goal(node, goal):
return path[:]

min_threshold = float('inf')

for neighbor in get_neighbors(node):
if neighbor not in path:
path.append(neighbor)
result = ida_search(path, g + 1, threshold, goal, heuristic, get_neighbors, is_goal)

if isinstance(result, list):
return result

min_threshold = min(min_threshold, result)
path.pop()

return min_threshold

def manhattan_distance(state, goal):
distance = 0
for i in range(9):
if state[i] != 0:
goal_idx = goal.index(state[i])
distance += abs(i // 3 - goal_idx // 3) + abs(i % 3 - goal_idx % 3)
return distance

start = (1, 2, 3, 4, 0, 5, 6, 7, 8)
goal = (1, 2, 3, 4, 5, 6, 7, 8, 0)

result = ida_star(start, goal, manhattan_distance, get_neighbors_8puzzle, is_goal_8puzzle)
if result:
print(f"IDA* 找到解,步数: {len(result) - 1}")

双向 BFS

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
from collections import deque

def bidirectional_bfs(start, goal, get_neighbors):
if start == goal:
return [start]

forward_queue = deque([start])
backward_queue = deque([goal])
forward_visited = {start: [start]}
backward_visited = {goal: [goal]}

while forward_queue and backward_queue:
if len(forward_queue) <= len(backward_queue):
result = _bfs_step(forward_queue, forward_visited, backward_visited, get_neighbors, True)
else:
result = _bfs_step(backward_queue, backward_visited, forward_visited, get_neighbors, False)

if result:
return result

return None

def _bfs_step(queue, visited, other_visited, get_neighbors, is_forward):
for _ in range(len(queue)):
node = queue.popleft()

for neighbor in get_neighbors(node):
if neighbor not in visited:
visited[neighbor] = visited[node] + [neighbor]
queue.append(neighbor)

if neighbor in other_visited:
forward_path = visited[neighbor] if is_forward else other_visited[neighbor]
backward_path = other_visited[neighbor] if is_forward else visited[neighbor]
return forward_path + backward_path[-2::-1]

return None

result = bidirectional_bfs(start, goal, get_neighbors_8puzzle)
if result:
print(f"双向 BFS 找到解,步数: {len(result) - 1}")

五十、高级动态规划

插头 DP

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
def plug_dp(grid):
n, m = len(grid), len(grid[0])

def encode(state):
result = 0
for i, val in enumerate(state):
result |= val << (2 * i)
return result

def decode(code, length):
state = []
for _ in range(length):
state.append(code & 3)
code >>= 2
return state

dp = {0: 1}

for i in range(n):
for j in range(m):
new_dp = {}

for state_code, count in dp.items():
state = decode(state_code, m + 1)

if grid[i][j] == 1:
new_state = state[:]
new_code = encode(new_state)
new_dp[new_code] = new_dp.get(new_code, 0) + count
else:
left = state[j]
up = state[j + 1]

if left == 0 and up == 0:
if j + 1 < m and grid[i][j + 1] == 0 and i + 1 < n and grid[i + 1][j] == 0:
new_state = state[:]
new_state[j] = 1
new_state[j + 1] = 2
new_code = encode(new_state)
new_dp[new_code] = new_dp.get(new_code, 0) + count
elif left == 0 and up != 0:
if i + 1 < n and grid[i + 1][j] == 0:
new_state = state[:]
new_state[j + 1] = up
new_code = encode(new_state)
new_dp[new_code] = new_dp.get(new_code, 0) + count
if j + 1 < m and grid[i][j + 1] == 0:
new_state = state[:]
new_state[j + 1] = 0
new_code = encode(new_state)
new_dp[new_code] = new_dp.get(new_code, 0) + count
elif left != 0 and up == 0:
pass
else:
new_state = state[:]
new_state[j] = 0
new_state[j + 1] = 0
new_code = encode(new_state)
new_dp[new_code] = new_dp.get(new_code, 0) + count

dp = new_dp

return dp.get(0, 0)

轮廓线 DP

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
def contour_dp(grid):
n, m = len(grid), len(grid[0])

dp = {(0,) * m: 1}

for i in range(n):
for j in range(m):
new_dp = {}

for state, count in dp.items():
state = list(state)

if grid[i][j] == 1:
new_state = tuple(state)
new_dp[new_state] = new_dp.get(new_state, 0) + count
else:
new_state = tuple(state)
new_dp[new_state] = new_dp.get(new_state, 0) + count

if state[j] == 0:
if j + 1 < m and state[j + 1] == 0:
new_state = list(state)
new_state[j] = 1
new_state[j + 1] = 1
new_dp[tuple(new_state)] = new_dp.get(tuple(new_state), 0) + count

dp = new_dp

return dp.get((0,) * m, 0)

数位 DP

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
def digit_dp(n, condition):
s = str(n)
length = len(s)

from functools import lru_cache

@lru_cache(maxsize=None)
def dfs(pos, tight, state):
if pos == length:
return 1 if condition(state) else 0

limit = int(s[pos]) if tight else 9
result = 0

for digit in range(limit + 1):
new_tight = tight and digit == limit
new_state = update_state(state, digit)
result += dfs(pos + 1, new_tight, new_state)

return result

return dfs(0, True, initial_state())

def count_without_49(n):
s = str(n)
length = len(s)

from functools import lru_cache

@lru_cache(maxsize=None)
def dfs(pos, tight, prev_four):
if pos == length:
return 1

limit = int(s[pos]) if tight else 9
result = 0

for digit in range(limit + 1):
new_tight = tight and digit == limit
if prev_four and digit == 9:
continue
result += dfs(pos + 1, new_tight, digit == 4)

return result

return dfs(0, True, False)

print(f"1 到 1000 中不含 49 的数: {count_without_49(1000)}")

五十一、更多真题实战

2020 年真题:跑步锻炼

题目:小蓝每天都锻炼身体。有些日子,他会跑步,并且记录跑步的里程。2020 年 1 月 1 日是星期三,小蓝开始跑步。从这天开始,每周的周一或月初(1 日)他会跑 2 千米,其他日子跑 1 千米。请问到 2020 年 12 月 31 日,小蓝总共跑了多少千米?

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
from datetime import datetime, timedelta

def running_distance():
start = datetime(2020, 1, 1)
end = datetime(2020, 12, 31)

total = 0
current = start

while current <= end:
is_monday = current.weekday() == 0
is_first_day = current.day == 1

if is_monday or is_first_day:
total += 2
else:
total += 1

current += timedelta(days=1)

return total

print(f"总跑步距离: {running_distance()} 千米")

2020 年真题:蛇形填数

题目:如下图所示,小明用从 1 开始的正整数”蛇形”填充无限大的矩阵。容易看出矩阵第 2 行第 2 列的数是 5。请你计算矩阵中第 20 行第 20 列的数是多少?

1
2
3
4
5
6
7
8
9
10
11
12
def snake_matrix_value(row, col):
if row == 1 and col == 1:
return 1

diag = row + col - 1

if diag % 2 == 1:
return diag * (diag - 1) // 2 + col
else:
return diag * (diag - 1) // 2 + row

print(f"第 20 行第 20 列的数: {snake_matrix_value(20, 20)}")

2019 年真题:等差数列

题目:数学老师给小明出了一道等差数列求和的题目。但是粗心的小明忘记了一部分的数列,只记得其中 N 个整数。现在给出这 N 个整数,小明想知道包含这 N 个整数的最短的等差数列有几项?

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
from math import gcd
from functools import reduce

def find_gcd_of_list(numbers):
return reduce(gcd, numbers)

def shortest_arithmetic_sequence(nums):
nums.sort()

if len(nums) == 1:
return 1

diffs = []
for i in range(1, len(nums)):
diffs.append(nums[i] - nums[i - 1])

if all(d == 0 for d in diffs):
return len(nums)

d = find_gcd_of_list(diffs)

return (nums[-1] - nums[0]) // d + 1

nums = [2, 6, 4, 10, 20, 30]
print(f"最短等差数列项数: {shortest_arithmetic_sequence(nums)}")

2019 年真题:数的分解

题目:把 2019 分解成 3 个各不相同的正整数之和,并且要求每个正整数都不包含数字 2 和 4,一共有多少种不同的分解方法?

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
def has_digit(n, digit):
return str(digit) in str(n)

def count_decompositions(target):
count = 0

for a in range(1, target):
if has_digit(a, 2) or has_digit(a, 4):
continue

for b in range(a + 1, target):
if has_digit(b, 2) or has_digit(b, 4):
continue

c = target - a - b

if c <= b:
continue

if has_digit(c, 2) or has_digit(c, 4):
continue

count += 1

return count

print(f"分解方法数: {count_decompositions(2019)}")

2018 年真题:递增三元组

题目:给定三个整数数组 A, B, C,请统计有多少个三元组 (i, j, k) 满足:A[i] < B[j] < C[k]。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
from bisect import bisect_left, bisect_right

def count_increasing_triplets(A, B, C):
A.sort()
C.sort()

count = 0

for b in B:
less_a = bisect_left(A, b)
greater_c = len(C) - bisect_right(C, b)
count += less_a * greater_c

return count

A = [1, 1, 1]
B = [2, 2, 2]
C = [3, 3, 3]

print(f"递增三元组数量: {count_increasing_triplets(A, B, C)}")

五十二、代码优化技巧

常数优化

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import sys
from functools import lru_cache

input = sys.stdin.readline

def fast_io():
data = sys.stdin.read().split()
it = iter(data)
return it

@lru_cache(maxsize=None)
def fibonacci(n):
if n <= 1:
return n
return fibonacci(n - 1) + fibonacci(n - 2)

def optimize_loops():
arr = list(range(1000000))

result = 0
for x in arr:
result += x

result = sum(arr)

return result

def use_local_variables():
arr = list(range(10000))
n = len(arr)
local_append = arr.append

for i in range(n):
local_append(i)

return arr

内存优化

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
def memory_optimization():
large_list = [i for i in range(10**7)]

large_gen = (i for i in range(10**7))

def chunked_processing(data, chunk_size=10000):
for i in range(0, len(data), chunk_size):
yield data[i:i + chunk_size]

return large_gen

def use_slots():
class Point:
__slots__ = ['x', 'y']

def __init__(self, x, y):
self.x = x
self.y = y

return Point(1, 2)

输入输出优化

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
import sys

class FastIO:
def __init__(self):
self.buffer = sys.stdin.buffer
self.output = sys.stdout

def read_int(self):
return int(self.buffer.readline())

def read_ints(self):
return list(map(int, self.buffer.readline().split()))

def read_str(self):
return self.buffer.readline().decode().strip()

def write(self, s):
self.output.write(str(s))

def writeln(self, s):
self.output.write(str(s) + '\n')

def batch_output(results):
print('\n'.join(map(str, results)))

五十三、常见错误与陷阱

整数溢出与精度问题

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
def common_errors():
pass

def precision_issues():
a = 0.1 + 0.2
print(f"0.1 + 0.2 = {a}")
print(f"0.1 + 0.2 == 0.3: {a == 0.3}")

from decimal import Decimal
a = Decimal('0.1') + Decimal('0.2')
print(f"Decimal: {a}")

return a

def division_traps():
print(f"-7 // 3 = {-7 // 3}")
print(f"-7 % 3 = {-7 % 3}")

print(f"int(-3.7) = {int(-3.7)}")
print(f"int(3.7) = {int(3.7)}")

def list_copy_traps():
a = [[1, 2], [3, 4]]
b = a[:]
b[0][0] = 999
print(f"原列表被修改: {a}")

import copy
c = copy.deepcopy(a)
c[0][0] = 111
print(f"深拷贝不影响原列表: {a}")

边界条件处理

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
def handle_edge_cases():
pass

def safe_divide(a, b):
if b == 0:
return float('inf') if a >= 0 else float('-inf')
return a / b

def safe_mod(a, b):
if b == 0:
raise ValueError("除数不能为 0")
return a % b

def array_access(arr, idx):
if 0 <= idx < len(arr):
return arr[idx]
return None

def matrix_access(matrix, row, col):
if 0 <= row < len(matrix) and 0 <= col < len(matrix[0]):
return matrix[row][col]
return None

常见逻辑错误

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
def logic_errors():
pass

def correct_comparison():
items = [{'val': 1}, {'val': 2}]

found = False
for item in items:
if item['val'] == 3:
found = True
break

return found

def avoid_off_by_one():
n = 10
arr = list(range(n))

for i in range(n):
pass

for i in range(n - 1):
pass

for i in range(1, n):
pass

def correct_recursion():
from functools import lru_cache

@lru_cache(maxsize=None)
def fib(n):
if n <= 1:
return n
return fib(n - 1) + fib(n - 2)

return fib(50)

五十四、比赛策略与技巧

时间分配策略

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
蓝桥杯比赛时间分配建议(4 小时):

1. 前 30 分钟:
- 快速浏览所有题目
- 标记简单题和难题
- 确定答题顺序

2. 第 30-90 分钟:
- 完成所有结果填空题
- 确保简单题不丢分

3. 第 90-180 分钟:
- 攻克编程大题
- 先完成有把握的题目

4. 最后 60 分钟:
- 检查已完成的代码
- 尝试难题的部分分
- 确保代码能编译运行

部分分策略

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
def partial_score_strategy():
pass

def solve_with_brute_force():
n = int(input())
arr = list(map(int, input().split()))

result = 0
for i in range(n):
for j in range(i + 1, n):
result += arr[i] * arr[j]

print(result)

def solve_with_optimized():
n = int(input())
arr = list(map(int, input().split()))

total = sum(arr)
result = 0

for x in arr:
result += x * (total - x)

print(result // 2)

def handle_small_cases():
n = int(input())

if n <= 20:
pass
elif n <= 1000:
pass
else:
pass

调试技巧

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
def debug_techniques():
pass

def debug_print(*args, **kwargs):
import sys
print("[DEBUG]", *args, **kwargs, file=sys.stderr)

def assert_with_message(condition, message):
if not condition:
print(f"断言失败: {message}")
raise AssertionError(message)

def test_with_examples():
test_cases = [
(5, [1, 2, 3, 4, 5], 15),
(3, [1, 1, 1], 3),
(0, [], 0),
]

for n, arr, expected in test_cases:
result = sum(arr)
assert result == expected, f"测试失败: n={n}, arr={arr}"

print("所有测试通过!")

def visualize_data(data):
if isinstance(data, list):
print(f"列表长度: {len(data)}")
print(f"前 10 个元素: {data[:10]}")
print(f"后 10 个元素: {data[-10:]}")
elif isinstance(data, dict):
print(f"字典大小: {len(data)}")
for i, (k, v) in enumerate(data.items()):
if i >= 10:
break
print(f" {k}: {v}")

五十五、综合实战案例

案例:迷宫最短路径

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
from collections import deque

def maze_shortest_path(maze, start, end):
rows, cols = len(maze), len(maze[0])
directions = [(0, 1), (0, -1), (1, 0), (-1, 0)]

queue = deque([(start[0], start[1], 0)])
visited = [[False] * cols for _ in range(rows)]
visited[start[0]][start[1]] = True

while queue:
x, y, dist = queue.popleft()

if (x, y) == end:
return dist

for dx, dy in directions:
nx, ny = x + dx, y + dy

if 0 <= nx < rows and 0 <= ny < cols:
if not visited[nx][ny] and maze[nx][ny] != '#':
visited[nx][ny] = True
queue.append((nx, ny, dist + 1))

return -1

maze = [
"....#",
".#...",
".#.#.",
"...#.",
"#...."
]

maze_grid = [list(row) for row in maze]
result = maze_shortest_path(maze_grid, (0, 0), (4, 4))
print(f"最短路径: {result}")

案例:表达式计算器

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
def advanced_calculator(expression):
import re

tokens = re.findall(r'\d+|[+\-*/()]', expression)

def precedence(op):
if op in '+-':
return 1
if op in '*/':
return 2
return 0

def apply_op(a, b, op):
if op == '+': return a + b
if op == '-': return a - b
if op == '*': return a * b
if op == '/': return a / b

values = []
ops = []

for token in tokens:
if token.isdigit():
values.append(int(token))
elif token == '(':
ops.append(token)
elif token == ')':
while ops and ops[-1] != '(':
b, a = values.pop(), values.pop()
values.append(apply_op(a, b, ops.pop()))
ops.pop()
else:
while ops and precedence(ops[-1]) >= precedence(token):
b, a = values.pop(), values.pop()
values.append(apply_op(a, b, ops.pop()))
ops.append(token)

while ops:
b, a = values.pop(), values.pop()
values.append(apply_op(a, b, ops.pop()))

return values[0]

print(f"计算结果: {advanced_calculator('3+4*2-6/3')}")

案例:调度问题

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
def task_scheduling(tasks):
tasks.sort(key=lambda x: x[1])

selected = []
end_time = 0

for task in tasks:
if task[0] >= end_time:
selected.append(task)
end_time = task[1]

return selected

def task_scheduling_with_profit(tasks):
tasks.sort(key=lambda x: x[1])

n = len(tasks)
dp = [0] * (n + 1)

for i in range(1, n + 1):
start, end, profit = tasks[i - 1]

j = i - 1
while j > 0 and tasks[j - 1][1] > start:
j -= 1

dp[i] = max(dp[i - 1], dp[j] + profit)

return dp[n]

tasks = [(1, 3, 50), (2, 5, 20), (4, 6, 60), (6, 7, 30)]
print(f"最大收益: {task_scheduling_with_profit(tasks)}")

五十六、计算几何进阶

点与向量运算

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import math

class Point:
def __init__(self, x=0, y=0):
self.x = x
self.y = y

def __add__(self, other):
return Point(self.x + other.x, self.y + other.y)

def __sub__(self, other):
return Point(self.x - other.x, self.y - other.y)

def __mul__(self, k):
return Point(self.x * k, self.y * k)

def __truediv__(self, k):
return Point(self.x / k, self.y / k)

def __neg__(self):
return Point(-self.x, -self.y)

def dot(self, other):
return self.x * other.x + self.y * other.y

def cross(self, other):
return self.x * other.y - self.y * other.x

def length(self):
return math.sqrt(self.x ** 2 + self.y ** 2)

def length_sq(self):
return self.x ** 2 + self.y ** 2

def normalize(self):
l = self.length()
if l > 0:
return self / l
return Point(0, 0)

def rotate(self, angle):
cos_a = math.cos(angle)
sin_a = math.sin(angle)
return Point(self.x * cos_a - self.y * sin_a,
self.x * sin_a + self.y * cos_a)

def angle(self):
return math.atan2(self.y, self.x)

def distance_to(self, other):
return (self - other).length()

def __repr__(self):
return f"Point({self.x}, {self.y})"

def angle_between(a, b):
return math.acos(max(-1, min(1, a.dot(b) / (a.length() * b.length()))))

def triangle_area(a, b, c):
return abs((b - a).cross(c - a)) / 2

def orientation(a, b, c):
val = (b - a).cross(c - a)
if val > 0:
return 1
elif val < 0:
return -1
return 0

p1 = Point(0, 0)
p2 = Point(3, 4)
p3 = Point(1, 1)

print(f"点距离: {p1.distance_to(p2)}")
print(f"三角形面积: {triangle_area(p1, p2, p3)}")
print(f"方向: {orientation(p1, p2, p3)}")

线段与多边形

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
class Line:
def __init__(self, p1, p2):
self.p1 = p1
self.p2 = p2

def direction(self):
return self.p2 - self.p1

def length(self):
return self.direction().length()

def point_at(self, t):
return self.p1 + self.direction() * t

def distance_to_point(self, p):
v = self.direction()
w = p - self.p1
c1 = w.dot(v)
if c1 <= 0:
return (p - self.p1).length()

c2 = v.dot(v)
if c2 <= c1:
return (p - self.p2).length()

b = c1 / c2
pb = self.p1 + v * b
return (p - pb).length()

def closest_point(self, p):
v = self.direction()
w = p - self.p1
c1 = w.dot(v)
c2 = v.dot(v)

if c1 <= 0:
return self.p1
if c2 <= c1:
return self.p2

b = c1 / c2
return self.p1 + v * b

def segments_intersect(l1, l2):
def on_segment(p, q, r):
return (min(p.x, r.x) <= q.x <= max(p.x, r.x) and
min(p.y, r.y) <= q.y <= max(p.y, r.y))

o1 = orientation(l1.p1, l1.p2, l2.p1)
o2 = orientation(l1.p1, l1.p2, l2.p2)
o3 = orientation(l2.p1, l2.p2, l1.p1)
o4 = orientation(l2.p1, l2.p2, l1.p2)

if o1 != o2 and o3 != o4:
return True

if o1 == 0 and on_segment(l1.p1, l2.p1, l1.p2):
return True
if o2 == 0 and on_segment(l1.p1, l2.p2, l1.p2):
return True
if o3 == 0 and on_segment(l2.p1, l1.p1, l2.p2):
return True
if o4 == 0 and on_segment(l2.p1, l1.p2, l2.p2):
return True

return False

def line_intersection(l1, l2):
d1 = l1.direction()
d2 = l2.direction()

cross = d1.cross(d2)
if abs(cross) < 1e-10:
return None

t = (l2.p1 - l1.p1).cross(d2) / cross
return l1.p1 + d1 * t

def polygon_area(vertices):
n = len(vertices)
area = 0
for i in range(n):
j = (i + 1) % n
area += vertices[i].cross(vertices[j])
return abs(area) / 2

def point_in_polygon(p, vertices):
n = len(vertices)
inside = False

j = n - 1
for i in range(n):
if ((vertices[i].y > p.y) != (vertices[j].y > p.y) and
p.x < (vertices[j].x - vertices[i].x) * (p.y - vertices[i].y) /
(vertices[j].y - vertices[i].y) + vertices[i].x):
inside = not inside
j = i

return inside

def convex_hull(points):
n = len(points)
if n < 3:
return points

points = sorted(points, key=lambda p: (p.x, p.y))

lower = []
for p in points:
while len(lower) >= 2 and orientation(lower[-2], lower[-1], p) <= 0:
lower.pop()
lower.append(p)

upper = []
for p in reversed(points):
while len(upper) >= 2 and orientation(upper[-2], upper[-1], p) <= 0:
upper.pop()
upper.append(p)

return lower[:-1] + upper[:-1]

vertices = [Point(0, 0), Point(4, 0), Point(4, 3), Point(0, 3)]
print(f"多边形面积: {polygon_area(vertices)}")
print(f"点 (2, 2) 在多边形内: {point_in_polygon(Point(2, 2), vertices)}")

圆与球

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
class Circle:
def __init__(self, center, radius):
self.center = center
self.radius = radius

def area(self):
return math.pi * self.radius ** 2

def circumference(self):
return 2 * math.pi * self.radius

def contains_point(self, p):
return (p - self.center).length() <= self.radius

def distance_to_point(self, p):
d = (p - self.center).length()
if d <= self.radius:
return 0
return d - self.radius

def intersects_circle(self, other):
d = (self.center - other.center).length()
return d < self.radius + other.radius and d > abs(self.radius - other.radius)

def circles_intersection(c1, c2):
d = (c1.center - c2.center).length()

if d > c1.radius + c2.radius or d < abs(c1.radius - c2.radius) or d == 0:
return []

a = (c1.radius ** 2 - c2.radius ** 2 + d ** 2) / (2 * d)
h = math.sqrt(c1.radius ** 2 - a ** 2)

p2 = c1.center + (c2.center - c1.center) * (a / d)

offset = (c2.center - c1.center).normalize().rotate(math.pi / 2) * h

return [p2 + offset, p2 - offset]

def circle_line_intersection(circle, line):
d = line.direction()
f = line.p1 - circle.center

a = d.dot(d)
b = 2 * f.dot(d)
c = f.dot(f) - circle.radius ** 2

discriminant = b ** 2 - 4 * a * c

if discriminant < 0:
return []

discriminant = math.sqrt(discriminant)

t1 = (-b - discriminant) / (2 * a)
t2 = (-b + discriminant) / (2 * a)

return [line.point_at(t1), line.point_at(t2)]

def minimum_enclosing_circle(points):
from random import shuffle
shuffle(points)

circle = Circle(points[0], 0)

for i in range(1, len(points)):
if not circle.contains_point(points[i]):
circle = Circle(points[i], 0)
for j in range(i):
if not circle.contains_point(points[j]):
circle = Circle((points[i] + points[j]) / 2,
(points[i] - points[j]).length() / 2)
for k in range(j):
if not circle.contains_point(points[k]):
a, b, c = points[i], points[j], points[k]
d = 2 * (a.x * (b.y - c.y) + b.x * (c.y - a.y) + c.x * (a.y - b.y))
ux = ((a.x ** 2 + a.y ** 2) * (b.y - c.y) +
(b.x ** 2 + b.y ** 2) * (c.y - a.y) +
(c.x ** 2 + c.y ** 2) * (a.y - b.y)) / d
uy = ((a.x ** 2 + a.y ** 2) * (c.x - b.x) +
(b.x ** 2 + b.y ** 2) * (a.x - c.x) +
(c.x ** 2 + c.y ** 2) * (b.x - a.x)) / d
center = Point(ux, uy)
circle = Circle(center, (center - a).length())

return circle

五十七、字符串算法进阶

扩展 KMP

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
def z_function(s):
n = len(s)
z = [0] * n
z[0] = n

l, r = 0, 0
for i in range(1, n):
if i < r:
z[i] = min(r - i, z[i - l])

while i + z[i] < n and s[z[i]] == s[i + z[i]]:
z[i] += 1

if i + z[i] > r:
l, r = i, i + z[i]

return z

def extended_kmp(text, pattern):
concat = pattern + '#' + text
z = z_function(concat)

n = len(pattern)
result = []

for i in range(n + 1, len(concat)):
if z[i] >= n:
result.append(i - n - 1)

return result

s = "aabcaabxaaz"
print(f"Z 函数: {z_function(s)}")

text = "abcabcabc"
pattern = "abc"
print(f"匹配位置: {extended_kmp(text, pattern)}")

Manacher 算法

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
def manacher(s):
t = '#' + '#'.join(s) + '#'
n = len(t)
p = [0] * n

center, right = 0, 0

for i in range(n):
if i < right:
mirror = 2 * center - i
p[i] = min(right - i, p[mirror])

while i - p[i] - 1 >= 0 and i + p[i] + 1 < n and t[i - p[i] - 1] == t[i + p[i] + 1]:
p[i] += 1

if i + p[i] > right:
center, right = i, i + p[i]

return p

def longest_palindrome(s):
p = manacher(s)
max_len = max(p)
center = p.index(max_len)

start = (center - max_len) // 2
return s[start:start + max_len]

def count_all_palindromes(s):
p = manacher(s)
return sum((x + 1) // 2 for x in p)

s = "babad"
print(f"最长回文子串: {longest_palindrome(s)}")
print(f"回文子串总数: {count_all_palindromes(s)}")

字符串哈希

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
class StringHash:
def __init__(self, s, base=131, mod=10**9 + 7):
self.n = len(s)
self.base = base
self.mod = mod

self.power = [1] * (self.n + 1)
self.hash = [0] * (self.n + 1)

for i in range(self.n):
self.power[i + 1] = (self.power[i] * base) % mod
self.hash[i + 1] = (self.hash[i] * base + ord(s[i])) % mod

def get_hash(self, l, r):
return (self.hash[r] - self.hash[l] * self.power[r - l] % self.mod + self.mod) % mod

def get_full_hash(self):
return self.hash[self.n]

def find_common_substring(s1, s2, length):
h1 = StringHash(s1)
h2 = StringHash(s2)

seen = set()
for i in range(len(s1) - length + 1):
seen.add(h1.get_hash(i, i + length))

for i in range(len(s2) - length + 1):
if h2.get_hash(i, i + length) in seen:
return True, i

return False, -1

def longest_common_substring(s1, s2):
low, high = 0, min(len(s1), len(s2))
result = ""

while low <= high:
mid = (low + high) // 2
found, pos = find_common_substring(s1, s2, mid)

if found:
result = s2[pos:pos + mid]
low = mid + 1
else:
high = mid - 1

return result

s1 = "abcdefg"
s2 = "cdefghi"
print(f"最长公共子串: {longest_common_substring(s1, s2)}")

最小表示法

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
def minimal_representation(s):
n = len(s)
s = s + s
i, j, k = 0, 1, 0

while i < n and j < n and k < n:
if s[i + k] == s[j + k]:
k += 1
elif s[i + k] > s[j + k]:
i = i + k + 1
if i <= j:
i = j + 1
k = 0
else:
j = j + k + 1
if j <= i:
j = i + 1
k = 0

return min(i, j)

def get_minimal_string(s):
pos = minimal_representation(s)
return s[pos:] + s[:pos]

s = "bbaa"
print(f"最小表示: {get_minimal_string(s)}")

五十八、图论算法进阶

强连通分量

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
def tarjan_scc(n, adj):
index = [0] * n
low = [0] * n
on_stack = [False] * n
stack = []
sccs = []
idx = [0]

def dfs(v):
index[v] = low[v] = idx[0] + 1
idx[0] += 1
stack.append(v)
on_stack[v] = True

for u in adj[v]:
if index[u] == 0:
dfs(u)
low[v] = min(low[v], low[u])
elif on_stack[u]:
low[v] = min(low[v], index[u])

if low[v] == index[v]:
scc = []
while True:
u = stack.pop()
on_stack[u] = False
scc.append(u)
if u == v:
break
sccs.append(scc)

for v in range(n):
if index[v] == 0:
dfs(v)

return sccs

def kosaraju_scc(n, adj):
visited = [False] * n
order = []

def dfs1(v):
visited[v] = True
for u in adj[v]:
if not visited[u]:
dfs1(u)
order.append(v)

for v in range(n):
if not visited[v]:
dfs1(v)

radj = [[] for _ in range(n)]
for v in range(n):
for u in adj[v]:
radj[u].append(v)

visited = [False] * n
sccs = []

def dfs2(v, scc):
visited[v] = True
scc.append(v)
for u in radj[v]:
if not visited[u]:
dfs2(u, scc)

for v in reversed(order):
if not visited[v]:
scc = []
dfs2(v, scc)
sccs.append(scc)

return sccs

n = 5
adj = [[1], [2], [0, 3], [4], []]
print(f"强连通分量 (Tarjan): {tarjan_scc(n, adj)}")
print(f"强连通分量 (Kosaraju): {kosaraju_scc(n, adj)}")

割点与桥

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
def find_bridges(n, adj):
disc = [0] * n
low = [0] * n
visited = [False] * n
bridges = []
timer = [0]

def dfs(v, parent):
visited[v] = True
disc[v] = low[v] = timer[0] + 1
timer[0] += 1

for u in adj[v]:
if u == parent:
continue
if visited[u]:
low[v] = min(low[v], disc[u])
else:
dfs(u, v)
low[v] = min(low[v], low[u])
if low[u] > disc[v]:
bridges.append((v, u))

for v in range(n):
if not visited[v]:
dfs(v, -1)

return bridges

def find_articulation_points(n, adj):
disc = [0] * n
low = [0] * n
visited = [False] * n
ap = [False] * n
timer = [0]

def dfs(v, parent, root):
children = 0
visited[v] = True
disc[v] = low[v] = timer[0] + 1
timer[0] += 1

for u in adj[v]:
if u == parent:
continue
if visited[u]:
low[v] = min(low[v], disc[u])
else:
children += 1
dfs(u, v, root)
low[v] = min(low[v], low[u])

if parent != -1 and low[u] >= disc[v]:
ap[v] = True

if parent == -1 and children > 1:
ap[v] = True

for v in range(n):
if not visited[v]:
dfs(v, -1, v)

return [i for i in range(n) if ap[i]]

n = 5
adj = [[1, 2], [0, 2], [0, 1, 3, 4], [2], [2]]
print(f"桥: {find_bridges(n, adj)}")
print(f"割点: {find_articulation_points(n, adj)}")

2-SAT 问题

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
class TwoSAT:
def __init__(self, n):
self.n = n
self.adj = [[] for _ in range(2 * n)]
self.radj = [[] for _ in range(2 * n)]

def add_clause(self, a, va, b, vb):
self.adj[2 * a + va].append(2 * b + 1 - vb)
self.adj[2 * b + vb].append(2 * a + 1 - va)
self.radj[2 * b + 1 - vb].append(2 * a + va)
self.radj[2 * a + 1 - va].append(2 * b + vb)

def solve(self):
visited = [False] * (2 * self.n)
order = []

def dfs1(v):
visited[v] = True
for u in self.adj[v]:
if not visited[u]:
dfs1(u)
order.append(v)

for v in range(2 * self.n):
if not visited[v]:
dfs1(v)

comp = [-1] * (2 * self.n)
cid = [0]

def dfs2(v):
comp[v] = cid[0]
for u in self.radj[v]:
if comp[u] == -1:
dfs2(u)

for v in reversed(order):
if comp[v] == -1:
dfs2(v)
cid[0] += 1

for i in range(self.n):
if comp[2 * i] == comp[2 * i + 1]:
return False, []

assignment = [False] * self.n
for i in range(self.n):
assignment[i] = comp[2 * i] > comp[2 * i + 1]

return True, assignment

ts = TwoSAT(3)
ts.add_clause(0, 1, 1, 0)
ts.add_clause(1, 1, 2, 1)
ts.add_clause(0, 0, 2, 0)

satisfiable, assignment = ts.solve()
print(f"可满足: {satisfiable}")
if satisfiable:
print(f"赋值: {assignment}")

差分约束

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
def difference_constraints(n, constraints):
adj = [[] for _ in range(n + 1)]

for u, v, c in constraints:
adj[v].append((u, c))

for i in range(n):
adj[n].append((i, 0))

dist = [float('inf')] * (n + 1)
dist[n] = 0

for _ in range(n):
for u in range(n + 1):
for v, w in adj[u]:
if dist[u] + w < dist[v]:
dist[v] = dist[u] + w

for u in range(n + 1):
for v, w in adj[u]:
if dist[u] + w < dist[v]:
return None

return dist[:-1]

constraints = [(0, 1, 3), (1, 2, -2), (2, 0, 1)]
result = difference_constraints(3, constraints)
print(f"差分约束解: {result}")

五十九、数论进阶

欧拉函数

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
def euler_phi(n):
result = n
p = 2
while p * p <= n:
if n % p == 0:
while n % p == 0:
n //= p
result -= result // p
p += 1
if n > 1:
result -= result // n
return result

def euler_phi_sieve(n):
phi = list(range(n + 1))
for i in range(2, n + 1):
if phi[i] == i:
for j in range(i, n + 1, i):
phi[j] -= phi[j] // i
return phi

def sum_phi(n):
phi = euler_phi_sieve(n)
return sum(phi)

print(f"φ(12) = {euler_phi(12)}")
print(f"φ 数组 (1-10): {euler_phi_sieve(10)[1:]}")

中国剩余定理

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
def extended_gcd(a, b):
if b == 0:
return a, 1, 0
g, x, y = extended_gcd(b, a % b)
return g, y, x - (a // b) * y

def crt(remainders, moduli):
n = len(remainders)
result = remainders[0]
m = moduli[0]

for i in range(1, n):
a = remainders[i]
m2 = moduli[i]

g, p, q = extended_gcd(m, m2)

if (a - result) % g != 0:
return None

lcm = m // g * m2
result = (result + m * ((a - result) // g * p % (m2 // g))) % lcm
m = lcm

return result if result >= 0 else result + m

remainders = [2, 3, 2]
moduli = [3, 5, 7]
print(f"CRT 结果: {crt(remainders, moduli)}")

离散对数

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
def discrete_log(a, b, m):
n = int(m ** 0.5) + 1

value = {}
an = 1
for i in range(n):
if an not in value:
value[an] = i
an = (an * a) % m

a_n = pow(a, n, m)
cur = b

for i in range(n + 1):
if cur in value:
ans = value[cur] + i * n
if pow(a, ans, m) == b:
return ans
cur = (cur * a_n) % m

return None

print(f"离散对数 log_2(3) mod 7 = {discrete_log(2, 3, 7)}")

原根

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
def primitive_root(p):
phi = p - 1
factors = set()
n = phi
i = 2
while i * i <= n:
if n % i == 0:
factors.add(i)
while n % i == 0:
n //= i
i += 1
if n > 1:
factors.add(n)

for g in range(2, p + 1):
ok = True
for f in factors:
if pow(g, phi // f, p) == 1:
ok = False
break
if ok:
return g
return None

print(f"7 的原根: {primitive_root(7)}")

NTT 快速数论变换

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
def ntt(a, invert, root, mod):
n = len(a)
j = 0
for i in range(1, n):
bit = n >> 1
while j & bit:
j ^= bit
bit >>= 1
j ^= bit
if i < j:
a[i], a[j] = a[j], a[i]

length = 2
while length <= n:
wlen = pow(root, (mod - 1) // length, mod)
if invert:
wlen = pow(wlen, mod - 2, mod)

for i in range(0, n, length):
w = 1
for j in range(i, i + length // 2):
u = a[j]
v = a[j + length // 2] * w % mod
a[j] = (u + v) % mod
a[j + length // 2] = (u - v + mod) % mod
w = w * wlen % mod

length *= 2

if invert:
n_inv = pow(n, mod - 2, mod)
for i in range(n):
a[i] = a[i] * n_inv % mod

def polynomial_multiply_ntt(a, b):
MOD = 998244353
ROOT = 3

result_size = len(a) + len(b) - 1
n = 1
while n < result_size:
n *= 2

fa = a + [0] * (n - len(a))
fb = b + [0] * (n - len(b))

ntt(fa, False, ROOT, MOD)
ntt(fb, False, ROOT, MOD)

for i in range(n):
fa[i] = fa[i] * fb[i] % MOD

ntt(fa, True, ROOT, MOD)

return fa[:result_size]

a = [1, 2, 3]
b = [4, 5, 6]
print(f"多项式乘积: {polynomial_multiply_ntt(a, b)}")

六十、更多真题解析

2021 年真题:卡片

题目:小蓝有很多数字卡片,每张卡片上都是数字 0 到 9。小蓝准备用这些卡片来拼一些数,他想从 1 开始拼出正整数,每拼一个就保存起来,卡片就不能用来拼其它数了。小蓝想知道自己能从 1 拼到多少。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
def card_puzzle():
cards = [2021] * 10

num = 1
while True:
n = num
can_form = True

temp = [0] * 10
while n > 0:
temp[n % 10] += 1
n //= 10

for i in range(10):
if temp[i] > cards[i]:
can_form = False
break

if not can_form:
return num - 1

for i in range(10):
cards[i] -= temp[i]

num += 1

print(f"能拼到的最大数: {card_puzzle()}")

2021 年真题:直线

题目:在平面直角坐标系中,两点可以确定一条直线。如果给定两点坐标,求经过这两点的直线方程。现在给定多个点,求这些点能确定多少条不同的直线。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
def count_lines(points):
lines = set()
n = len(points)

for i in range(n):
for j in range(i + 1, n):
x1, y1 = points[i]
x2, y2 = points[j]

if x1 == x2:
lines.add(('v', x1))
elif y1 == y2:
lines.add(('h', y1))
else:
a = y2 - y1
b = x1 - x2
c = x2 * y1 - x1 * y2

g = gcd(gcd(abs(a), abs(b)), abs(c))
a, b, c = a // g, b // g, c // g

if a < 0 or (a == 0 and b < 0):
a, b, c = -a, -b, -c

lines.add((a, b, c))

return len(lines)

from math import gcd

points = [(0, 0), (1, 1), (2, 2), (0, 1), (1, 0)]
print(f"不同直线数: {count_lines(points)}")

2021 年真题:货物摆放

题目:小蓝有一个超大的仓库,可以摆放很多货物。现在,小蓝有 n 箱货物要摆放在仓库,每箱货物都是规则的正方体。小蓝规定了”长”、”宽”、”高”三个方向,每箱货物的边长必须严格等于其长、宽、高。小蓝希望所有的货物最终摆成一个大的长方体。即在长、宽、高的方向上分别堆 a、b、c 箱货物,满足 a × b × c = n。请问有多少种堆放货物的方案?

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
def count_arrangements(n):
factors = []
i = 1
while i * i <= n:
if n % i == 0:
factors.append(i)
if i != n // i:
factors.append(n // i)
i += 1

count = 0
for a in factors:
if n % a == 0:
remaining = n // a
j = 1
while j * j <= remaining:
if remaining % j == 0:
count += 1
if j != remaining // j:
count += 1
j += 1

return count

print(f"摆放方案数 (n=4): {count_arrangements(4)}")

2020 年真题:既约分数

题目:如果一个分数的分子和分母的最大公约数是 1,这个分数称为既约分数。请问,有多少个既约分数,分子和分母都是 1 到 2020 之间的整数(包括 1 和 2020)?

1
2
3
4
5
6
7
8
9
def count_reduced_fractions(n):
count = 0
for i in range(1, n + 1):
for j in range(1, n + 1):
if gcd(i, j) == 1:
count += 1
return count

print(f"既约分数数量: {count_reduced_fractions(2020)}")

2017 年真题:青蛙跳杯子

题目:X 星球的青蛙很特别,它们可以从一个杯子跳到另一个杯子。有 N 个杯子排成一排,青蛙初始在第 1 个杯子,目标是跳到第 N 个杯子。青蛙每次可以跳 1 格或 2 格。求有多少种跳法。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
def frog_jump(n):
if n <= 2:
return n

dp = [0] * (n + 1)
dp[1] = 1
dp[2] = 2

for i in range(3, n + 1):
dp[i] = dp[i - 1] + dp[i - 2]

return dp[n]

print(f"跳法数量 (n=10): {frog_jump(10)}")

六十一、高级数据结构

左偏树

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
class LeftistNode:
def __init__(self, val):
self.val = val
self.left = None
self.right = None
self.dist = 0

class LeftistHeap:
def __init__(self):
self.root = None

def _merge(self, a, b):
if not a:
return b
if not b:
return a

if a.val > b.val:
a, b = b, a

a.right = self._merge(a.right, b)

if not a.left or (a.right and a.left.dist < a.right.dist):
a.left, a.right = a.right, a.left

if a.right:
a.dist = a.right.dist + 1
else:
a.dist = 0

return a

def push(self, val):
node = LeftistNode(val)
self.root = self._merge(self.root, node)

def pop(self):
if not self.root:
return None
val = self.root.val
self.root = self._merge(self.root.left, self.root.right)
return val

def top(self):
return self.root.val if self.root else None

def is_empty(self):
return self.root is None

heap = LeftistHeap()
for x in [5, 3, 7, 1, 9, 2]:
heap.push(x)

while not heap.is_empty():
print(heap.pop(), end=' ')
print()

可并堆

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
class BinomialNode:
def __init__(self, val):
self.val = val
self.degree = 0
self.parent = None
self.child = None
self.sibling = None

class BinomialHeap:
def __init__(self):
self.head = None

def _merge_trees(self, t1, t2):
if t1.val > t2.val:
t1, t2 = t2, t1
t2.parent = t1
t2.sibling = t1.child
t1.child = t2
t1.degree += 1
return t1

def _merge_heaps(self, h1, h2):
if not h1:
return h2
if not h2:
return h1

if h1.degree < h2.degree:
h1.sibling = self._merge_heaps(h1.sibling, h2)
return h1
else:
h2.sibling = self._merge_heaps(h1, h2.sibling)
return h2

def _union(self, h2):
self.head = self._merge_heaps(self.head, h2.head)

if not self.head:
return

prev = None
curr = self.head
next_node = curr.sibling

while next_node:
if curr.degree != next_node.degree or \
(next_node.sibling and next_node.sibling.degree == curr.degree):
prev = curr
curr = next_node
else:
if curr.val <= next_node.val:
curr.sibling = next_node.sibling
self._merge_trees(curr, next_node)
else:
if not prev:
self.head = next_node
else:
prev.sibling = next_node
self._merge_trees(next_node, curr)
curr = next_node
next_node = curr.sibling

def insert(self, val):
new_heap = BinomialHeap()
new_heap.head = BinomialNode(val)
self._union(new_heap)

def get_min(self):
if not self.head:
return None

min_val = self.head.val
curr = self.head.sibling
while curr:
if curr.val < min_val:
min_val = curr.val
curr = curr.sibling
return min_val

树套树

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
class FenwickTree:
def __init__(self, n):
self.n = n
self.tree = [0] * (n + 1)

def update(self, i, delta):
while i <= self.n:
self.tree[i] += delta
i += i & -i

def query(self, i):
result = 0
while i > 0:
result += self.tree[i]
i -= i & -i
return result

class TwoDTree:
def __init__(self, n, m):
self.n = n
self.m = m
self.trees = [FenwickTree(m) for _ in range(n + 1)]

def update(self, x, y, delta):
i = x
while i <= self.n:
self.trees[i].update(y, delta)
i += i & -i

def query(self, x, y):
result = 0
i = x
while i > 0:
result += self.trees[i].query(y)
i -= i & -i
return result

def range_query(self, x1, y1, x2, y2):
return (self.query(x2, y2) - self.query(x1 - 1, y2) -
self.query(x2, y1 - 1) + self.query(x1 - 1, y1 - 1))

六十二、算法竞赛技巧

快速幂与矩阵快速幂

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
def matrix_mult(A, B, mod=10**9 + 7):
n = len(A)
m = len(B[0])
k = len(B)
C = [[0] * m for _ in range(n)]

for i in range(n):
for j in range(m):
for p in range(k):
C[i][j] = (C[i][j] + A[i][p] * B[p][j]) % mod

return C

def matrix_pow(M, n, mod=10**9 + 7):
size = len(M)
result = [[1 if i == j else 0 for j in range(size)] for i in range(size)]

while n > 0:
if n & 1:
result = matrix_mult(result, M, mod)
M = matrix_mult(M, M, mod)
n >>= 1

return result

def fibonacci_matrix(n):
if n <= 1:
return n

M = [[1, 1], [1, 0]]
result = matrix_pow(M, n - 1)
return result[0][0]

def linear_recurrence(coeffs, init, n, mod=10**9 + 7):
k = len(coeffs)

if n < k:
return init[n]

M = [[0] * k for _ in range(k)]
for i in range(k - 1):
M[i][i + 1] = 1
for i in range(k):
M[k - 1][i] = coeffs[k - 1 - i]

result = matrix_pow(M, n, mod)

ans = 0
for i in range(k):
ans = (ans + result[0][i] * init[i]) % mod

return ans

print(f"斐波那契第 10 项: {fibonacci_matrix(10)}")

状态压缩技巧

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
def count_bits(n):
count = 0
while n:
count += n & 1
n >>= 1
return count

def lowbit(n):
return n & -n

def enumerate_subsets(mask):
subsets = []
subset = mask
while subset:
subsets.append(subset)
subset = (subset - 1) & mask
subsets.append(0)
return subsets

def next_permutation_of_bits(mask, n):
if mask == 0:
return 0

c = mask & -mask
r = mask + c

if r >= (1 << n):
return 0

return r | (((r ^ mask) >> 2) // c)

def traveling_salesman_dp(dist):
n = len(dist)
INF = float('inf')

dp = [[INF] * n for _ in range(1 << n)]
dp[1][0] = 0

for mask in range(1, 1 << n):
for u in range(n):
if not (mask & (1 << u)):
continue

for v in range(n):
if mask & (1 << v):
continue

new_mask = mask | (1 << v)
dp[new_mask][v] = min(dp[new_mask][v], dp[mask][u] + dist[u][v])

full_mask = (1 << n) - 1
result = INF
for u in range(1, n):
result = min(result, dp[full_mask][u] + dist[u][0])

return result

dist = [[0, 10, 15, 20], [10, 0, 35, 25],
[15, 35, 0, 30], [20, 25, 30, 0]]
print(f"TSP 最短路径: {traveling_salesman_dp(dist)}")

记忆化搜索模板

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
from functools import lru_cache

def memo_template():
@lru_cache(maxsize=None)
def dfs(state):
if is_terminal(state):
return base_case(state)

result = initial_value()

for next_state in get_transitions(state):
result = combine(result, dfs(next_state))

return result

return dfs(initial_state)

def solve_with_memo():
from functools import lru_cache

@lru_cache(maxsize=None)
def dp(pos, state1, state2):
if pos == n:
return 0 if is_valid(state1, state2) else float('-inf')

result = float('-inf')

for choice in get_choices(pos, state1, state2):
new_state1 = update_state1(state1, choice)
new_state2 = update_state2(state2, choice)
result = max(result, dp(pos + 1, new_state1, new_state2) + get_value(choice))

return result

return dp(0, initial_state1, initial_state2)

六十三、模拟与实现技巧

大整数运算完整实现

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
class BigInt:
def __init__(self, num=0):
if isinstance(num, str):
self.sign = 1
if num[0] == '-':
self.sign = -1
num = num[1:]
self.digits = [int(c) for c in reversed(num)]
elif isinstance(num, int):
self.sign = 1 if num >= 0 else -1
num = abs(num)
self.digits = [num] if num == 0 else []
while num > 0:
self.digits.append(num % 10)
num //= 10
else:
self.sign = 1
self.digits = num[:]

self._normalize()

def _normalize(self):
while len(self.digits) > 1 and self.digits[-1] == 0:
self.digits.pop()
if len(self.digits) == 1 and self.digits[0] == 0:
self.sign = 1

def __abs__(self):
result = BigInt()
result.digits = self.digits[:]
result.sign = 1
return result

def __neg__(self):
result = BigInt()
result.digits = self.digits[:]
result.sign = -self.sign
result._normalize()
return result

def __add__(self, other):
if self.sign != other.sign:
if self.sign == -1:
return other - (-self)
else:
return self - (-other)

result = BigInt()
result.sign = self.sign

carry = 0
max_len = max(len(self.digits), len(other.digits))

for i in range(max_len):
a = self.digits[i] if i < len(self.digits) else 0
b = other.digits[i] if i < len(other.digits) else 0
total = a + b + carry
result.digits.append(total % 10)
carry = total // 10

if carry:
result.digits.append(carry)

result._normalize()
return result

def __sub__(self, other):
if self.sign != other.sign:
return self + (-other)

if abs(self) < abs(other):
result = other - self
result.sign = -result.sign
return result

result = BigInt()
result.sign = self.sign

borrow = 0
for i in range(len(self.digits)):
a = self.digits[i]
b = other.digits[i] if i < len(other.digits) else 0
diff = a - b - borrow

if diff < 0:
diff += 10
borrow = 1
else:
borrow = 0

result.digits.append(diff)

result._normalize()
return result

def __mul__(self, other):
result = BigInt()
result.sign = self.sign * other.sign
result.digits = [0] * (len(self.digits) + len(other.digits))

for i in range(len(self.digits)):
carry = 0
for j in range(len(other.digits)):
total = result.digits[i + j] + self.digits[i] * other.digits[j] + carry
result.digits[i + j] = total % 10
carry = total // 10

if carry:
result.digits[i + len(other.digits)] += carry

result._normalize()
return result

def __lt__(self, other):
if self.sign != other.sign:
return self.sign < other.sign

if len(self.digits) != len(other.digits):
return (len(self.digits) < len(other.digits)) == (self.sign == 1)

for i in range(len(self.digits) - 1, -1, -1):
if self.digits[i] != other.digits[i]:
return (self.digits[i] < other.digits[i]) == (self.sign == 1)

return False

def __str__(self):
s = ''.join(map(str, reversed(self.digits)))
return '-' + s if self.sign == -1 and s != '0' else s

a = BigInt("12345678901234567890")
b = BigInt("98765432109876543210")
print(f"加法: {a + b}")
print(f"乘法: {a * b}")

高精度除法

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
def big_int_divide(dividend, divisor):
if len(divisor) == 1 and divisor[0] == 0:
raise ValueError("Division by zero")

if compare(dividend, divisor) < 0:
return [0], dividend[:]

quotient = []
remainder = []

for digit in dividend:
remainder.append(digit)
remainder = remove_leading_zeros(remainder)

q = 0
while compare(remainder, divisor) >= 0:
remainder = subtract(remainder, divisor)
q += 1

quotient.append(q)

return remove_leading_zeros(quotient), remove_leading_zeros(remainder)

def compare(a, b):
if len(a) != len(b):
return len(a) - len(b)
for i in range(len(a) - 1, -1, -1):
if a[i] != b[i]:
return a[i] - b[i]
return 0

def subtract(a, b):
result = []
borrow = 0

for i in range(len(a)):
ai = a[i]
bi = b[i] if i < len(b) else 0
diff = ai - bi - borrow

if diff < 0:
diff += 10
borrow = 1
else:
borrow = 0

result.append(diff)

return remove_leading_zeros(result)

def remove_leading_zeros(arr):
while len(arr) > 1 and arr[-1] == 0:
arr.pop()
return arr

六十四、概率与随机算法

概率计算

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import random

def monte_carlo_pi(n):
inside = 0
for _ in range(n):
x, y = random.random(), random.random()
if x * x + y * y <= 1:
inside += 1
return 4 * inside / n

def birthday_paradox(n, trials=10000):
same_birthday = 0

for _ in range(trials):
birthdays = [random.randint(1, 365) for _ in range(n)]
if len(birthdays) != len(set(birthdays)):
same_birthday += 1

return same_birthday / trials

def expected_value(values, probabilities):
return sum(v * p for v, p in zip(values, probabilities))

def variance(values, probabilities):
mean = expected_value(values, probabilities)
return expected_value([(v - mean) ** 2 for v in values], probabilities)

print(f"蒙特卡洛 π: {monte_carlo_pi(100000)}")
print(f"生日悖论 (23人): {birthday_paradox(23)}")

随机化算法应用

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
def quick_select_randomized(arr, k):
if len(arr) == 1:
return arr[0]

pivot = random.choice(arr)

left = [x for x in arr if x < pivot]
mid = [x for x in arr if x == pivot]
right = [x for x in arr if x > pivot]

if k <= len(left):
return quick_select_randomized(left, k)
elif k <= len(left) + len(mid):
return pivot
else:
return quick_select_randomized(right, k - len(left) - len(mid))

def randomized_partition(arr, low, high):
pivot_idx = random.randint(low, high)
arr[pivot_idx], arr[high] = arr[high], arr[pivot_idx]

pivot = arr[high]
i = low - 1

for j in range(low, high):
if arr[j] <= pivot:
i += 1
arr[i], arr[j] = arr[j], arr[i]

arr[i + 1], arr[high] = arr[high], arr[i + 1]
return i + 1

def randomized_quicksort(arr, low, high):
if low < high:
pi = randomized_partition(arr, low, high)
randomized_quicksort(arr, low, pi - 1)
randomized_quicksort(arr, pi + 1, high)

def shuffle_array(arr):
for i in range(len(arr) - 1, 0, -1):
j = random.randint(0, i)
arr[i], arr[j] = arr[j], arr[i]
return arr

arr = [3, 1, 4, 1, 5, 9, 2, 6]
print(f"第 4 小元素: {quick_select_randomized(arr, 4)}")

蓄水池抽样

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
class ReservoirSampler:
def __init__(self, k):
self.k = k
self.samples = []
self.count = 0

def add(self, item):
self.count += 1

if len(self.samples) < self.k:
self.samples.append(item)
else:
j = random.randint(0, self.count - 1)
if j < self.k:
self.samples[j] = item

def get_samples(self):
return self.samples

def reservoir_sample(stream, k):
result = []

for i, item in enumerate(stream):
if i < k:
result.append(item)
else:
j = random.randint(0, i)
if j < k:
result[j] = item

return result

六十五、综合练习题

练习 1:区间最值查询

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
class SparseTable:
def __init__(self, arr):
self.n = len(arr)
self.log = [0] * (self.n + 1)

for i in range(2, self.n + 1):
self.log[i] = self.log[i // 2] + 1

self.k = self.log[self.n] + 1
self.st = [[0] * self.n for _ in range(self.k)]

for i in range(self.n):
self.st[0][i] = arr[i]

for j in range(1, self.k):
for i in range(self.n - (1 << j) + 1):
self.st[j][i] = min(self.st[j - 1][i], self.st[j - 1][i + (1 << (j - 1))])

def query(self, l, r):
j = self.log[r - l + 1]
return min(self.st[j][l], self.st[j][r - (1 << j) + 1])

arr = [1, 3, 2, 7, 9, 11, 3, 5, 6]
st = SparseTable(arr)
print(f"区间 [2, 5] 最小值: {st.query(2, 5)}")

练习 2:最近公共祖先

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
class LCA:
def __init__(self, adj, root=0):
self.n = len(adj)
self.adj = adj
self.root = root

self.depth = [0] * self.n
self.parent = [[-1] * self.n for _ in range(20)]

self._dfs(root, -1, 0)

for k in range(1, 20):
for v in range(self.n):
if self.parent[k - 1][v] != -1:
self.parent[k][v] = self.parent[k - 1][self.parent[k - 1][v]]

def _dfs(self, v, p, d):
self.parent[0][v] = p
self.depth[v] = d

for u in self.adj[v]:
if u != p:
self._dfs(u, v, d + 1)

def get_lca(self, u, v):
if self.depth[u] < self.depth[v]:
u, v = v, u

diff = self.depth[u] - self.depth[v]

for k in range(20):
if diff & (1 << k):
u = self.parent[k][u]

if u == v:
return u

for k in range(19, -1, -1):
if self.parent[k][u] != self.parent[k][v]:
u = self.parent[k][u]
v = self.parent[k][v]

return self.parent[0][u]

def get_distance(self, u, v):
lca = self.get_lca(u, v)
return self.depth[u] + self.depth[v] - 2 * self.depth[lca]

adj = [[1, 2], [0, 3, 4], [0, 5], [1], [1], [2]]
lca = LCA(adj)
print(f"LCA(3, 5): {lca.get_lca(3, 5)}")
print(f"距离(3, 5): {lca.get_distance(3, 5)}")

练习 3:滑动窗口

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
from collections import deque

def sliding_window_max(arr, k):
result = []
dq = deque()

for i in range(len(arr)):
while dq and dq[0] < i - k + 1:
dq.popleft()

while dq and arr[dq[-1]] < arr[i]:
dq.pop()

dq.append(i)

if i >= k - 1:
result.append(arr[dq[0]])

return result

def sliding_window_min(arr, k):
result = []
dq = deque()

for i in range(len(arr)):
while dq and dq[0] < i - k + 1:
dq.popleft()

while dq and arr[dq[-1]] > arr[i]:
dq.pop()

dq.append(i)

if i >= k - 1:
result.append(arr[dq[0]])

return result

arr = [1, 3, -1, -3, 5, 3, 6, 7]
k = 3
print(f"滑动窗口最大值: {sliding_window_max(arr, k)}")
print(f"滑动窗口最小值: {sliding_window_min(arr, k)}")

练习 4:单调栈

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
def next_greater_element(arr):
n = len(arr)
result = [-1] * n
stack = []

for i in range(n):
while stack and arr[stack[-1]] < arr[i]:
result[stack.pop()] = arr[i]
stack.append(i)

return result

def previous_smaller_element(arr):
n = len(arr)
result = [-1] * n
stack = []

for i in range(n):
while stack and arr[stack[-1]] >= arr[i]:
stack.pop()

if stack:
result[i] = arr[stack[-1]]

stack.append(i)

return result

def largest_rectangle_in_histogram(heights):
stack = []
max_area = 0
heights.append(0)

for i, h in enumerate(heights):
while stack and heights[stack[-1]] > h:
height = heights[stack.pop()]
width = i if not stack else i - stack[-1] - 1
max_area = max(max_area, height * width)
stack.append(i)

heights.pop()
return max_area

arr = [2, 1, 2, 4, 3]
print(f"下一个更大元素: {next_greater_element(arr)}")
print(f"柱状图最大矩形: {largest_rectangle_in_histogram([2, 1, 5, 6, 2, 3])}")

练习 5:拓扑排序应用

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
def topological_sort_kahn(n, edges):
adj = [[] for _ in range(n)]
in_degree = [0] * n

for u, v in edges:
adj[u].append(v)
in_degree[v] += 1

from collections import deque
queue = deque([i for i in range(n) if in_degree[i] == 0])
result = []

while queue:
u = queue.popleft()
result.append(u)

for v in adj[u]:
in_degree[v] -= 1
if in_degree[v] == 0:
queue.append(v)

return result if len(result) == n else None

def course_schedule(num_courses, prerequisites):
edges = [(v, u) for u, v in prerequisites]
result = topological_sort_kahn(num_courses, edges)
return result is not None

def find_order(num_courses, prerequisites):
edges = [(v, u) for u, v in prerequisites]
return topological_sort_kahn(num_courses, edges)

prerequisites = [[1, 0], [2, 0], [3, 1], [3, 2]]
print(f"能否完成课程: {course_schedule(4, prerequisites)}")
print(f"课程顺序: {find_order(4, prerequisites)}")

六十六、动态规划专题进阶

区间 DP

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
def matrix_chain_multiplication(dims):
n = len(dims) - 1
dp = [[0] * n for _ in range(n)]

for length in range(2, n + 1):
for i in range(n - length + 1):
j = i + length - 1
dp[i][j] = float('inf')

for k in range(i, j):
cost = dp[i][k] + dp[k + 1][j] + dims[i] * dims[k + 1] * dims[j + 1]
dp[i][j] = min(dp[i][j], cost)

return dp[0][n - 1]

def optimal_bst(keys, freq):
n = len(keys)
dp = [[0] * n for _ in range(n)]
prefix_sum = [0] * (n + 1)

for i in range(n):
prefix_sum[i + 1] = prefix_sum[i] + freq[i]

for length in range(1, n + 1):
for i in range(n - length + 1):
j = i + length - 1
dp[i][j] = float('inf')

for k in range(i, j + 1):
left = dp[i][k - 1] if k > i else 0
right = dp[k + 1][j] if k < j else 0
cost = left + right + prefix_sum[j + 1] - prefix_sum[i]
dp[i][j] = min(dp[i][j], cost)

return dp[0][n - 1]

def palindrome_partitioning(s):
n = len(s)
is_palindrome = [[False] * n for _ in range(n)]

for i in range(n):
is_palindrome[i][i] = True

for i in range(n - 1):
is_palindrome[i][i + 1] = s[i] == s[i + 1]

for length in range(3, n + 1):
for i in range(n - length + 1):
j = i + length - 1
is_palindrome[i][j] = (s[i] == s[j]) and is_palindrome[i + 1][j - 1]

dp = [float('inf')] * n

for i in range(n):
if is_palindrome[0][i]:
dp[i] = 0
else:
for j in range(i):
if is_palindrome[j + 1][i]:
dp[i] = min(dp[i], dp[j] + 1)

return dp[n - 1]

def stone_merge(stones):
n = len(stones)
prefix = [0] * (n + 1)

for i in range(n):
prefix[i + 1] = prefix[i] + stones[i]

dp = [[0] * n for _ in range(n)]

for length in range(2, n + 1):
for i in range(n - length + 1):
j = i + length - 1
dp[i][j] = float('inf')

for k in range(i, j):
cost = dp[i][k] + dp[k + 1][j] + prefix[j + 1] - prefix[i]
dp[i][j] = min(dp[i][j], cost)

return dp[0][n - 1]

dims = [10, 30, 5, 60]
print(f"矩阵链乘最小代价: {matrix_chain_multiplication(dims)}")
print(f"石子合并最小代价: {stone_merge([1, 3, 5, 2])}")

树形 DP

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
def tree_dp_max_independent_set(adj, n, root=0):
dp = [[0, 0] for _ in range(n)]

def dfs(u, parent):
dp[u][0] = 0
dp[u][1] = 1

for v in adj[u]:
if v != parent:
dfs(v, u)
dp[u][0] += max(dp[v][0], dp[v][1])
dp[u][1] += dp[v][0]

dfs(root, -1)
return max(dp[root][0], dp[root][1])

def tree_dp_diameter(adj, n):
diameter = [0]

def dfs(u, parent):
max1, max2 = 0, 0

for v in adj[u]:
if v != parent:
depth = dfs(v, u) + 1
if depth > max1:
max2 = max1
max1 = depth
elif depth > max2:
max2 = depth

diameter[0] = max(diameter[0], max1 + max2)
return max1

dfs(0, -1)
return diameter[0]

def tree_dp_tree_backpack(adj, values, weights, capacity, n, root=0):
dp = [[0] * (capacity + 1) for _ in range(n)]

def dfs(u, parent):
for v in adj[u]:
if v != parent:
dfs(v, u)

for j in range(capacity, -1, -1):
for k in range(j + 1):
dp[u][j] = max(dp[u][j], dp[u][j - k] + dp[v][k])

for j in range(capacity, weights[u] - 1, -1):
dp[u][j] = max(dp[u][j], dp[u][j - weights[u]] + values[u])

dfs(root, -1)
return dp[root][capacity]

adj = [[1, 2], [0, 3, 4], [0], [1], [1]]
print(f"最大独立集: {tree_dp_max_independent_set(adj, 5)}")
print(f"树直径: {tree_dp_diameter(adj, 5)}")

斜率优化 DP

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
def slope_optimization_dp(n, cost_func):
from collections import deque

dp = [0] * (n + 1)
q = deque([0])

def get_slope(i, j):
return (dp[i] - dp[j]) / (i - j) if i != j else float('inf')

for i in range(1, n + 1):
while len(q) >= 2 and get_slope(q[0], q[1]) <= cost_func.get_slope(i):
q.popleft()

j = q[0]
dp[i] = dp[j] + cost_func.calculate(j, i)

while len(q) >= 2 and get_slope(q[-2], q[-1]) >= get_slope(q[-1], i):
q.pop()

q.append(i)

return dp[n]

def print_workshop(n, a, b, c):
prefix = [0] * (n + 1)
for i in range(1, n + 1):
prefix[i] = prefix[i - 1] + a[i - 1]

dp = [0] * (n + 1)
q = [0]

def get_x(i):
return prefix[i]

def get_y(i):
return dp[i] + b * prefix[i] * prefix[i] - c * prefix[i]

def get_slope(j, k):
return (get_y(j) - get_y(k)) / (get_x(j) - get_x(k))

head = 0
for i in range(1, n + 1):
while head < len(q) - 1 and get_slope(q[head], q[head + 1]) >= 2 * b * prefix[i]:
head += 1

j = q[head]
dp[i] = dp[j] + b * (prefix[i] - prefix[j]) ** 2 + c * (prefix[i] - prefix[j]) + a[i - 1]

while len(q) > head + 1 and get_slope(q[-2], q[-1]) <= get_slope(q[-1], i):
q.pop()

q.append(i)

return dp[n]

六十七、图论应用专题

最小生成树变种

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
def kruskal_with_restriction(n, edges, must_include, exclude):
parent = list(range(n))
rank = [0] * n

def find(x):
if parent[x] != x:
parent[x] = find(parent[x])
return parent[x]

def union(x, y):
px, py = find(x), find(y)
if px == py:
return False
if rank[px] < rank[py]:
px, py = py, px
parent[py] = px
if rank[px] == rank[py]:
rank[px] += 1
return True

total = 0
count = 0

for u, v, w in must_include:
if union(u, v):
total += w
count += 1

edges = [e for e in edges if e not in exclude and e not in must_include]
edges.sort(key=lambda x: x[2])

for u, v, w in edges:
if union(u, v):
total += w
count += 1
if count == n - 1:
break

return total if count == n - 1 else -1

def second_mst(n, edges):
def kruskal(n, edges, skip=None):
parent = list(range(n))

def find(x):
if parent[x] != x:
parent[x] = find(parent[x])
return parent[x]

total = 0
mst_edges = []

for i, (u, v, w) in enumerate(edges):
if skip == i:
continue
if find(u) != find(v):
parent[find(u)] = find(v)
total += w
mst_edges.append(i)

return total, mst_edges if len(mst_edges) == n - 1 else (float('inf'), [])

edges.sort(key=lambda x: x[2])
mst_cost, mst_edges = kruskal(n, edges)

second_best = float('inf')
for skip_idx in mst_edges:
cost, _ = kruskal(n, edges, skip_idx)
if cost > mst_cost:
second_best = min(second_best, cost)

return mst_cost, second_best

def minimum_spanning_arborescence(n, edges, root):
INF = float('inf')
in_edge = [INF] * n
parent = [-1] * n

for u, v, w in edges:
if v != root and w < in_edge[v]:
in_edge[v] = w
parent[v] = u

visited = [-1] * n
cycle_id = [-1] * n
total = 0
cnt = 0

for i in range(n):
if i == root:
continue

total += in_edge[i]

v = i
while v != root and visited[v] == -1:
visited[v] = i
v = parent[v]

if v != root and visited[v] == i:
u = parent[v]
while u != v:
cycle_id[u] = cnt
u = parent[u]
cycle_id[v] = cnt
cnt += 1

if cnt == 0:
return total

for i in range(n):
if cycle_id[i] == -1:
cycle_id[i] = cnt
cnt += 1

new_edges = []
for u, v, w in edges:
new_u = cycle_id[u]
new_v = cycle_id[v]
new_w = w - in_edge[v] if cycle_id[v] != cycle_id[u] else 0
if new_u != new_v:
new_edges.append((new_u, new_v, new_w))

return total + minimum_spanning_arborescence(cnt, new_edges, cycle_id[root])

差分与前缀和进阶

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
def difference_array_2d(matrix, updates):
n, m = len(matrix), len(matrix[0])
diff = [[0] * (m + 2) for _ in range(n + 2)]

for x1, y1, x2, y2, val in updates:
diff[x1][y1] += val
diff[x1][y2 + 1] -= val
diff[x2 + 1][y1] -= val
diff[x2 + 1][y2 + 1] += val

result = [[0] * m for _ in range(n)]
for i in range(n):
for j in range(m):
if i > 0:
diff[i][j] += diff[i - 1][j]
if j > 0:
diff[i][j] += diff[i][j - 1]
if i > 0 and j > 0:
diff[i][j] -= diff[i - 1][j - 1]
result[i][j] = matrix[i][j] + diff[i][j]

return result

def prefix_sum_2d(matrix):
n, m = len(matrix), len(matrix[0])
prefix = [[0] * (m + 1) for _ in range(n + 1)]

for i in range(n):
for j in range(m):
prefix[i + 1][j + 1] = (prefix[i][j + 1] + prefix[i + 1][j] -
prefix[i][j] + matrix[i][j])

def query(x1, y1, x2, y2):
return (prefix[x2 + 1][y2 + 1] - prefix[x1][y2 + 1] -
prefix[x2 + 1][y1] + prefix[x1][y1])

return prefix, query

def difference_on_tree(n, adj, queries, root=0):
parent = [-1] * n
depth = [0] * n
size = [0] * n
heavy = [-1] * n

def dfs(u, p):
size[u] = 1
max_size = 0
for v in adj[u]:
if v != p:
parent[v] = u
depth[v] = depth[u] + 1
dfs(v, u)
size[u] += size[v]
if size[v] > max_size:
max_size = size[v]
heavy[u] = v

dfs(root, -1)

diff = [0] * n

for u, v, val in queries:
diff[u] += val
diff[v] += val
lca = find_lca(u, v, parent, depth)
diff[lca] -= val
if parent[lca] != -1:
diff[parent[lca]] -= val

def push_diff(u, p):
for v in adj[u]:
if v != p:
push_diff(v, u)
diff[u] += diff[v]

push_diff(root, -1)
return diff

def find_lca(u, v, parent, depth):
while u != v:
if depth[u] > depth[v]:
u = parent[u]
else:
v = parent[v]
return u

拓扑排序应用

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
def critical_path(n, adj, durations):
in_degree = [0] * n
for u in range(n):
for v in adj[u]:
in_degree[v] += 1

from collections import deque
queue = deque([i for i in range(n) if in_degree[i] == 0])
order = []

earliest = [0] * n

while queue:
u = queue.popleft()
order.append(u)

for v in adj[u]:
earliest[v] = max(earliest[v], earliest[u] + durations[u])
in_degree[v] -= 1
if in_degree[v] == 0:
queue.append(v)

total_time = max(earliest[i] + durations[i] for i in range(n))

latest = [total_time - durations[i] for i in range(n)]

for u in reversed(order):
for v in adj[u]:
latest[u] = min(latest[u], latest[v] - durations[u])

critical_tasks = []
for i in range(n):
if earliest[i] == latest[i]:
critical_tasks.append(i)

return total_time, critical_tasks, earliest, latest

def longest_path_dag(n, adj, weights):
in_degree = [0] * n
for u in range(n):
for v in adj[u]:
in_degree[v] += 1

from collections import deque
queue = deque([i for i in range(n) if in_degree[i] == 0])

dist = [float('-inf')] * n
for i in range(n):
if in_degree[i] == 0:
dist[i] = 0

while queue:
u = queue.popleft()

for v in adj[u]:
if dist[v] < dist[u] + weights[(u, v)]:
dist[v] = dist[u] + weights[(u, v)]
in_degree[v] -= 1
if in_degree[v] == 0:
queue.append(v)

return dist

六十八、字符串处理专题

字符串匹配算法

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
def boyer_moore(text, pattern):
def build_bad_char_table(pattern):
table = {}
for i in range(len(pattern)):
table[pattern[i]] = i
return table

def build_good_suffix_table(pattern):
m = len(pattern)
table = [0] * (m + 1)

suffix = [0] * m
suffix[m - 1] = m

for i in range(m - 2, -1, -1):
j = i
while j >= 0 and pattern[j] == pattern[m - 1 - i + j]:
j -= 1
suffix[i] = i - j

for i in range(m):
table[i] = m

for i in range(m - 1, -1, -1):
if suffix[i] == i + 1:
for j in range(m - 1 - i):
if table[j] == m:
table[j] = m - 1 - i

for i in range(m - 1):
table[m - 1 - suffix[i]] = m - 1 - i

return table

n, m = len(text), len(pattern)
if m == 0:
return []

bad_char = build_bad_char_table(pattern)
good_suffix = build_good_suffix_table(pattern)

result = []
i = 0

while i <= n - m:
j = m - 1
while j >= 0 and pattern[j] == text[i + j]:
j -= 1

if j < 0:
result.append(i)
i += good_suffix[0]
else:
bc_shift = j - bad_char.get(text[i + j], -1)
gs_shift = good_suffix[j + 1]
i += max(bc_shift, gs_shift)

return result

def sunday_search(text, pattern):
def build_shift_table(pattern):
table = {}
m = len(pattern)
for i in range(m):
table[pattern[i]] = m - i
return table

n, m = len(text), len(pattern)
if m == 0:
return []

shift = build_shift_table(pattern)
result = []
i = 0

while i <= n - m:
j = 0
while j < m and text[i + j] == pattern[j]:
j += 1

if j == m:
result.append(i)

if i + m >= n:
break

i += shift.get(text[i + m], m + 1)

return result

text = "ABABDABACDABABCABAB"
pattern = "ABABCABAB"
print(f"Boyer-Moore 匹配: {boyer_moore(text, pattern)}")
print(f"Sunday 匹配: {sunday_search(text, pattern)}")

正则表达式应用

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
import re

def regex_examples():
text = "联系电话: 13812345678, 邮箱: test@example.com, 日期: 2024-01-15"

phones = re.findall(r'1[3-9]\d{9}', text)
print(f"手机号: {phones}")

emails = re.findall(r'[\w.-]+@[\w.-]+\.\w+', text)
print(f"邮箱: {emails}")

dates = re.findall(r'\d{4}-\d{2}-\d{2}', text)
print(f"日期: {dates}")

text2 = "价格: 100元, 200.5元, 300.00元"
prices = re.findall(r'(\d+\.?\d*)元', text2)
print(f"价格: {prices}")

text3 = " hello world python "
cleaned = re.sub(r'\s+', ' ', text3).strip()
print(f"清理后: '{cleaned}'")

text4 = "apple,banana;orange|grape"
fruits = re.split(r'[,;|]', text4)
print(f"分割: {fruits}")

def validate_inputs():
def is_valid_id_card(s):
pattern = r'^[1-9]\d{5}(18|19|20)\d{2}(0[1-9]|1[0-2])(0[1-9]|[12]\d|3[01])\d{3}[\dXx]$'
return bool(re.match(pattern, s))

def is_valid_ip(s):
pattern = r'^((25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)$'
return bool(re.match(pattern, s))

def is_valid_url(s):
pattern = r'^https?://[\w.-]+(:\d+)?(/[\w./-]*)?$'
return bool(re.match(pattern, s))

print(f"身份证验证: {is_valid_id_card('11010519900307233X')}")
print(f"IP验证: {is_valid_ip('192.168.1.1')}")
print(f"URL验证: {is_valid_url('https://www.example.com:8080/path')}")

regex_examples()
validate_inputs()

文本处理技巧

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
def text_processing():
pass

def word_frequency(text):
import re
from collections import Counter

words = re.findall(r'\b\w+\b', text.lower())
return Counter(words)

def find_longest_common_prefix(strs):
if not strs:
return ""

prefix = strs[0]
for s in strs[1:]:
while not s.startswith(prefix):
prefix = prefix[:-1]
if not prefix:
return ""

return prefix

def group_anagrams(strs):
from collections import defaultdict

groups = defaultdict(list)
for s in strs:
key = ''.join(sorted(s))
groups[key].append(s)

return list(groups.values())

def text_justification(words, max_width):
result = []
line = []
line_length = 0

for word in words:
if line_length + len(line) + len(word) > max_width:
spaces = max_width - line_length
if len(line) == 1:
result.append(line[0] + ' ' * spaces)
else:
space_between = spaces // (len(line) - 1)
extra = spaces % (len(line) - 1)
line_str = ""
for i, w in enumerate(line):
line_str += w
if i < len(line) - 1:
line_str += ' ' * (space_between + (1 if i < extra else 0))
result.append(line_str)

line = []
line_length = 0

line.append(word)
line_length += len(word)

if line:
result.append(' '.join(line) + ' ' * (max_width - line_length - len(line) + 1))

return result

strs = ["flower", "flow", "flight"]
print(f"最长公共前缀: '{find_longest_common_prefix(strs)}'")
print(f"字母异位词分组: {group_anagrams(['eat', 'tea', 'tan', 'ate', 'nat', 'bat'])}")

六十九、数学问题专题

组合数学

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
class Combinatorics:
def __init__(self, n, mod=10**9 + 7):
self.n = n
self.mod = mod

self.factorial = [1] * (n + 1)
for i in range(1, n + 1):
self.factorial[i] = self.factorial[i - 1] * i % mod

self.inv_factorial = [1] * (n + 1)
self.inv_factorial[n] = pow(self.factorial[n], mod - 2, mod)
for i in range(n - 1, -1, -1):
self.inv_factorial[i] = self.inv_factorial[i + 1] * (i + 1) % mod

def C(self, n, k):
if k < 0 or k > n:
return 0
return self.factorial[n] * self.inv_factorial[k] % self.mod * self.inv_factorial[n - k] % self.mod

def A(self, n, k):
if k < 0 or k > n:
return 0
return self.factorial[n] * self.inv_factorial[n - k] % self.mod

def catalan(self, n):
return self.C(2 * n, n) * pow(n + 1, self.mod - 2, self.mod) % self.mod

def stirling_second(self, n, k):
result = 0
for i in range(k + 1):
sign = 1 if (k - i) % 2 == 0 else -1
result = (result + sign * self.C(k, i) * pow(i, n, self.mod)) % self.mod
return result * self.inv_factorial[k] % self.mod

def bell_number(self, n):
result = 0
for k in range(n + 1):
result = (result + self.stirling_second(n, k)) % self.mod
return result

comb = Combinatorics(1000)
print(f"C(10, 3) = {comb.C(10, 3)}")
print(f"A(10, 3) = {comb.A(10, 3)}")
print(f"卡特兰数 C(5) = {comb.catalan(5)}")

线性代数

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
class Matrix:
def __init__(self, data):
self.data = data
self.rows = len(data)
self.cols = len(data[0]) if data else 0

def __add__(self, other):
if self.rows != other.rows or self.cols != other.cols:
raise ValueError("矩阵维度不匹配")

result = [[self.data[i][j] + other.data[i][j]
for j in range(self.cols)] for i in range(self.rows)]
return Matrix(result)

def __mul__(self, other):
if self.cols != other.rows:
raise ValueError("矩阵维度不匹配")

result = [[0] * other.cols for _ in range(self.rows)]
for i in range(self.rows):
for j in range(other.cols):
for k in range(self.cols):
result[i][j] += self.data[i][k] * other.data[k][j]

return Matrix(result)

def scalar_mul(self, c):
result = [[self.data[i][j] * c
for j in range(self.cols)] for i in range(self.rows)]
return Matrix(result)

def transpose(self):
result = [[self.data[j][i]
for j in range(self.rows)] for i in range(self.cols)]
return Matrix(result)

def determinant(self):
if self.rows != self.cols:
raise ValueError("非方阵")

n = self.rows
if n == 1:
return self.data[0][0]

det = 0
for j in range(n):
minor = [[self.data[i][k] for k in range(n) if k != j]
for i in range(1, n)]
det += ((-1) ** j) * self.data[0][j] * Matrix(minor).determinant()

return det

def inverse(self):
det = self.determinant()
if abs(det) < 1e-10:
raise ValueError("矩阵不可逆")

n = self.rows
augmented = [row[:] + [1 if i == j else 0 for j in range(n)]
for i, row in enumerate(self.data)]

for i in range(n):
max_row = i
for k in range(i + 1, n):
if abs(augmented[k][i]) > abs(augmented[max_row][i]):
max_row = k
augmented[i], augmented[max_row] = augmented[max_row], augmented[i]

pivot = augmented[i][i]
for j in range(2 * n):
augmented[i][j] /= pivot

for k in range(n):
if k != i:
factor = augmented[k][i]
for j in range(2 * n):
augmented[k][j] -= factor * augmented[i][j]

return Matrix([row[n:] for row in augmented])

m1 = Matrix([[1, 2], [3, 4]])
m2 = Matrix([[5, 6], [7, 8]])
print(f"矩阵加法: {(m1 + m2).data}")
print(f"矩阵乘法: {(m1 * m2).data}")
print(f"行列式: {m1.determinant()}")

数论函数

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
def divisor_function(n):
count = 0
total = 0
i = 1
while i * i <= n:
if n % i == 0:
count += 1 if i * i == n else 2
total += i
if i != n // i:
total += n // i
i += 1
return count, total

def sum_of_divisors_sieve(n):
sigma = [1] * (n + 1)
sigma[0] = 0

for i in range(2, n + 1):
for j in range(i, n + 1, i):
sigma[j] += i

return sigma

def number_of_divisors_sieve(n):
d = [1] * (n + 1)
d[0] = 0

is_prime = [True] * (n + 1)
primes = []
min_prime = [0] * (n + 1)

for i in range(2, n + 1):
if is_prime[i]:
primes.append(i)
min_prime[i] = i
d[i] = 2

for p in primes:
if i * p > n:
break
is_prime[i * p] = False
min_prime[i * p] = p

if i % p == 0:
exp = 0
temp = i
while temp % p == 0:
temp //= p
exp += 1
d[i * p] = d[i] // (exp + 1) * (exp + 2)
break
else:
d[i * p] = d[i] * 2

return d

def mobius_function_sieve(n):
mu = [1] * (n + 1)
is_prime = [True] * (n + 1)
primes = []

for i in range(2, n + 1):
if is_prime[i]:
primes.append(i)
mu[i] = -1

for p in primes:
if i * p > n:
break
is_prime[i * p] = False

if i % p == 0:
mu[i * p] = 0
break
else:
mu[i * p] = -mu[i]

return mu

print(f"约数个数和约数和 (12): {divisor_function(12)}")
print(f"莫比乌斯函数 (1-10): {mobius_function_sieve(10)[1:]}")

七十、蓝桥杯历年真题精选

2022 年真题:裁纸刀

题目:小蓝有一个裁纸刀,每次可以将一张纸切成两张。现在有 n 张纸,小蓝想知道最少需要切多少次才能得到 m 张纸。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
def paper_cut(n, m):
if m <= n:
return 0

cuts = 0
current = n

while current < m:
current += current
cuts += 1

return cuts

print(f"最少切纸次数: {paper_cut(1, 10)}")

2022 年真题:寻找整数

题目:有一个不超过 10^17 的正整数 n,给出 n 除以 2 到 49 的余数,求满足条件的最小正整数 n。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
def find_integer(remainders):
from math import gcd
from functools import reduce

def extended_gcd(a, b):
if b == 0:
return a, 1, 0
g, x, y = extended_gcd(b, a % b)
return g, y, x - (a // b) * y

def crt_pair(a1, m1, a2, m2):
g, p, q = extended_gcd(m1, m2)

if (a2 - a1) % g != 0:
return None, None

lcm = m1 // g * m2
x = (a1 + m1 * ((a2 - a1) // g * p % (m2 // g))) % lcm

return x, lcm

result = 0
mod = 1

for m, a in remainders.items():
result, mod = crt_pair(result, mod, a, m)
if result is None:
return None

return result if result > 0 else result + mod

remainders = {2: 1, 3: 2, 5: 4, 7: 0, 11: 9}
print(f"满足条件的最小整数: {find_integer(remainders)}")

2022 年真题:最大子矩阵

题目:给定一个 n×m 的矩阵,求最大子矩阵的和。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
def max_submatrix(matrix):
n, m = len(matrix), len(matrix[0])
max_sum = float('-inf')

for top in range(n):
row_sum = [0] * m

for bottom in range(top, n):
for j in range(m):
row_sum[j] += matrix[bottom][j]

current = 0
for val in row_sum:
current = max(val, current + val)
max_sum = max(max_sum, current)

return max_sum

matrix = [
[1, 2, -1, -4, -20],
[-8, -3, 4, 2, 1],
[3, 8, 10, 1, 3],
[-4, -1, 1, 7, -6]
]
print(f"最大子矩阵和: {max_submatrix(matrix)}")

2021 年真题:路径计数

题目:小蓝要从左上角走到右下角,每次只能向右或向下走,求有多少种不同的路径。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
def unique_paths(m, n):
dp = [[1] * n for _ in range(m)]

for i in range(1, m):
for j in range(1, n):
dp[i][j] = dp[i - 1][j] + dp[i][j - 1]

return dp[m - 1][n - 1]

def unique_paths_with_obstacles(obstacle_grid):
m, n = len(obstacle_grid), len(obstacle_grid[0])

if obstacle_grid[0][0] == 1:
return 0

dp = [[0] * n for _ in range(m)]
dp[0][0] = 1

for i in range(m):
for j in range(n):
if obstacle_grid[i][j] == 1:
dp[i][j] = 0
else:
if i > 0:
dp[i][j] += dp[i - 1][j]
if j > 0:
dp[i][j] += dp[i][j - 1]

return dp[m - 1][n - 1]

print(f"不同路径数 (3x7): {unique_paths(3, 7)}")

2020 年真题:数字三角形

题目:给定一个数字三角形,从顶部出发,每次可以移动到下方相邻的数字,求从顶部到底部的最大路径和。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
def maximum_path_sum(triangle):
n = len(triangle)
dp = triangle[-1][:]

for i in range(n - 2, -1, -1):
for j in range(i + 1):
dp[j] = triangle[i][j] + max(dp[j], dp[j + 1])

return dp[0]

triangle = [
[7],
[3, 8],
[8, 1, 0],
[2, 7, 4, 4],
[4, 5, 2, 6, 5]
]
print(f"最大路径和: {maximum_path_sum(triangle)}")

七十一、模拟题与实战演练

模拟题 1:表达式求值

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
def evaluate_expression_advanced(s):
import re

tokens = re.findall(r'\d+\.?\d*|[+\-*/()]', s)

def precedence(op):
if op in '+-':
return 1
if op in '*/':
return 2
return 0

def apply_op(a, b, op):
if op == '+': return a + b
if op == '-': return a - b
if op == '*': return a * b
if op == '/': return a / b

values = []
ops = []
i = 0

while i < len(tokens):
token = tokens[i]

if token == '(':
ops.append(token)
elif token == ')':
while ops and ops[-1] != '(':
b, a = values.pop(), values.pop()
values.append(apply_op(a, b, ops.pop()))
ops.pop()
elif token in '+-*/':
while ops and precedence(ops[-1]) >= precedence(token):
b, a = values.pop(), values.pop()
values.append(apply_op(a, b, ops.pop()))
ops.append(token)
else:
values.append(float(token))

i += 1

while ops:
b, a = values.pop(), values.pop()
values.append(apply_op(a, b, ops.pop()))

return values[0]

print(f"表达式结果: {evaluate_expression_advanced('3+4*2-6/3')}")

模拟题 2:括号匹配

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
def is_valid_parentheses(s):
stack = []
mapping = {')': '(', ']': '[', '}': '{'}

for char in s:
if char in '([{':
stack.append(char)
elif char in ')]}':
if not stack or stack[-1] != mapping[char]:
return False
stack.pop()

return len(stack) == 0

def generate_parentheses(n):
def backtrack(s, left, right):
if len(s) == 2 * n:
result.append(s)
return

if left < n:
backtrack(s + '(', left + 1, right)
if right < left:
backtrack(s + ')', left, right + 1)

result = []
backtrack('', 0, 0)
return result

def longest_valid_parentheses(s):
stack = [-1]
max_len = 0

for i, char in enumerate(s):
if char == '(':
stack.append(i)
else:
stack.pop()
if not stack:
stack.append(i)
else:
max_len = max(max_len, i - stack[-1])

return max_len

print(f"括号匹配: {is_valid_parentheses('()[]{}')}")
print(f"生成括号 (n=3): {generate_parentheses(3)}")
print(f"最长有效括号: {longest_valid_parentheses(')()())')}")

模拟题 3:区间调度

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
def merge_intervals(intervals):
if not intervals:
return []

intervals.sort(key=lambda x: x[0])
result = [intervals[0]]

for interval in intervals[1:]:
if interval[0] <= result[-1][1]:
result[-1][1] = max(result[-1][1], interval[1])
else:
result.append(interval)

return result

def insert_interval(intervals, new_interval):
intervals.append(new_interval)
return merge_intervals(intervals)

def erase_overlap_intervals(intervals):
if not intervals:
return 0

intervals.sort(key=lambda x: x[1])
count = 0
end = intervals[0][1]

for i in range(1, len(intervals)):
if intervals[i][0] < end:
count += 1
else:
end = intervals[i][1]

return count

intervals = [[1, 3], [2, 6], [8, 10], [15, 18]]
print(f"合并区间: {merge_intervals(intervals)}")
print(f"移除重叠区间数: {erase_overlap_intervals([[1, 2], [2, 3], [3, 4], [1, 3]])}")

模拟题 4:股票交易

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
def max_profit_single(prices):
min_price = float('inf')
max_profit = 0

for price in prices:
min_price = min(min_price, price)
max_profit = max(max_profit, price - min_price)

return max_profit

def max_profit_multiple(prices):
profit = 0

for i in range(1, len(prices)):
if prices[i] > prices[i - 1]:
profit += prices[i] - prices[i - 1]

return profit

def max_profit_k_transactions(k, prices):
n = len(prices)
if n < 2 or k < 1:
return 0

if k >= n // 2:
return max_profit_multiple(prices)

dp = [[0] * n for _ in range(k + 1)]

for i in range(1, k + 1):
max_diff = -prices[0]
for j in range(1, n):
dp[i][j] = max(dp[i][j - 1], prices[j] + max_diff)
max_diff = max(max_diff, dp[i - 1][j] - prices[j])

return dp[k][n - 1]

prices = [7, 1, 5, 3, 6, 4]
print(f"单次交易最大利润: {max_profit_single(prices)}")
print(f"多次交易最大利润: {max_profit_multiple(prices)}")
print(f"最多2次交易最大利润: {max_profit_k_transactions(2, prices)}")

模拟题 5:字符串转换

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
def edit_distance(word1, word2):
m, n = len(word1), len(word2)
dp = [[0] * (n + 1) for _ in range(m + 1)]

for i in range(m + 1):
dp[i][0] = i
for j in range(n + 1):
dp[0][j] = j

for i in range(1, m + 1):
for j in range(1, n + 1):
if word1[i - 1] == word2[j - 1]:
dp[i][j] = dp[i - 1][j - 1]
else:
dp[i][j] = 1 + min(dp[i - 1][j], dp[i][j - 1], dp[i - 1][j - 1])

return dp[m][n]

def longest_common_subsequence(text1, text2):
m, n = len(text1), len(text2)
dp = [[0] * (n + 1) for _ in range(m + 1)]

for i in range(1, m + 1):
for j in range(1, n + 1):
if text1[i - 1] == text2[j - 1]:
dp[i][j] = dp[i - 1][j - 1] + 1
else:
dp[i][j] = max(dp[i - 1][j], dp[i][j - 1])

return dp[m][n]

def longest_palindromic_subsequence(s):
n = len(s)
dp = [[0] * n for _ in range(n)]

for i in range(n):
dp[i][i] = 1

for length in range(2, n + 1):
for i in range(n - length + 1):
j = i + length - 1
if s[i] == s[j]:
dp[i][j] = dp[i + 1][j - 1] + 2
else:
dp[i][j] = max(dp[i + 1][j], dp[i][j - 1])

return dp[0][n - 1]

print(f"编辑距离: {edit_distance('horse', 'ros')}")
print(f"最长公共子序列: {longest_common_subsequence('abcde', 'ace')}")
print(f"最长回文子序列: {longest_palindromic_subsequence('bbbab')}")

七十二、代码模板与速查表

输入输出模板

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import sys
from sys import stdin, stdout

def fast_input():
data = stdin.read().split()
it = iter(data)
return it

def solve():
it = fast_input()
n = int(next(it))
arr = [int(next(it)) for _ in range(n)]

result = sum(arr)
print(result)

def multi_test_cases():
it = fast_input()
t = int(next(it))

for _ in range(t):
n = int(next(it))
arr = [int(next(it)) for _ in range(n)]

result = max(arr)
print(result)

def read_matrix():
it = fast_input()
n, m = int(next(it)), int(next(it))
matrix = [[int(next(it)) for _ in range(m)] for _ in range(n)]
return matrix

def output_format():
results = [1, 2, 3, 4, 5]

print('\n'.join(map(str, results)))

print(' '.join(map(str, results)))

for i, result in enumerate(results):
print(f"Case {i + 1}: {result}")

常用算法模板

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
ALGORITHM_TEMPLATES = """
二分查找:
left, right = 0, n - 1
while left <= right:
mid = (left + right) // 2
if check(mid):
right = mid - 1
else:
left = mid + 1

BFS:
queue = deque([start])
visited = set([start])
while queue:
node = queue.popleft()
for neighbor in get_neighbors(node):
if neighbor not in visited:
visited.add(neighbor)
queue.append(neighbor)

DFS:
def dfs(node, visited):
visited.add(node)
for neighbor in graph[node]:
if neighbor not in visited:
dfs(neighbor, visited)

并查集:
def find(x):
if parent[x] != x:
parent[x] = find(parent[x])
return parent[x]

def union(x, y):
px, py = find(x), find(y)
if px != py:
parent[px] = py

快速幂:
def pow_mod(base, exp, mod):
result = 1
while exp:
if exp & 1:
result = result * base % mod
base = base * base % mod
exp >>= 1
return result

GCD:
def gcd(a, b):
while b:
a, b = b, a % b
return a
"""

数据结构模板

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
class UnionFind:
def __init__(self, n):
self.parent = list(range(n))
self.rank = [0] * n
self.size = [1] * n

def find(self, x):
if self.parent[x] != x:
self.parent[x] = self.find(self.parent[x])
return self.parent[x]

def union(self, x, y):
px, py = self.find(x), self.find(y)
if px == py:
return False
if self.rank[px] < self.rank[py]:
px, py = py, px
self.parent[py] = px
self.size[px] += self.size[py]
if self.rank[px] == self.rank[py]:
self.rank[px] += 1
return True

def get_size(self, x):
return self.size[self.find(x)]

class SegmentTree:
def __init__(self, data):
self.n = len(data)
self.size = 1
while self.size < self.n:
self.size *= 2
self.tree = [0] * (2 * self.size)

for i in range(self.n):
self.tree[self.size + i] = data[i]

for i in range(self.size - 1, 0, -1):
self.tree[i] = self.tree[2 * i] + self.tree[2 * i + 1]

def update(self, idx, val):
idx += self.size
self.tree[idx] = val
idx //= 2
while idx:
self.tree[idx] = self.tree[2 * idx] + self.tree[2 * idx + 1]
idx //= 2

def query(self, l, r):
l += self.size
r += self.size
result = 0

while l <= r:
if l % 2 == 1:
result += self.tree[l]
l += 1
if r % 2 == 0:
result += self.tree[r]
r -= 1
l //= 2
r //= 2

return result

class Trie:
def __init__(self):
self.children = {}
self.is_end = False

def insert(self, word):
node = self
for char in word:
if char not in node.children:
node.children[char] = Trie()
node = node.children[char]
node.is_end = True

def search(self, word):
node = self
for char in word:
if char not in node.children:
return False
node = node.children[char]
return node.is_end

def starts_with(self, prefix):
node = self
for char in prefix:
if char not in node.children:
return False
node = node.children[char]
return True

七十三、常见问题与解决方案

时间复杂度优化

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
OPTIMIZATION_TIPS = """
常见优化技巧:

1. 预处理:
- 前缀和: O(n) 预处理, O(1) 查询
- 素数筛: O(n log log n) 预处理
- 阶乘和逆元: O(n) 预处理

2. 空间换时间:
- 哈希表: O(1) 平均查询
- 记忆化: 避免重复计算

3. 数据结构优化:
- 线段树: O(log n) 区间操作
- 树状数组: O(log n) 单点更新和前缀查询
- 优先队列: O(log n) 插入和删除

4. 算法选择:
- 排序: O(n log n) vs O(n²)
- 查找: O(log n) 二分 vs O(n) 线性
- 图遍历: BFS 最短路 vs DFS 路径

5. 常数优化:
- 使用局部变量
- 减少函数调用
- 使用内置函数
- 避免不必要的类型转换
"""

def optimization_examples():
pass

def slow_sum(arr, l, r):
return sum(arr[l:r+1])

def fast_sum(prefix, l, r):
return prefix[r+1] - prefix[l]

def slow_fib(n):
if n <= 1:
return n
return slow_fib(n-1) + slow_fib(n-2)

def fast_fib(n):
from functools import lru_cache
@lru_cache(maxsize=None)
def fib(n):
if n <= 1:
return n
return fib(n-1) + fib(n-2)
return fib(n)

def matrix_fib(n):
if n <= 1:
return n

def matrix_mult(A, B, mod=10**9+7):
return [[sum(A[i][k]*B[k][j] for k in range(2)) % mod for j in range(2)] for i in range(2)]

def matrix_pow(M, n):
result = [[1,0],[0,1]]
while n:
if n & 1:
result = matrix_mult(result, M)
M = matrix_mult(M, M)
n >>= 1
return result

M = [[1,1],[1,0]]
return matrix_pow(M, n-1)[0][0]

调试技巧

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
def debug_tools():
pass

def debug_print(*args, **kwargs):
import sys
print("[DEBUG]", *args, **kwargs, file=sys.stderr)

def print_array(arr, name="array"):
print(f"{name}: {arr[:10]}{'...' if len(arr) > 10 else ''} (len={len(arr)})")

def print_matrix(matrix, name="matrix"):
print(f"{name}:")
for i, row in enumerate(matrix[:5]):
print(f" [{i}]: {row[:10]}{'...' if len(row) > 10 else ''}")
if len(matrix) > 5:
print(f" ... ({len(matrix)} rows total)")

def print_graph(adj, name="graph"):
print(f"{name}:")
for i, neighbors in enumerate(adj[:5]):
print(f" {i}: {neighbors}")
if len(adj) > 5:
print(f" ... ({len(adj)} nodes total)")

def assert_equal(actual, expected, msg=""):
assert actual == expected, f"{msg}\nExpected: {expected}\nActual: {actual}"

def test_function(func, test_cases):
for i, (input_data, expected) in enumerate(test_cases):
actual = func(*input_data) if isinstance(input_data, tuple) else func(input_data)
if actual != expected:
print(f"Test case {i+1} FAILED:")
print(f" Input: {input_data}")
print(f" Expected: {expected}")
print(f" Actual: {actual}")
else:
print(f"Test case {i+1} PASSED")

边界条件处理

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
def boundary_handling():
pass

def safe_divide(a, b, default=0):
return a / b if b != 0 else default

def safe_mod(a, b):
return a % b if b != 0 else 0

def safe_index(arr, idx, default=None):
if 0 <= idx < len(arr):
return arr[idx]
return default

def safe_2d_index(matrix, row, col, default=None):
if 0 <= row < len(matrix) and 0 <= col < len(matrix[0]):
return matrix[row][col]
return default

def clamp(value, min_val, max_val):
return max(min_val, min(max_val, value))

def handle_empty_input(arr):
if not arr:
return 0
return sum(arr)

def handle_single_element(arr):
if len(arr) == 1:
return arr[0]
return max(arr)

def check_constraints(value, min_val, max_val):
if not (min_val <= value <= max_val):
raise ValueError(f"Value {value} out of range [{min_val}, {max_val}]")
return value

七十四、进阶学习路径

知识体系图

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
蓝桥杯 Python 高中组知识体系

基础层:
├── Python 语法
│ ├── 变量与数据类型
│ ├── 控制结构
│ ├── 函数与模块
│ └── 面向对象
├── 基础数据结构
│ ├── 列表与元组
│ ├── 字符串
│ ├── 字典与集合
│ └── 栈与队列
└── 基础算法
├── 枚举与模拟
├── 排序与查找
├── 递归与分治
└── 贪心算法

进阶层:
├── 高级数据结构
│ ├── 树与图
│ ├── 堆与优先队列
│ ├── 线段树与树状数组
│ └── 并查集与字典树
├── 动态规划
│ ├── 线性 DP
│ ├── 区间 DP
│ ├── 树形 DP
│ └── 状态压缩 DP
└── 图论算法
├── 最短路
├── 最小生成树
├── 拓扑排序
└── 强连通分量

高级层:
├── 字符串算法
│ ├── KMP 与扩展 KMP
│ ├── AC 自动机
│ ├── 后缀数组
│ └── 字符串哈希
├── 数学算法
│ ├── 数论
│ ├── 组合数学
│ ├── 博弈论
│ └── 计算几何
└── 高级技巧
├── 网络流
├── 随机化算法
├── 分块与莫队
└── 优化技巧

学习建议

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
备考时间规划:

第 1-2 周: 基础语法复习
- 熟练掌握 Python 输入输出
- 理解常用数据类型操作
- 练习基础题目

第 3-4 周: 数据结构
- 掌握列表、字典、集合操作
- 理解栈、队列的应用
- 练习数据结构题目

第 5-6 周: 基础算法
- 排序、二分查找
- 递归、分治
- 贪心算法

第 7-8 周: 动态规划
- 理解 DP 基本思想
- 练习经典 DP 问题
- 背包、区间 DP

第 9-10 周: 图论基础
- 图的表示与遍历
- 最短路、最小生成树
- 拓扑排序

第 11-12 周: 真题训练
- 完成历年真题
- 总结题型与技巧
- 查漏补缺

比赛前一周:
- 复习代码模板
- 保持手感
- 调整心态

推荐资源

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
在线练习平台:
- 蓝桥杯官方练习系统
- 洛谷 (luogu.com.cn)
- 力扣 (leetcode.cn)
- AcWing

推荐书籍:
- 《算法竞赛入门经典》
- 《挑战程序设计竞赛》
- 《算法导论》

学习建议:
1. 每天保持 2-3 小时练习
2. 重视基础,不要好高骛远
3. 多做真题,熟悉题型
4. 注意代码规范和调试技巧
5. 保持耐心,循序渐进

七十五、附录

Python 常用内置函数

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
BUILTIN_FUNCTIONS = """
数学函数:
abs(x) 绝对值
max(iterable) 最大值
min(iterable) 最小值
sum(iterable) 求和
pow(x, y) x 的 y 次方
round(x, n) 四舍五入
divmod(a, b) 返回商和余数

类型转换:
int(x) 转整数
float(x) 转浮点数
str(x) 转字符串
list(x) 转列表
tuple(x) 转元组
set(x) 转集合
dict(x) 转字典

序列操作:
len(s) 长度
sorted(s) 排序
reversed(s) 反转
enumerate(s) 枚举
zip(*s) 打包
map(f, s) 映射
filter(f, s) 过滤
all(s) 全真
any(s) 存真

字符串方法:
s.split() 分割
s.join() 连接
s.strip() 去空格
s.replace() 替换
s.find() 查找
s.count() 计数
s.upper() 大写
s.lower() 小写

列表方法:
lst.append(x) 追加
lst.extend(s) 扩展
lst.insert(i,x) 插入
lst.pop() 弹出
lst.remove(x) 删除
lst.sort() 排序
lst.reverse() 反转
lst.index(x) 索引
lst.count(x) 计数
"""

常用数学公式

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
MATH_FORMULAS = """
等差数列:
第 n 项: a_n = a_1 + (n-1)d
前 n 项和: S_n = n(a_1 + a_n)/2 = na_1 + n(n-1)d/2

等比数列:
第 n 项: a_n = a_1 * r^(n-1)
前 n 项和: S_n = a_1(1-r^n)/(1-r)

排列组合:
排列: A(n,m) = n!/(n-m)!
组合: C(n,m) = n!/(m!(n-m)!)
卡特兰数: C_n = C(2n,n)/(n+1)

数论:
GCD(a,b) * LCM(a,b) = a * b
欧拉定理: a^φ(m) ≡ 1 (mod m)
费马小定理: a^(p-1) ≡ 1 (mod p)

几何:
三角形面积: S = (1/2) * |x1(y2-y3) + x2(y3-y1) + x3(y1-y2)|
点到直线距离: d = |Ax + By + C| / sqrt(A² + B²)
两点距离: d = sqrt((x1-x2)² + (y1-y2)²)

图论:
完全图边数: n(n-1)/2
树的边数: n-1
欧拉路径: 奇度顶点数为 0 或 2
"""

复杂度对照表

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
常见时间复杂度对照表 (n = 10^6):

O(1) 瞬间完成
O(log n) 约 20 次操作
O(n) 约 10^6 次操作
O(n log n) 约 2*10^7 次操作
O(n²) 约 10^12 次操作 (超时)
O(2^n) n > 20 时超时
O(n!) n > 10 时超时

数据规模与算法选择:
n ≤ 10 O(n!), O(2^n)
n ≤ 20 O(2^n), O(n²)
n ≤ 100 O(n³)
n ≤ 1000 O(n²)
n ≤ 10^4 O(n²) (常数小时)
n ≤ 10^5 O(n log n)
n ≤ 10^6 O(n), O(n log n)
n ≤ 10^7 O(n)
n ≤ 10^8 O(n) (常数小)

七十六、高级数据结构实现

平衡树 - Treap

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import random

class TreapNode:
def __init__(self, key):
self.key = key
self.priority = random.randint(1, 10**9)
self.left = None
self.right = None
self.size = 1
self.count = 1

class Treap:
def __init__(self):
self.root = None

def _update_size(self, node):
if node:
left_size = node.left.size if node.left else 0
right_size = node.right.size if node.right else 0
node.size = left_size + right_size + node.count

def _rotate_right(self, y):
x = y.left
y.left = x.right
x.right = y
self._update_size(y)
self._update_size(x)
return x

def _rotate_left(self, x):
y = x.right
x.right = y.left
y.left = x
self._update_size(x)
self._update_size(y)
return y

def insert(self, key):
self.root = self._insert(self.root, key)

def _insert(self, node, key):
if not node:
return TreapNode(key)

if key == node.key:
node.count += 1
elif key < node.key:
node.left = self._insert(node.left, key)
if node.left.priority > node.priority:
node = self._rotate_right(node)
else:
node.right = self._insert(node.right, key)
if node.right.priority > node.priority:
node = self._rotate_left(node)

self._update_size(node)
return node

def delete(self, key):
self.root = self._delete(self.root, key)

def _delete(self, node, key):
if not node:
return None

if key < node.key:
node.left = self._delete(node.left, key)
elif key > node.key:
node.right = self._delete(node.right, key)
else:
if node.count > 1:
node.count -= 1
else:
if not node.left:
return node.right
if not node.right:
return node.left

if node.left.priority > node.right.priority:
node = self._rotate_right(node)
node.right = self._delete(node.right, key)
else:
node = self._rotate_left(node)
node.left = self._delete(node.left, key)

self._update_size(node)
return node

def find_kth(self, k):
return self._find_kth(self.root, k)

def _find_kth(self, node, k):
if not node:
return None

left_size = node.left.size if node.left else 0

if k <= left_size:
return self._find_kth(node.left, k)
elif k <= left_size + node.count:
return node.key
else:
return self._find_kth(node.right, k - left_size - node.count)

def get_rank(self, key):
return self._get_rank(self.root, key)

def _get_rank(self, node, key):
if not node:
return 0

if key < node.key:
return self._get_rank(node.left, key)
elif key == node.key:
left_size = node.left.size if node.left else 0
return left_size + 1
else:
left_size = node.left.size if node.left else 0
return left_size + node.count + self._get_rank(node.right, key)

treap = Treap()
for x in [5, 3, 7, 1, 9, 4, 6]:
treap.insert(x)

print(f"第3小的数: {treap.find_kth(3)}")
print(f"5的排名: {treap.get_rank(5)}")

可持久化线段树

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
class PersistentSegmentTree:
def __init__(self, arr):
self.n = len(arr)
self.roots = [None]
self.nodes = []

sorted_arr = sorted(set(arr))
self.values = {v: i for i, v in enumerate(sorted_arr)}
self.rev_values = {i: v for i, v in enumerate(sorted_arr)}
self.m = len(sorted_arr)

self.roots[0] = self._build(0, self.m - 1)

for val in arr:
new_root = self._update(self.roots[-1], 0, self.m - 1, self.values[val])
self.roots.append(new_root)

def _new_node(self):
node = {'left': None, 'right': None, 'count': 0}
self.nodes.append(node)
return len(self.nodes) - 1

def _build(self, l, r):
node_idx = self._new_node()
node = self.nodes[node_idx]

if l == r:
return node_idx

mid = (l + r) // 2
node['left'] = self._build(l, mid)
node['right'] = self._build(mid + 1, r)

return node_idx

def _update(self, prev_node, l, r, pos):
node_idx = self._new_node()
node = self.nodes[node_idx]
prev = self.nodes[prev_node]

node['left'] = prev['left']
node['right'] = prev['right']
node['count'] = prev['count'] + 1

if l == r:
return node_idx

mid = (l + r) // 2

if pos <= mid:
node['left'] = self._update(prev['left'], l, mid, pos)
else:
node['right'] = self._update(prev['right'], mid + 1, r, pos)

return node_idx

def query_kth(self, l_version, r_version, k):
return self._query_kth(
self.roots[l_version],
self.roots[r_version],
0, self.m - 1, k
)

def _query_kth(self, left_node, right_node, l, r, k):
if l == r:
return self.rev_values[l]

left_count = self.nodes[self.nodes[right_node]['left']]['count'] - \
self.nodes[self.nodes[left_node]['left']]['count']

mid = (l + r) // 2

if k <= left_count:
return self._query_kth(
self.nodes[left_node]['left'],
self.nodes[right_node]['left'],
l, mid, k
)
else:
return self._query_kth(
self.nodes[left_node]['right'],
self.nodes[right_node]['right'],
mid + 1, r, k - left_count
)

arr = [1, 5, 2, 4, 3]
pst = PersistentSegmentTree(arr)
print(f"区间[1,4]第2小的数: {pst.query_kth(1, 4, 2)}")

树链剖分

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
class HeavyLightDecomposition:
def __init__(self, n, adj):
self.n = n
self.adj = adj

self.parent = [-1] * n
self.depth = [0] * n
self.size = [0] * n
self.heavy = [-1] * n

self.head = [0] * n
self.pos = [0] * n
self.current_pos = 0

self._dfs(0, -1)
self._decompose(0, 0)

def _dfs(self, u, p):
self.size[u] = 1
max_size = 0

for v in self.adj[u]:
if v != p:
self.parent[v] = u
self.depth[v] = self.depth[u] + 1
self._dfs(v, u)
self.size[u] += self.size[v]

if self.size[v] > max_size:
max_size = self.size[v]
self.heavy[u] = v

def _decompose(self, u, h):
self.head[u] = h
self.pos[u] = self.current_pos
self.current_pos += 1

if self.heavy[u] != -1:
self._decompose(self.heavy[u], h)

for v in self.adj[u]:
if v != self.parent[u] and v != self.heavy[u]:
self._decompose(v, v)

def lca(self, u, v):
while self.head[u] != self.head[v]:
if self.depth[self.head[u]] > self.depth[self.head[v]]:
u = self.parent[self.head[u]]
else:
v = self.parent[self.head[v]]

return u if self.depth[u] < self.depth[v] else v

def path_query(self, u, v, tree):
result = 0

while self.head[u] != self.head[v]:
if self.depth[self.head[u]] > self.depth[self.head[v]]:
result += tree.query(self.pos[self.head[u]], self.pos[u])
u = self.parent[self.head[u]]
else:
result += tree.query(self.pos[self.head[v]], self.pos[v])
v = self.parent[self.head[v]]

if self.depth[u] > self.depth[v]:
u, v = v, u

result += tree.query(self.pos[u], self.pos[v])

return result

n = 7
adj = [[1, 2], [0, 3, 4], [0, 5, 6], [1], [1], [2], [2]]
hld = HeavyLightDecomposition(n, adj)
print(f"LCA(3, 5): {hld.lca(3, 5)}")

七十七、网络流算法详解

最大流 - Dinic 算法

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
from collections import deque

class Dinic:
def __init__(self, n):
self.n = n
self.adj = [[] for _ in range(n)]

def add_edge(self, u, v, cap):
self.adj[u].append([v, cap, len(self.adj[v])])
self.adj[v].append([u, 0, len(self.adj[u]) - 1])

def _bfs(self, s, t, level):
for i in range(self.n):
level[i] = -1

queue = deque([s])
level[s] = 0

while queue:
u = queue.popleft()

for v, cap, rev in self.adj[u]:
if level[v] == -1 and cap > 0:
level[v] = level[u] + 1
queue.append(v)

return level[t] != -1

def _dfs(self, u, t, flow, level, ptr):
if u == t:
return flow

for i in range(ptr[u], len(self.adj[u])):
ptr[u] = i
v, cap, rev = self.adj[u][i]

if level[v] == level[u] + 1 and cap > 0:
pushed = self._dfs(v, t, min(flow, cap), level, ptr)

if pushed > 0:
self.adj[u][i][1] -= pushed
self.adj[v][rev][1] += pushed
return pushed

return 0

def max_flow(self, s, t):
flow = 0
level = [-1] * self.n

while self._bfs(s, t, level):
ptr = [0] * self.n

while True:
pushed = self._dfs(s, t, float('inf'), level, ptr)
if pushed == 0:
break
flow += pushed

level = [-1] * self.n

return flow

dinic = Dinic(6)
dinic.add_edge(0, 1, 10)
dinic.add_edge(0, 2, 10)
dinic.add_edge(1, 2, 2)
dinic.add_edge(1, 3, 4)
dinic.add_edge(1, 4, 8)
dinic.add_edge(2, 4, 9)
dinic.add_edge(3, 5, 10)
dinic.add_edge(4, 3, 6)
dinic.add_edge(4, 5, 10)

print(f"最大流: {dinic.max_flow(0, 5)}")

最小费用最大流

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
class MinCostMaxFlow:
def __init__(self, n):
self.n = n
self.adj = [[] for _ in range(n)]

def add_edge(self, u, v, cap, cost):
self.adj[u].append([v, cap, cost, len(self.adj[v])])
self.adj[v].append([u, 0, -cost, len(self.adj[u]) - 1])

def _spfa(self, s, t, dist, parent, parent_edge):
in_queue = [False] * self.n
dist[s] = 0
in_queue[s] = True

queue = deque([s])

while queue:
u = queue.popleft()
in_queue[u] = False

for i, (v, cap, cost, rev) in enumerate(self.adj[u]):
if cap > 0 and dist[v] > dist[u] + cost:
dist[v] = dist[u] + cost
parent[v] = u
parent_edge[v] = i

if not in_queue[v]:
queue.append(v)
in_queue[v] = True

return dist[t] != float('inf')

def min_cost_max_flow(self, s, t):
flow = 0
cost = 0

while True:
dist = [float('inf')] * self.n
parent = [-1] * self.n
parent_edge = [-1] * self.n

if not self._spfa(s, t, dist, parent, parent_edge):
break

path_flow = float('inf')
v = t

while v != s:
u = parent[v]
edge_idx = parent_edge[v]
path_flow = min(path_flow, self.adj[u][edge_idx][1])
v = u

flow += path_flow
cost += path_flow * dist[t]

v = t
while v != s:
u = parent[v]
edge_idx = parent_edge[v]
rev = self.adj[u][edge_idx][3]

self.adj[u][edge_idx][1] -= path_flow
self.adj[v][rev][1] += path_flow

v = u

return flow, cost

mcmf = MinCostMaxFlow(4)
mcmf.add_edge(0, 1, 2, 2)
mcmf.add_edge(0, 2, 1, 1)
mcmf.add_edge(1, 2, 1, 1)
mcmf.add_edge(1, 3, 1, 3)
mcmf.add_edge(2, 3, 2, 1)

flow, cost = mcmf.min_cost_max_flow(0, 3)
print(f"最大流: {flow}, 最小费用: {cost}")

二分图匹配

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
def hungarian_algorithm(adj, n, m):
match_r = [-1] * m

def bpm(u, seen):
for v in range(m):
if adj[u][v] and not seen[v]:
seen[v] = True

if match_r[v] == -1 or bpm(match_r[v], seen):
match_r[v] = u
return True

return False

result = 0

for u in range(n):
seen = [False] * m

if bpm(u, seen):
result += 1

return result, match_r

def hopcroft_karp(adj, n, m):
from collections import deque

pair_u = [-1] * n
pair_v = [-1] * m
dist = [0] * n

def bfs():
queue = deque()

for u in range(n):
if pair_u[u] == -1:
dist[u] = 0
queue.append(u)
else:
dist[u] = float('inf')

dist_null = float('inf')

while queue:
u = queue.popleft()

if dist[u] < dist_null:
for v in range(m):
if adj[u][v] and pair_v[v] != -1 and dist[pair_v[v]] == float('inf'):
dist[pair_v[v]] = dist[u] + 1
queue.append(pair_v[v])
elif adj[u][v] and pair_v[v] == -1:
dist_null = dist[u] + 1

return dist_null != float('inf')

def dfs(u):
for v in range(m):
if adj[u][v]:
if pair_v[v] == -1 or (dist[pair_v[v]] == dist[u] + 1 and dfs(pair_v[v])):
pair_u[u] = v
pair_v[v] = u
return True

dist[u] = float('inf')
return False

matching = 0

while bfs():
for u in range(n):
if pair_u[u] == -1 and dfs(u):
matching += 1

return matching, pair_u, pair_v

adj = [
[1, 1, 0, 0],
[0, 1, 1, 0],
[0, 0, 1, 1],
[0, 0, 0, 1]
]

matching, pairs = hungarian_algorithm(adj, 4, 4)
print(f"最大匹配数: {matching}")
print(f"匹配对: {[(pairs[v], v) for v in range(4) if pairs[v] != -1]}")

七十八、字符串高级算法

后缀自动机

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
class SuffixAutomaton:
def __init__(self, s):
self.states = [{'link': -1, 'len': 0, 'next': {}}]
self.last = 0

for c in s:
self._extend(c)

def _extend(self, c):
p = self.last
curr = len(self.states)
self.states.append({'link': -1, 'len': self.states[p]['len'] + 1, 'next': {}})

while p != -1 and c not in self.states[p]['next']:
self.states[p]['next'][c] = curr
p = self.states[p]['link']

if p == -1:
self.states[curr]['link'] = 0
else:
q = self.states[p]['next'][c]

if self.states[p]['len'] + 1 == self.states[q]['len']:
self.states[curr]['link'] = q
else:
clone = len(self.states)
self.states.append({
'link': self.states[q]['link'],
'len': self.states[p]['len'] + 1,
'next': self.states[q]['next'].copy()
})

while p != -1 and self.states[p]['next'].get(c) == q:
self.states[p]['next'][c] = clone
p = self.states[p]['link']

self.states[q]['link'] = clone
self.states[curr]['link'] = clone

self.last = curr

def count_distinct_substrings(self):
result = 0
for i in range(1, len(self.states)):
result += self.states[i]['len'] - self.states[self.states[i]['link']]['len']
return result

def longest_common_substring(self, t):
v = 0
l = 0
best = 0

for c in t:
while v != 0 and c not in self.states[v]['next']:
v = self.states[v]['link']
l = self.states[v]['len']

if c in self.states[v]['next']:
v = self.states[v]['next'][c]
l += 1

best = max(best, l)

return best

sam = SuffixAutomaton("ababa")
print(f"不同子串数量: {sam.count_distinct_substrings()}")
print(f"与'bab'的最长公共子串: {sam.longest_common_substring('bab')}")

后缀数组

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
def build_suffix_array(s):
n = len(s)
k = 1

rank = [ord(c) for c in s]
tmp = [0] * n
sa = list(range(n))

while True:
sa.sort(key=lambda x: (rank[x], rank[x + k] if x + k < n else -1))

tmp[sa[0]] = 0
for i in range(1, n):
prev, curr = sa[i - 1], sa[i]
tmp[curr] = tmp[prev]

if rank[prev] != rank[curr] or \
(rank[prev + k] if prev + k < n else -1) != \
(rank[curr + k] if curr + k < n else -1):
tmp[curr] += 1

rank = tmp[:]

if rank[sa[-1]] == n - 1:
break

k *= 2

return sa

def build_lcp_array(s, sa):
n = len(s)
rank = [0] * n

for i in range(n):
rank[sa[i]] = i

lcp = [0] * (n - 1)
h = 0

for i in range(n):
if rank[i] == 0:
h = 0
continue

j = sa[rank[i] - 1]

while i + h < n and j + h < n and s[i + h] == s[j + h]:
h += 1

lcp[rank[i] - 1] = h

if h > 0:
h -= 1

return lcp

def count_distinct_substrings_sa(s):
n = len(s)
sa = build_suffix_array(s)
lcp = build_lcp_array(s, sa)

return n * (n + 1) // 2 - sum(lcp)

s = "banana"
sa = build_suffix_array(s)
lcp = build_lcp_array(s, sa)

print(f"后缀数组: {sa}")
print(f"LCP数组: {lcp}")
print(f"不同子串数量: {count_distinct_substrings_sa(s)}")

回文自动机

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
class PalindromicTree:
def __init__(self, s):
self.s = s
self.n = len(s)

self.nodes = [{'len': -1, 'link': 0, 'next': {}}, {'len': 0, 'link': 0, 'next': {}}]
self.num = 2
self.last = 1

self.suff = [0] * (self.n + 1)

for i in range(self.n):
self._add_char(i)

def _get_link(self, node, pos):
while True:
cur_len = self.nodes[node]['len']
if pos - 1 - cur_len >= 0 and self.s[pos] == self.s[pos - 1 - cur_len]:
return node
node = self.nodes[node]['link']

def _add_char(self, pos):
cur = self._get_link(self.last, pos)

if self.s[pos] in self.nodes[cur]['next']:
self.last = self.nodes[cur]['next'][self.s[pos]]
self.suff[pos] = self.last
return

self.nodes[cur]['next'][self.s[pos]] = self.num
self.nodes.append({'len': self.nodes[cur]['len'] + 2, 'link': 0, 'next': {}})

if self.nodes[self.num]['len'] == 1:
self.nodes[self.num]['link'] = 1
else:
link_node = self._get_link(self.nodes[cur]['link'], pos)
self.nodes[self.num]['link'] = self.nodes[link_node]['next'][self.s[pos]]

self.last = self.num
self.suff[pos] = self.num
self.num += 1

def count_palindromes(self):
return self.num - 2

def get_all_palindromes(self):
result = []

def dfs(node, path):
for c, child in self.nodes[node]['next'].items():
pal_len = self.nodes[child]['len']

if pal_len % 2 == 0:
half = pal_len // 2
pal = path[:half] + c + path[:half][::-1]
else:
half = pal_len // 2
pal = path[:half] + c + path[:half][::-1]

result.append((pal, pal_len))
dfs(child, path + c)

dfs(0, "")
dfs(1, "")

return result

pt = PalindromicTree("abacaba")
print(f"不同回文子串数量: {pt.count_palindromes()}")

七十九、计算几何进阶

凸包算法

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
def convex_hull(points):
points = sorted(set(points))

if len(points) <= 1:
return points

def cross(o, a, b):
return (a[0] - o[0]) * (b[1] - o[1]) - (a[1] - o[1]) * (b[0] - o[0])

lower = []
for p in points:
while len(lower) >= 2 and cross(lower[-2], lower[-1], p) <= 0:
lower.pop()
lower.append(p)

upper = []
for p in reversed(points):
while len(upper) >= 2 and cross(upper[-2], upper[-1], p) <= 0:
upper.pop()
upper.append(p)

return lower[:-1] + upper[:-1]

def rotating_calipers(points):
hull = convex_hull(points)
n = len(hull)

if n == 1:
return 0, hull[0], hull[0]
if n == 2:
dist = ((hull[0][0] - hull[1][0])**2 + (hull[0][1] - hull[1][1])**2)**0.5
return dist, hull[0], hull[1]

def dist(p1, p2):
return ((p1[0] - p2[0])**2 + (p1[1] - p2[1])**2)**0.5

max_dist = 0
p1_max, p2_max = None, None

j = 1
for i in range(n):
while True:
next_j = (j + 1) % n
d1 = dist(hull[i], hull[j])
d2 = dist(hull[i], hull[next_j])

if d2 > d1:
j = next_j
else:
break

d = dist(hull[i], hull[j])
if d > max_dist:
max_dist = d
p1_max, p2_max = hull[i], hull[j]

return max_dist, p1_max, p2_max

def polygon_area(points):
n = len(points)
area = 0

for i in range(n):
j = (i + 1) % n
area += points[i][0] * points[j][1]
area -= points[j][0] * points[i][1]

return abs(area) / 2

points = [(0, 0), (1, 1), (2, 2), (3, 1), (2, 0), (1, -1)]
hull = convex_hull(points)
print(f"凸包顶点: {hull}")
print(f"凸包面积: {polygon_area(hull)}")

max_dist, p1, p2 = rotating_calipers(points)
print(f"最远点对距离: {max_dist:.2f}, 点对: {p1}, {p2}")

半平面交

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
class Line:
def __init__(self, p1, p2):
self.p1 = p1
self.p2 = p2

def direction(self):
return (self.p2[0] - self.p1[0], self.p2[1] - self.p1[1])

def line_intersection(l1, l2):
d1 = l1.direction()
d2 = l2.direction()

det = d1[0] * d2[1] - d1[1] * d2[0]

if abs(det) < 1e-10:
return None

t = ((l2.p1[0] - l1.p1[0]) * d2[1] - (l2.p1[1] - l1.p1[1]) * d2[0]) / det

return (l1.p1[0] + t * d1[0], l1.p1[1] + t * d1[1])

def half_plane_intersection(lines):
def cross(p, q, r):
return (q[0] - p[0]) * (r[1] - p[1]) - (q[1] - p[1]) * (r[0] - p[0])

lines.sort(key=lambda l: math.atan2(l.direction()[1], l.direction()[0]))

from collections import deque
dq = deque()

for line in lines:
while len(dq) >= 2:
p1 = line_intersection(dq[-1], line)
p2 = line_intersection(dq[-2], dq[-1])

if p1 is None or p2 is None:
break

if cross(dq[-2].p1, p2, p1) < 0:
dq.pop()
else:
break

while len(dq) >= 2:
p1 = line_intersection(dq[0], line)
p2 = line_intersection(dq[0], dq[1])

if p1 is None or p2 is None:
break

if cross(dq[1].p1, p2, p1) < 0:
dq.popleft()
else:
break

dq.append(line)

while len(dq) >= 2:
p1 = line_intersection(dq[-1], dq[0])
p2 = line_intersection(dq[-2], dq[-1])

if p1 is None or p2 is None:
break

if cross(dq[-2].p1, p2, p1) < 0:
dq.pop()
else:
break

if len(dq) < 3:
return []

result = []
for i in range(len(dq)):
p = line_intersection(dq[i], dq[(i + 1) % len(dq)])
if p:
result.append(p)

return result

import math

圆与几何

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
def circle_intersection(c1, r1, c2, r2):
dx = c2[0] - c1[0]
dy = c2[1] - c1[1]
d = (dx**2 + dy**2)**0.5

if d > r1 + r2:
return []

if d < abs(r1 - r2):
return []

if d == 0 and r1 == r2:
return None

a = (r1**2 - r2**2 + d**2) / (2 * d)
h = (r1**2 - a**2)**0.5

px = c1[0] + a * dx / d
py = c1[1] + a * dy / d

if d == r1 + r2 or d == abs(r1 - r2):
return [(px, py)]

rx = -h * dy / d
ry = h * dx / d

return [(px + rx, py + ry), (px - rx, py - ry)]

def point_in_circle(p, c, r):
return (p[0] - c[0])**2 + (p[1] - c[1])**2 <= r**2

def minimum_enclosing_circle(points):
def circle_from_3_points(p1, p2, p3):
ax, ay = p1
bx, by = p2
cx, cy = p3

d = 2 * (ax * (by - cy) + bx * (cy - ay) + cx * (ay - by))

if abs(d) < 1e-10:
return None, 0

ux = ((ax**2 + ay**2) * (by - cy) + (bx**2 + by**2) * (cy - ay) + (cx**2 + cy**2) * (ay - by)) / d
uy = ((ax**2 + ay**2) * (cx - bx) + (bx**2 + by**2) * (ax - cx) + (cx**2 + cy**2) * (bx - ax)) / d

r = ((ux - ax)**2 + (uy - ay)**2)**0.5

return (ux, uy), r

def circle_from_2_points(p1, p2):
cx = (p1[0] + p2[0]) / 2
cy = (p1[1] + p2[1]) / 2
r = ((p1[0] - cx)**2 + (p1[1] - cy)**2)**0.5
return (cx, cy), r

import random
shuffled = points[:]
random.shuffle(shuffled)

c, r = shuffled[0], 0

for i in range(1, len(shuffled)):
if not point_in_circle(shuffled[i], c, r):
c, r = shuffled[i], 0

for j in range(i):
if not point_in_circle(shuffled[j], c, r):
c, r = circle_from_2_points(shuffled[i], shuffled[j])

for k in range(j):
if not point_in_circle(shuffled[k], c, r):
c, r = circle_from_3_points(shuffled[i], shuffled[j], shuffled[k])

return c, r

points = [(0, 0), (1, 1), (2, 0), (1, -1)]
center, radius = minimum_enclosing_circle(points)
print(f"最小包围圆: 圆心{center}, 半径{radius:.2f}")

八十、蓝桥杯真题深度解析

2023 年真题:工作时长

题目:给定员工每天的打卡时间,计算每位员工的工作时长。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
def calculate_work_time(records):
records.sort(key=lambda x: (x[0], x[2]))

result = {}
i = 0

while i < len(records):
emp_id = records[i][0]
date = records[i][1]

in_time = None
out_time = None

while i < len(records) and records[i][0] == emp_id and records[i][1] == date:
if records[i][2] == 'in':
in_time = records[i][3]
else:
out_time = records[i][3]
i += 1

if in_time and out_time:
work_hours = out_time - in_time
result[emp_id] = result.get(emp_id, 0) + work_hours

return result

records = [
(1, '2023-01-01', 'in', 9),
(1, '2023-01-01', 'out', 18),
(1, '2023-01-02', 'in', 8),
(1, '2023-01-02', 'out', 17),
(2, '2023-01-01', 'in', 10),
(2, '2023-01-01', 'out', 19),
]

work_time = calculate_work_time(records)
for emp_id, hours in sorted(work_time.items()):
print(f"员工{emp_id}: {hours}小时")

2023 年真题:数组分割

题目:将数组分成两部分,使得两部分的和的差的绝对值最小。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
def min_partition_diff(arr):
total = sum(arr)
n = len(arr)
target = total // 2

dp = [False] * (target + 1)
dp[0] = True

for num in arr:
for j in range(target, num - 1, -1):
dp[j] = dp[j] or dp[j - num]

for j in range(target, -1, -1):
if dp[j]:
return total - 2 * j

return total

def partition_with_trace(arr):
total = sum(arr)
n = len(arr)
target = total // 2

dp = [[False] * (target + 1) for _ in range(n + 1)]

for i in range(n + 1):
dp[i][0] = True

for i in range(1, n + 1):
for j in range(target + 1):
dp[i][j] = dp[i - 1][j]
if j >= arr[i - 1]:
dp[i][j] = dp[i][j] or dp[i - 1][j - arr[i - 1]]

for j in range(target, -1, -1):
if dp[n][j]:
part1_sum = j
part2_sum = total - j

part1 = []
remaining = j
for i in range(n, 0, -1):
if remaining >= arr[i - 1] and dp[i - 1][remaining - arr[i - 1]]:
part1.append(arr[i - 1])
remaining -= arr[i - 1]

part2 = [x for x in arr if x not in part1 or part1.remove(x) or True]

return abs(part1_sum - part2_sum), part1, part2

return total, arr, []

arr = [1, 6, 11, 5]
diff, p1, p2 = partition_with_trace(arr)
print(f"最小差值: {diff}")
print(f"分割: {p1}{p2}")

2022 年真题:因数平方和

题目:计算 1 到 n 所有数的因数的平方和。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
def sum_of_divisor_squares(n):
result = 0

for i in range(1, int(n**0.5) + 1):
count = n // i

if i * i <= n:
result += i * i * count

if i != n // i:
j = n // i
result += j * j

return result

def sum_of_divisor_squares_optimized(n):
result = 0
i = 1

while i <= n:
q = n // i
j = n // q

count = j - i + 1
sum_i_to_j = (i + j) * count // 2

result += sum_i_to_j * q * q

i = j + 1

return result

n = 10
print(f"因数平方和 (暴力): {sum_of_divisor_squares(n)}")
print(f"因数平方和 (优化): {sum_of_divisor_squares_optimized(n)}")

2021 年真题:括号序列

题目:判断括号序列是否合法,如果不合法,最少需要添加多少个括号使其合法。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
def min_add_to_make_valid(s):
left_needed = 0
right_needed = 0

for c in s:
if c == '(':
right_needed += 1
elif c == ')':
if right_needed > 0:
right_needed -= 1
else:
left_needed += 1

return left_needed + right_needed

def generate_valid_parentheses_with_fix(s):
n = len(s)
dp = [[None] * n for _ in range(n)]

def solve(l, r):
if l > r:
return ['']
if dp[l][r] is not None:
return dp[l][r]

result = []

if s[l] in '([':
for mid in range(l, r + 1):
if (s[l] == '(' and s[mid] == ')') or (s[l] == '[' and s[mid] == ']'):
left_parts = solve(l + 1, mid - 1)
right_parts = solve(mid + 1, r)

for lp in left_parts:
for rp in right_parts:
result.append(s[l] + lp + s[mid] + rp)

if not result:
for mid in range(l, r):
left_parts = solve(l, mid)
right_parts = solve(mid + 1, r)

for lp in left_parts:
for rp in right_parts:
result.append(lp + rp)

if not result:
result = [s[l]]

dp[l][r] = result
return result

return solve(0, n - 1)

s = "(()))"
print(f"最少添加括号数: {min_add_to_make_valid(s)}")

2020 年真题:子串分值

题目:对于字符串的每个子串,其分值为子串中不同字符的个数。求所有子串的分值之和。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
def substring_score(s):
n = len(s)
total = 0

for i in range(n):
seen = set()

for j in range(i, n):
seen.add(s[j])
total += len(seen)

return total

def substring_score_optimized(s):
n = len(s)
total = 0
last_pos = {}

for i in range(n):
c = s[i]

prev = last_pos.get(c, -1)

contribution = (i - prev) * (n - i)

total += contribution

last_pos[c] = i

return total

s = "ababc"
print(f"子串分值 (暴力): {substring_score(s)}")
print(f"子串分值 (优化): {substring_score_optimized(s)}")

八十一、综合实战案例

案例 1:迷宫最短路径

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
from collections import deque

def maze_shortest_path(maze, start, end):
n, m = len(maze), len(maze[0])
directions = [(0, 1), (0, -1), (1, 0), (-1, 0)]

queue = deque([(start[0], start[1], 0)])
visited = [[False] * m for _ in range(n)]
visited[start[0]][start[1]] = True

parent = {}

while queue:
x, y, dist = queue.popleft()

if (x, y) == end:
path = []
curr = (x, y)
while curr in parent:
path.append(curr)
curr = parent[curr]
path.append(start)
return dist, path[::-1]

for dx, dy in directions:
nx, ny = x + dx, y + dy

if 0 <= nx < n and 0 <= ny < m and not visited[nx][ny] and maze[nx][ny] == 0:
visited[nx][ny] = True
parent[(nx, ny)] = (x, y)
queue.append((nx, ny, dist + 1))

return -1, []

maze = [
[0, 1, 0, 0, 0],
[0, 1, 0, 1, 0],
[0, 0, 0, 1, 0],
[0, 1, 1, 1, 0],
[0, 0, 0, 0, 0]
]

dist, path = maze_shortest_path(maze, (0, 0), (4, 4))
print(f"最短距离: {dist}")
print(f"路径: {path}")

案例 2:任务调度

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
def task_scheduling_with_deadlines(tasks):
tasks.sort(key=lambda x: x[1])

max_deadline = max(t[1] for t in tasks)
slots = [False] * (max_deadline + 1)

total_profit = 0
scheduled = []

for profit, deadline, name in sorted(tasks, key=lambda x: -x[0]):
for slot in range(deadline, 0, -1):
if not slots[slot]:
slots[slot] = True
total_profit += profit
scheduled.append((name, slot, profit))
break

return total_profit, scheduled

tasks = [
(100, 2, 'A'),
(19, 1, 'B'),
(27, 2, 'C'),
(25, 1, 'D'),
(15, 3, 'E')
]

profit, schedule = task_scheduling_with_deadlines(tasks)
print(f"最大收益: {profit}")
print(f"调度方案: {schedule}")

案例 3:股票买卖策略

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
def stock_trading_strategy(prices, k):
n = len(prices)

if n < 2 or k < 1:
return 0, []

if k >= n // 2:
profit = 0
transactions = []

for i in range(1, n):
if prices[i] > prices[i - 1]:
profit += prices[i] - prices[i - 1]
transactions.append((i - 1, i, prices[i] - prices[i - 1]))

return profit, transactions

dp = [[0] * n for _ in range(k + 1)]

for i in range(1, k + 1):
max_diff = -prices[0]

for j in range(1, n):
dp[i][j] = max(dp[i][j - 1], prices[j] + max_diff)
max_diff = max(max_diff, dp[i - 1][j] - prices[j])

transactions = []
i, j = k, n - 1

while i > 0 and j > 0:
if dp[i][j] != dp[i][j - 1]:
buy_day = 0
for d in range(j):
if dp[i - 1][d] - prices[d] == dp[i][j] - prices[j]:
buy_day = d
break

transactions.append((buy_day, j, prices[j] - prices[buy_day]))
j = buy_day
i -= 1
else:
j -= 1

return dp[k][n - 1], transactions[::-1]

prices = [3, 2, 6, 5, 0, 3]
k = 2
profit, transactions = stock_trading_strategy(prices, k)
print(f"最大利润: {profit}")
print(f"交易记录: {transactions}")

案例 4:文件压缩

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import heapq
from collections import Counter

class HuffmanNode:
def __init__(self, char=None, freq=0, left=None, right=None):
self.char = char
self.freq = freq
self.left = left
self.right = right

def __lt__(self, other):
return self.freq < other.freq

def build_huffman_tree(text):
freq = Counter(text)

heap = [HuffmanNode(char=c, freq=f) for c, f in freq.items()]
heapq.heapify(heap)

while len(heap) > 1:
left = heapq.heappop(heap)
right = heapq.heappop(heap)

merged = HuffmanNode(freq=left.freq + right.freq, left=left, right=right)
heapq.heappush(heap, merged)

return heap[0] if heap else None

def build_codes(root):
codes = {}

def dfs(node, code):
if node is None:
return

if node.char is not None:
codes[node.char] = code if code else '0'
return

dfs(node.left, code + '0')
dfs(node.right, code + '1')

dfs(root, '')
return codes

def huffman_encode(text):
root = build_huffman_tree(text)
codes = build_codes(root)

encoded = ''.join(codes[c] for c in text)

original_bits = len(text) * 8
compressed_bits = len(encoded)

return encoded, codes, original_bits, compressed_bits

text = "this is an example for huffman encoding"
encoded, codes, orig, comp = huffman_encode(text)

print(f"编码表: {codes}")
print(f"原始位数: {orig}")
print(f"压缩后位数: {comp}")
print(f"压缩率: {(1 - comp/orig)*100:.1f}%")

案例 5:社交网络分析

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
def social_network_analysis(adj, n):
from collections import deque

def bfs_distances(start):
dist = [-1] * n
dist[start] = 0
queue = deque([start])

while queue:
u = queue.popleft()
for v in adj[u]:
if dist[v] == -1:
dist[v] = dist[u] + 1
queue.append(v)

return dist

all_distances = [bfs_distances(i) for i in range(n)]

avg_distances = []
for i in range(n):
valid = [d for d in all_distances[i] if d > 0]
avg = sum(valid) / len(valid) if valid else 0
avg_distances.append(avg)

diameter = max(max(d) for d in all_distances)

def find_clusters():
visited = [False] * n
clusters = []

for start in range(n):
if visited[start]:
continue

cluster = []
queue = deque([start])
visited[start] = True

while queue:
u = queue.popleft()
cluster.append(u)

for v in adj[u]:
if not visited[v]:
visited[v] = True
queue.append(v)

clusters.append(cluster)

return clusters

clusters = find_clusters()

degrees = [len(adj[i]) for i in range(n)]

return {
'avg_distances': avg_distances,
'diameter': diameter,
'clusters': clusters,
'degrees': degrees,
'most_central': avg_distances.index(min(avg_distances)) if avg_distances else -1,
'most_connected': degrees.index(max(degrees)) if degrees else -1
}

adj = [
[1, 2],
[0, 2, 3],
[0, 1, 3],
[1, 2, 4],
[3, 5],
[4]
]

analysis = social_network_analysis(adj, 6)
print(f"网络直径: {analysis['diameter']}")
print(f"社区数量: {len(analysis['clusters'])}")
print(f"最中心节点: {analysis['most_central']}")
print(f"最活跃节点: {analysis['most_connected']}")

八十二、算法竞赛技巧总结

常见陷阱与注意事项

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
COMMON_PITFALLS = """
1. 整数溢出:
- Python 整数无溢出,但其他语言需要注意
- 大数运算使用模运算

2. 浮点数精度:
- 使用 Decimal 处理精确计算
- 比较时使用 eps = 1e-9

3. 数组越界:
- 检查索引范围
- 使用边界条件处理

4. 递归深度:
- Python 默认递归深度约 1000
- 使用 sys.setrecursionlimit() 调整

5. 时间复杂度:
- 注意循环嵌套层数
- 预处理优化查询

6. 空间复杂度:
- 滚动数组优化
- 及时释放不用的变量

7. 输入输出:
- 大量数据使用快速 I/O
- 注意输出格式要求

8. 边界情况:
- 空数组、单元素
- 最大最小值
- 特殊字符处理
"""

def handle_large_input():
import sys
sys.setrecursionlimit(10**6)

data = sys.stdin.read().split()
it = iter(data)

n = int(next(it))
arr = [int(next(it)) for _ in range(n)]

return arr

def precision_handling():
from decimal import Decimal, getcontext
getcontext().prec = 50

a = Decimal('0.1')
b = Decimal('0.2')
c = a + b

return c

def modular_arithmetic():
MOD = 10**9 + 7

def add(a, b):
return (a + b) % MOD

def mul(a, b):
return (a * b) % MOD

def pow_mod(base, exp):
result = 1
while exp:
if exp & 1:
result = mul(result, base)
base = mul(base, base)
exp >>= 1
return result

return add, mul, pow_mod

代码优化技巧

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
def optimization_techniques():
pass

def use_local_variables():
import math
local_sqrt = math.sqrt

result = [local_sqrt(i) for i in range(1000)]
return result

def avoid_repeated_computation():
cache = {}

def expensive_function(x):
if x in cache:
return cache[x]

result = x * x + 2 * x + 1
cache[x] = result
return result

return expensive_function

def use_builtin_functions():
arr = list(range(1000))

total = sum(arr)
maximum = max(arr)
minimum = min(arr)
sorted_arr = sorted(arr, reverse=True)

return total, maximum, minimum

def list_comprehension_optimization():
squares = [x**2 for x in range(100)]

even_squares = [x**2 for x in range(100) if x % 2 == 0]

matrix = [[i * j for j in range(10)] for i in range(10)]

return squares, even_squares, matrix

def generator_for_memory():
def fibonacci_generator(n):
a, b = 0, 1
for _ in range(n):
yield a
a, b = b, a + b

return list(fibonacci_generator(20))

调试与测试方法

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
def debug_and_test():
pass

def simple_test(func, test_cases):
passed = 0
for i, (input_data, expected) in enumerate(test_cases):
try:
if isinstance(input_data, tuple):
result = func(*input_data)
else:
result = func(input_data)

if result == expected:
passed += 1
print(f"测试 {i+1}: 通过")
else:
print(f"测试 {i+1}: 失败")
print(f" 输入: {input_data}")
print(f" 期望: {expected}")
print(f" 实际: {result}")
except Exception as e:
print(f"测试 {i+1}: 错误 - {e}")

print(f"\n通过率: {passed}/{len(test_cases)}")
return passed == len(test_cases)

def performance_test(func, *args):
import time

start = time.time()
result = func(*args)
end = time.time()

print(f"执行时间: {end - start:.4f} 秒")
return result

def memory_usage():
import sys

def get_size(obj):
return sys.getsizeof(obj)

arr = [i for i in range(1000)]
print(f"列表内存: {get_size(arr)} 字节")

return get_size

总结

蓝桥杯 Python 高中组考试注重基础知识和编程能力。备考时需要:

  1. 扎实掌握 Python 语法和常用模块
  2. 熟练运用常见数据结构
  3. 理解并能够实现基础算法
  4. 多做真题,积累解题经验
  5. 注意代码规范和调试技巧

坚持练习,保持耐心,相信你一定能取得好成绩!