squares = [x**2for x inrange(10)] evens = [x for x inrange(20) if x % 2 == 0] matrix = [[0] * 3for _ inrange(3)] # 创建 3x3 矩阵 flattened = [x for row in matrix for x in row] # 二维转一维
defquick_sort(arr): iflen(arr) <= 1: return arr pivot = arr[len(arr) // 2] left = [x for x in arr if x < pivot] middle = [x for x in arr if x == pivot] right = [x for x in arr if x > pivot] return quick_sort(left) + middle + quick_sort(right)
defdfs(graph, start, visited=None): if visited isNone: visited = set() visited.add(start) print(start, end=' ') for neighbor in graph[start]: if neighbor notin visited: dfs(graph, neighbor, visited) return visited
# 迷宫问题 DFS defsolve_maze(maze, x, y, end_x, end_y, visited): if x == end_x and y == end_y: returnTrue if x < 0or x >= len(maze) or y < 0or y >= len(maze[0]): returnFalse if maze[x][y] == 1or (x, y) in visited: returnFalse visited.add((x, y)) directions = [(0, 1), (1, 0), (0, -1), (-1, 0)] for dx, dy in directions: if solve_maze(maze, x + dx, y + dy, end_x, end_y, visited): returnTrue returnFalse
deflis(nums): n = len(nums) dp = [1] * n for i inrange(1, n): for j inrange(i): if nums[j] < nums[i]: dp[i] = max(dp[i], dp[j] + 1) returnmax(dp)
# O(n log n) 优化版本 import bisect
deflis_optimized(nums): tails = [] for num in nums: pos = bisect.bisect_left(tails, num) if pos == len(tails): tails.append(num) else: tails[pos] = num returnlen(tails)
defis_prime(n): if n < 2: returnFalse if n == 2: returnTrue if n % 2 == 0: returnFalse for i inrange(3, int(n ** 0.5) + 1, 2): if n % i == 0: returnFalse returnTrue
defsieve_of_eratosthenes(n): is_prime = [True] * (n + 1) is_prime[0] = is_prime[1] = False for i inrange(2, int(n ** 0.5) + 1): if is_prime[i]: for j inrange(i * i, n + 1, i): is_prime[j] = False return [i for i inrange(n + 1) if is_prime[i]]
最大公约数与最小公倍数
1 2 3 4 5 6 7 8 9 10 11 12 13 14
defgcd(a, b): while b: a, b = b, a % b return a
deflcm(a, b): return a * b // gcd(a, b)
# 扩展欧几里得 defextended_gcd(a, b): if b == 0: return a, 1, 0 g, x, y = extended_gcd(b, a % b) return g, y, x - (a // b) * y
快速幂
1 2 3 4 5 6 7 8 9 10 11 12
deffast_pow(base, exp, mod=None): result = 1 while exp > 0: if exp % 2 == 1: result = result * base if mod: result %= mod base = base * base if mod: base %= mod exp //= 2 return result
# 邻接矩阵 n = 5 matrix = [[0] * n for _ inrange(n)] matrix[0][1] = matrix[1][0] = 1
最短路径算法
Dijkstra 算法
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
import heapq
defdijkstra(graph, start): distances = {node: float('inf') for node in graph} distances[start] = 0 pq = [(0, start)] while pq: dist, node = heapq.heappop(pq) if dist > distances[node]: continue for neighbor, weight in graph[node]: new_dist = dist + weight if new_dist < distances[neighbor]: distances[neighbor] = new_dist heapq.heappush(pq, (new_dist, neighbor)) return distances
Floyd-Warshall 算法
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
deffloyd_warshall(graph): n = len(graph) dist = [[float('inf')] * n for _ inrange(n)] for i inrange(n): dist[i][i] = 0 for j inrange(n): if graph[i][j] != 0: dist[i][j] = graph[i][j] for k inrange(n): for i inrange(n): for j inrange(n): dist[i][j] = min(dist[i][j], dist[i][k] + dist[k][j]) return dist
# 最长回文子串 deflongest_palindrome(s): defexpand(left, right): while left >= 0and right < len(s) and s[left] == s[right]: left -= 1 right += 1 return s[left + 1:right] result = "" for i inrange(len(s)): odd = expand(i, i) even = expand(i, i + 1) iflen(odd) > len(result): result = odd iflen(even) > len(result): result = even return result
# 时钟模拟 defclock_simulation(initial_time, minutes): h, m = map(int, initial_time.split(':')) total = h * 60 + m + minutes new_h = (total // 60) % 24 new_m = total % 60 returnf"{new_h:02d}:{new_m:02d}"
枚举与暴力
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
# 枚举所有子集 defall_subsets(arr): result = [] for i inrange(1 << len(arr)): subset = [arr[j] for j inrange(len(arr)) if i & (1 << j)] result.append(subset) return result
# 枚举区间 deffind_max_subarray(arr): n = len(arr) max_sum = float('-inf') for i inrange(n): current_sum = 0 for j inrange(i, n): current_sum += arr[j] max_sum = max(max_sum, current_sum) return max_sum
# 两数之和 deftwo_sum(nums, target): nums.sort() left, right = 0, len(nums) - 1 while left < right: s = nums[left] + nums[right] if s == target: return [left, right] elif s < target: left += 1 else: right -= 1 return [-1, -1]
count = 0 for p in permutations('0123456789'): if p[0] == '0'or p[4] == '0': continue four = int(''.join(p[:4])) six = int(''.join(p[4:])) if six % four == 0: count += 1 print(count)
defsolve(): n, m = map(int, input().split()) maze = [list(map(int, input().split())) for _ inrange(n)] sx, sy = map(int, input().split()) ex, ey = map(int, input().split()) visited = [[False] * m for _ inrange(n)] queue = deque([(sx, sy, 0)]) visited[sx][sy] = True directions = [(0, 1), (1, 0), (0, -1), (-1, 0)] while queue: x, y, dist = queue.popleft() if x == ex and y == ey: print(dist) return for dx, dy in directions: nx, ny = x + dx, y + dy if0 <= nx < n and0 <= ny < m andnot visited[nx][ny] and maze[nx][ny] == 0: visited[nx][ny] = True queue.append((nx, ny, dist + 1)) print(-1)
solve()
例题:数字三角形
题目:给定一个数字三角形,从顶部出发,每次只能移动到下方相邻的两个数字,求从顶到底的最大路径和。
解题思路:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
n = int(input()) triangle = [list(map(int, input().split())) for _ inrange(n)]
dp = [[0] * n for _ inrange(n)] dp[0][0] = triangle[0][0]
for i inrange(1, n): for j inrange(i + 1): if j == 0: dp[i][j] = dp[i-1][j] + triangle[i][j] elif j == i: dp[i][j] = dp[i-1][j-1] + triangle[i][j] else: dp[i][j] = max(dp[i-1][j-1], dp[i-1][j]) + triangle[i][j]
classFenwickTree: def__init__(self, n): self.n = n self.tree = [0] * (n + 1) defupdate(self, i, delta): while i <= self.n: self.tree[i] += delta i += i & (-i) defquery(self, i): result = 0 while i > 0: result += self.tree[i] i -= i & (-i) return result defrange_query(self, l, r): returnself.query(r) - self.query(l - 1)
defnext_greater_element(nums): n = len(nums) result = [-1] * n stack = [] for i inrange(n): while stack and nums[stack[-1]] < nums[i]: result[stack.pop()] = nums[i] stack.append(i) return result
deflargest_rectangle_area(heights): heights = heights + [0] stack = [-1] max_area = 0 for i, h inenumerate(heights): while stack[-1] != -1and heights[stack[-1]] > h: height = heights[stack.pop()] width = i - stack[-1] - 1 max_area = max(max_area, height * width) stack.append(i) return max_area
单调队列
单调队列用于滑动窗口最值问题。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
from collections import deque
defsliding_window_max(nums, k): result = [] dq = deque() for i, num inenumerate(nums): while dq and nums[dq[-1]] < num: dq.pop() dq.append(i) if dq[0] <= i - k: dq.popleft() if i >= k - 1: result.append(nums[dq[0]]) return result
deftraveling_salesman(cost): n = len(cost) dp = [[float('inf')] * n for _ inrange(1 << n)] dp[1][0] = 0 for mask inrange(1 << n): for u inrange(n): ifnot (mask & (1 << u)): continue for v inrange(n): if mask & (1 << v): continue new_mask = mask | (1 << v) dp[new_mask][v] = min(dp[new_mask][v], dp[mask][u] + cost[u][v]) full_mask = (1 << n) - 1 returnmin(dp[full_mask][i] + cost[i][0] for i inrange(1, n))
defcount_subsets_with_sum(nums, target): n = len(nums) count = 0 for mask inrange(1 << n): total = sum(nums[i] for i inrange(n) if mask & (1 << i)) if total == target: count += 1 return count
defmerge_sort(arr): iflen(arr) <= 1: return arr mid = len(arr) // 2 left = merge_sort(arr[:mid]) right = merge_sort(arr[mid:]) result = [] i = j = 0 while i < len(left) and j < len(right): if left[i] <= right[j]: result.append(left[i]) i += 1 else: result.append(right[j]) j += 1 result.extend(left[i:]) result.extend(right[j:]) return result
defcount_inversions(arr): defmerge_count(left, right): result = [] count = 0 i = j = 0 while i < len(left) and j < len(right): if left[i] <= right[j]: result.append(left[i]) i += 1 else: result.append(right[j]) count += len(left) - i j += 1 result.extend(left[i:]) result.extend(right[j:]) return result, count defsort_count(arr): iflen(arr) <= 1: return arr, 0 mid = len(arr) // 2 left, lc = sort_count(arr[:mid]) right, rc = sort_count(arr[mid:]) merged, mc = merge_count(left, right) return merged, lc + rc + mc _, count = sort_count(arr) return count
deffind_min_max_distance(positions, m): positions.sort() defcan_place(min_dist): count = 1 last = positions[0] for pos in positions[1:]: if pos - last >= min_dist: count += 1 last = pos return count >= m left, right = 1, positions[-1] - positions[0] while left < right: mid = (left + right + 1) // 2 if can_place(mid): left = mid else: right = mid - 1 return left
defminimize_max_value(nums, k): defcan_achieve(max_val): operations = 0 for num in nums: if num > max_val: operations += (num - max_val + k - 1) // k return operations <= sum(n - max_val for n in nums if n > max_val) left, right = min(nums), max(nums) while left < right: mid = (left + right) // 2 if can_achieve(mid): right = mid else: left = mid + 1 return left
deftopological_sort(n, edges): graph = [[] for _ inrange(n)] in_degree = [0] * n for u, v in edges: graph[u].append(v) in_degree[v] += 1 queue = deque([i for i inrange(n) if in_degree[i] == 0]) result = [] while queue: node = queue.popleft() result.append(node) for neighbor in graph[node]: in_degree[neighbor] -= 1 if in_degree[neighbor] == 0: queue.append(neighbor) return result iflen(result) == n else []
最小生成树
Kruskal 算法
1 2 3 4 5 6 7 8 9 10 11 12
defkruskal(n, edges): edges.sort(key=lambda x: x[2]) uf = UnionFind(n) total_weight = 0 mst_edges = [] for u, v, w in edges: if uf.union(u, v): total_weight += w mst_edges.append((u, v, w)) return total_weight, mst_edges
Prim 算法
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
import heapq
defprim(n, graph): visited = [False] * n min_heap = [(0, 0)] total_weight = 0 while min_heap: weight, node = heapq.heappop(min_heap) if visited[node]: continue visited[node] = True total_weight += weight for neighbor, w in graph[node]: ifnot visited[neighbor]: heapq.heappush(min_heap, (w, neighbor)) return total_weight
defeuler_phi(n): result = n p = 2 while p * p <= n: if n % p == 0: while n % p == 0: n //= p result -= result // p p += 1 if n > 1: result -= result // n return result
defeuler_phi_sieve(n): phi = list(range(n + 1)) for i inrange(2, n + 1): if phi[i] == i: for j inrange(i, n + 1, i): phi[j] -= phi[j] // i return phi
模逆元
1 2 3 4 5 6 7 8 9 10 11 12 13 14
defmod_inverse(a, mod): defextended_gcd(a, b): if b == 0: return a, 1, 0 g, x, y = extended_gcd(b, a % b) return g, y, x - (a // b) * y g, x, _ = extended_gcd(a, mod) if g != 1: returnNone return x % mod
defmod_inverse_fermat(a, mod): returnpow(a, mod - 2, mod)
中国剩余定理
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
defchinese_remainder_theorem(remainders, moduli): defextended_gcd(a, b): if b == 0: return a, 1, 0 g, x, y = extended_gcd(b, a % b) return g, y, x - (a // b) * y M = 1 for m in moduli: M *= m result = 0 for r, m inzip(remainders, moduli): Mi = M // m _, inv, _ = extended_gcd(Mi, m) result += r * Mi * inv return result % M
defmatrix_multiply(A, B, mod=None): n = len(A) m = len(B[0]) k = len(B) C = [[0] * m for _ inrange(n)] for i inrange(n): for j inrange(m): for p inrange(k): C[i][j] += A[i][p] * B[p][j] if mod: C[i][j] %= mod return C
defmatrix_pow(M, n, mod=None): size = len(M) result = [[1if i == j else0for j inrange(size)] for i inrange(size)] while n > 0: if n % 2 == 1: result = matrix_multiply(result, M, mod) M = matrix_multiply(M, M, mod) n //= 2 return result
deffibonacci_matrix(n, mod=None): if n <= 1: return n M = [[1, 1], [1, 0]] result = matrix_pow(M, n - 1, mod) return result[0][0]
# 错误:遍历时修改字典 d = {'a': 1, 'b': 2, 'c': 3} for key in d: if d[key] == 2: del d[key] # RuntimeError
# 正确:收集要删除的键 d = {'a': 1, 'b': 2, 'c': 3} to_delete = [k for k, v in d.items() if v == 2] for key in to_delete: del d[key]
浮点数精度问题
1 2 3 4 5 6 7 8 9 10 11 12
# 错误:浮点数比较 a = 0.1 + 0.2 print(a == 0.3) # False
# 正确:使用容差比较 deffloat_equal(a, b, eps=1e-9): returnabs(a - b) < eps
# 或者使用 Decimal from decimal import Decimal a = Decimal('0.1') + Decimal('0.2') print(a == Decimal('0.3')) # True
递归深度问题
1 2 3 4 5 6 7
import sys sys.setrecursionlimit(10**6)
defdeep_recursion(n): if n == 0: return0 return1 + deep_recursion(n - 1)
边界条件处理
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
defbinary_search(arr, target): left, right = 0, len(arr) - 1 while left <= right: # 注意是 <= mid = left + (right - left) // 2# 防止溢出 if arr[mid] == target: return mid elif arr[mid] < target: left = mid + 1 else: right = mid - 1 return -1
defsafe_divide(a, b): if b == 0: returnfloat('inf') if a >= 0elsefloat('-inf') return a / b
十九、时间复杂度优化技巧
常见时间复杂度对照表
数据规模
可接受复杂度
适用算法
n ≤ 10
O(n!)
全排列
n ≤ 20
O(2^n)
状态压缩
n ≤ 100
O(n^3)
Floyd
n ≤ 1000
O(n^2)
朴素 DP
n ≤ 10^4
O(n√n)
优化枚举
n ≤ 10^5
O(n log n)
排序、二分
n ≤ 10^6
O(n)
线性算法
n ≤ 10^8
O(log n)
快速幂
优化技巧
1. 预处理
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# 预处理阶乘 MAX_N = 10**5 factorial = [1] * (MAX_N + 1) for i inrange(1, MAX_N + 1): factorial[i] = factorial[i - 1] * i
# 预处理素数 is_prime = [True] * (MAX_N + 1) primes = [] for i inrange(2, MAX_N + 1): if is_prime[i]: primes.append(i) for j inrange(i * i, MAX_N + 1, i): is_prime[j] = False
2. 空间换时间
1 2 3 4 5 6 7 8
# 使用哈希表加速查找 deftwo_sum_hash(nums, target): seen = {} for i, num inenumerate(nums): if target - num in seen: return [seen[target - num], i] seen[num] = i return [-1, -1]
3. 剪枝优化
1 2 3 4 5 6 7 8 9 10 11
defdfs_with_pruning(arr, target, start, current, result): if current == target: result.append(current[:]) return if current > target: return# 剪枝 for i inrange(start, len(arr)): current.append(arr[i]) dfs_with_pruning(arr, target, i, current, result) current.pop()
4. 滚动数组
1 2 3 4 5 6
defknapsack_optimized(weights, values, capacity): dp = [0] * (capacity + 1) for w, v inzip(weights, values): for j inrange(capacity, w - 1, -1): dp[j] = max(dp[j], dp[j - w] + v) return dp[capacity]
二十、模拟考试模板
标准答题模板
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
import sys from collections import defaultdict, deque
input = sys.stdin.readline
defsolve(): n = int(input()) arr = list(map(int, input().split())) result = 0 print(result)
if __name__ == "__main__": solve()
多组测试数据模板
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
import sys
input = sys.stdin.readline
defsolve(): n = int(input()) if n == 0: returnFalse arr = list(map(int, input().split())) print(result) returnTrue
defsolve(): cards = [2021] * 10 for num inrange(1, 100000): n = num while n > 0: digit = n % 10 if cards[digit] == 0: print(num - 1) return cards[digit] -= 1 n //= 10 print(num)
solve()
2022 年真题:裁纸刀
题目:小蓝有一个裁纸刀,每次可以将一张纸切成两份。现在有 n 张纸,问最少需要切多少刀才能得到 m 张纸?
解题思路:
1 2 3 4 5 6
n, m = map(int, input().split())
if m <= n: print(0) else: print(m - n)
2021 年真题:货物摆放
题目:小蓝有一个超大的仓库,可以摆放很多货物。现在有 n 个货物,每个货物占 1×1 的位置。仓库是长方形的,问有多少种不同的摆放方案(长宽可以交换,视为同一种)。
defhas_invalid_digit(n): while n > 0: digit = n % 10 if digit == 2or digit == 4: returnTrue n //= 10 returnFalse
count = 0 for a inrange(1, 2020): if has_invalid_digit(a): continue for b inrange(a + 1, 2020): if has_invalid_digit(b): continue c = 2019 - a - b if c <= b or has_invalid_digit(c): continue count += 1
print(count)
二十二、高级 DP 专题
区间 DP
区间 DP 用于处理区间上的最优问题,通常枚举区间长度和分割点。
石子合并
题目:有 n 堆石子排成一排,每次可以将相邻的两堆合并,代价为两堆石子数之和。求将所有石子合并成一堆的最小代价。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
defstone_merge(stones): n = len(stones) prefix = [0] * (n + 1) for i inrange(n): prefix[i + 1] = prefix[i] + stones[i] dp = [[0] * n for _ inrange(n)] for length inrange(2, n + 1): for i inrange(n - length + 1): j = i + length - 1 dp[i][j] = float('inf') for k inrange(i, j): cost = dp[i][k] + dp[k + 1][j] + prefix[j + 1] - prefix[i] dp[i][j] = min(dp[i][j], cost) return dp[0][n - 1]
矩阵链乘法
1 2 3 4 5 6 7 8 9 10 11 12 13
defmatrix_chain_order(dims): n = len(dims) - 1 dp = [[0] * n for _ inrange(n)] for length inrange(2, n + 1): for i inrange(n - length + 1): j = i + length - 1 dp[i][j] = float('inf') for k inrange(i, j): cost = dp[i][k] + dp[k + 1][j] + dims[i] * dims[k + 1] * dims[j + 1] dp[i][j] = min(dp[i][j], cost) return dp[0][n - 1]
树形 DP
树形 DP 用于处理树结构上的动态规划问题。
树的最大独立集
1 2 3 4 5 6 7 8 9 10 11 12 13 14
defmax_independent_set(tree, n): defdfs(node, parent): include = 1 exclude = 0 for child in tree[node]: if child != parent: child_include, child_exclude = dfs(child, node) include += child_exclude exclude += max(child_include, child_exclude) return include, exclude returnmax(dfs(0, -1))
defcount_digit_occurrences(n, target): s = str(n) length = len(s) from functools import lru_cache @lru_cache(None) defdfs(pos, tight, count): if pos == length: return count limit = int(s[pos]) if tight else9 total = 0 for digit inrange(limit + 1): new_tight = tight and digit == limit new_count = count + (1if digit == target else0) total += dfs(pos + 1, new_tight, new_count) return total return dfs(0, True, 0)
概率 DP
骰子期望
1 2 3 4 5 6 7 8 9 10
defdice_expectation(n): dp = [0] * (n + 1) dp[0] = 0 for i inrange(1, n + 1): for j inrange(1, 7): if i >= j: dp[i] += (dp[i - j] + 1) / 6 return dp[n]
二十三、字符串高级算法
Z 函数
Z 函数用于计算每个后缀与原串的最长公共前缀。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
defz_function(s): n = len(s) z = [0] * n z[0] = n l, r = 0, 0 for i inrange(1, n): if i < r: z[i] = min(r - i, z[i - l]) while i + z[i] < n and s[z[i]] == s[i + z[i]]: z[i] += 1 if i + z[i] > r: l, r = i, i + z[i] return z
defmanacher(s): t = '#' + '#'.join(s) + '#' n = len(t) p = [0] * n center = right = 0 for i inrange(n): if i < right: p[i] = min(right - i, p[2 * center - i]) while i - p[i] - 1 >= 0and i + p[i] + 1 < n and t[i - p[i] - 1] == t[i + p[i] + 1]: p[i] += 1 if i + p[i] > right: center, right = i, i + p[i] return p
defsuffix_array(s): n = len(s) k = 1 rank = [ord(c) for c in s] + [-1] sa = list(range(n)) tmp = [0] * n while k < n: sa.sort(key=lambda x: (rank[x], rank[x + k] if x + k < n else -1)) tmp[sa[0]] = 0 for i inrange(1, n): prev, curr = sa[i - 1], sa[i] tmp[curr] = tmp[prev] + ( (rank[prev], rank[prev + k] if prev + k < n else -1) < (rank[curr], rank[curr + k] if curr + k < n else -1) ) rank = tmp[:] k *= 2 return sa
deflcp_array(s, sa): n = len(s) rank = [0] * n for i inrange(n): rank[sa[i]] = i lcp = [0] * (n - 1) h = 0 for i inrange(n): if rank[i] > 0: j = sa[rank[i] - 1] while i + h < n and j + h < n and s[i + h] == s[j + h]: h += 1 lcp[rank[i] - 1] = h if h > 0: h -= 1 return lcp
defpolygon_area(points): n = len(points) area = 0 for i inrange(n): j = (i + 1) % n area += points[i].x * points[j].y area -= points[j].x * points[i].y returnabs(area) / 2
二十五、博弈论基础
Nim 游戏
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
defnim_game(piles): xor_sum = 0 for pile in piles: xor_sum ^= pile return xor_sum != 0
defnim_move(piles): xor_sum = 0 for pile in piles: xor_sum ^= pile for i, pile inenumerate(piles): if pile ^ xor_sum < pile: return i, pile - (pile ^ xor_sum) return -1, 0
defsg_function(n, moves): sg = [0] * (n + 1) for i inrange(1, n + 1): reachable = set() for move in moves: if i >= move: reachable.add(sg[i - move]) g = 0 while g in reachable: g += 1 sg[i] = g return sg
defcombined_sg(sg_values): result = 0 for sg in sg_values: result ^= sg return result != 0
巴什博弈
1 2
defbash_game(n, m): return n % (m + 1) != 0
威佐夫博弈
1 2 3 4 5 6 7 8 9 10
import math
defwythoff_game(a, b): if a > b: a, b = b, a k = b - a golden_ratio = (1 + math.sqrt(5)) / 2 return a != int(k * golden_ratio)
defcompress_coordinates(points): x_coords = sorted(set(p[0] for p in points)) y_coords = sorted(set(p[1] for p in points)) x_map = {x: i for i, x inenumerate(x_coords)} y_map = {y: i for i, y inenumerate(y_coords)} return [(x_map[p[0]], y_map[p[1]]) for p in points]
deftest_with_cases(func, test_cases): for i, (input_data, expected) inenumerate(test_cases): result = func(*input_data) ifisinstance(input_data, tuple) else func(input_data) status = "PASS"if result == expected else"FAIL" print(f"Test {i + 1}: {status} (got {result}, expected {expected})")
常见失误避免
忘记初始化:循环内变量要重置
数组越界:检查索引范围
整数溢出:Python 无此问题,但要注意中间计算
精度问题:浮点数比较用容差
边界条件:空输入、单个元素、最大最小值
二十八、代码模板速查
输入模板
1 2 3 4 5 6 7 8
import sys input = sys.stdin.readline
n = int(input()) a, b = map(int, input().split()) arr = list(map(int, input().split())) matrix = [list(map(int, input().split())) for _ inrange(n)] string = input().strip()
快速排序
1 2 3 4 5 6 7 8
defquick_sort(arr): iflen(arr) <= 1: return arr pivot = arr[len(arr) // 2] left = [x for x in arr if x < pivot] middle = [x for x in arr if x == pivot] right = [x for x in arr if x > pivot] return quick_sort(left) + middle + quick_sort(right)
defdijkstra(graph, start, n): dist = [float('inf')] * n dist[start] = 0 pq = [(0, start)] while pq: d, u = heapq.heappop(pq) if d > dist[u]: continue for v, w in graph[u]: if dist[u] + w < dist[v]: dist[v] = dist[u] + w heapq.heappush(pq, (dist[v], v)) return dist
BFS 模板
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
from collections import deque
defbfs(graph, start, n): visited = [False] * n queue = deque([start]) visited[start] = True while queue: u = queue.popleft() for v in graph[u]: ifnot visited[v]: visited[v] = True queue.append(v) return visited
DFS 模板
1 2 3 4 5
defdfs(graph, u, visited): visited[u] = True for v in graph[u]: ifnot visited[v]: dfs(graph, v, visited)
二分查找
1 2 3 4 5 6 7 8 9 10 11
defbinary_search(arr, target): left, right = 0, len(arr) - 1 while left <= right: mid = (left + right) // 2 if arr[mid] == target: return mid elif arr[mid] < target: left = mid + 1 else: right = mid - 1 return -1
前缀和
1 2 3 4 5 6 7 8 9
defprefix_sum(arr): n = len(arr) prefix = [0] * (n + 1) for i inrange(n): prefix[i + 1] = prefix[i] + arr[i] return prefix
defmax_flow(capacity, source, sink, n): graph = [[] for _ inrange(n)] for u inrange(n): for v inrange(n): if capacity[u][v] > 0or capacity[v][u] > 0: graph[u].append(v) graph[v].append(u) defbfs(): parent = [-1] * n parent[source] = source queue = deque([(source, float('inf'))]) while queue: u, flow = queue.popleft() for v in graph[u]: if parent[v] == -1and capacity[u][v] > 0: parent[v] = u new_flow = min(flow, capacity[u][v]) if v == sink: return parent, new_flow queue.append((v, new_flow)) return parent, 0 total_flow = 0 whileTrue: parent, flow = bfs() if flow == 0: break total_flow += flow v = sink while v != source: u = parent[v] capacity[u][v] -= flow capacity[v][u] += flow v = u return total_flow
defhungarian(graph, n_left, n_right): match_right = [-1] * n_right defdfs(u, visited): for v in graph[u]: if visited[v]: continue visited[v] = True if match_right[v] == -1or dfs(match_right[v], visited): match_right[v] = u returnTrue returnFalse result = 0 for u inrange(n_left): visited = [False] * n_right if dfs(u, visited): result += 1 return result
classSkipList: def__init__(self, max_level=16, p=0.5): self.max_level = max_level self.p = p self.level = 0 self.header = SkipListNode(level=max_level) defrandom_level(self): level = 0 while random.random() < self.p and level < self.max_level: level += 1 return level definsert(self, key): update = [None] * (self.max_level + 1) current = self.header for i inrange(self.level, -1, -1): while current.forward[i] and current.forward[i].key < key: current = current.forward[i] update[i] = current new_level = self.random_level() if new_level > self.level: for i inrange(self.level + 1, new_level + 1): update[i] = self.header self.level = new_level new_node = SkipListNode(key, new_level) for i inrange(new_level + 1): new_node.forward[i] = update[i].forward[i] update[i].forward[i] = new_node defsearch(self, key): current = self.header for i inrange(self.level, -1, -1): while current.forward[i] and current.forward[i].key < key: current = current.forward[i] current = current.forward[0] return current and current.key == key defdelete(self, key): update = [None] * (self.max_level + 1) current = self.header for i inrange(self.level, -1, -1): while current.forward[i] and current.forward[i].key < key: current = current.forward[i] update[i] = current current = current.forward[0] if current and current.key == key: for i inrange(self.level + 1): if update[i].forward[i] != current: break update[i].forward[i] = current.forward[i] whileself.level > 0andnotself.header.forward[self.level]: self.level -= 1
defdecode_character(data): result = [] for i inrange(0, len(data), 2): row = bytes_to_binary(data[i]) + bytes_to_binary(data[i + 1]) row_str = ''.join(['*'if b == '1'else' 'for b in row]) result.append(row_str) return'\n'.join(result)
defmobius_sieve(n): mu = [1] * (n + 1) is_prime = [True] * (n + 1) primes = [] for i inrange(2, n + 1): if is_prime[i]: primes.append(i) mu[i] = -1 for p in primes: if i * p > n: break is_prime[i * p] = False if i % p == 0: mu[i * p] = 0 break else: mu[i * p] = -mu[i] return mu
defsum_gcd(n): mu = mobius_sieve(n) result = 0 for d inrange(1, n + 1): sum_mu = 0 for k inrange(1, n // d + 1): sum_mu += mu[k] * (n // (d * k)) ** 2 result += d * sum_mu return result
卢卡斯定理
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
deflucas(n, m, p): if m == 0: return1 return (lucas(n // p, m // p, p) * comb(n % p, m % p, p)) % p
defcomb(n, m, p): if m > n: return0 if m == 0or m == n: return1 fact = [1] * (p + 1) for i inrange(1, p + 1): fact[i] = (fact[i - 1] * i) % p return (fact[n] * pow(fact[m], p - 2, p) * pow(fact[n - m], p - 2, p)) % p
defdeterminant(matrix, n, mod=None): matrix = [row[:] for row in matrix] det = 1 for i inrange(n): if matrix[i][i] == 0: for j inrange(i + 1, n): if matrix[j][i] != 0: matrix[i], matrix[j] = matrix[j], matrix[i] det = -det break if matrix[i][i] == 0: return0 for j inrange(i + 1, n): if matrix[j][i] != 0: ratio = matrix[j][i] / matrix[i][i] if mod: ratio = matrix[j][i] * pow(matrix[i][i], mod - 2, mod) % mod for k inrange(i, n): matrix[j][k] -= ratio * matrix[i][k] if mod: matrix[j][k] %= mod for i inrange(n): det *= matrix[i][i] if mod: det %= mod return det
deffft(a, invert): n = len(a) j = 0 for i inrange(1, n): bit = n >> 1 while j & bit: j ^= bit bit >>= 1 j ^= bit if i < j: a[i], a[j] = a[j], a[i] length = 2 while length <= n: ang = 2 * cmath.pi / length * (-1if invert else1) wlen = cmath.exp(1j * ang) for i inrange(0, n, length): w = 1 for j inrange(i, i + length // 2): u = a[j] v = a[j + length // 2] * w a[j] = u + v a[j + length // 2] = u - v w *= wlen length *= 2 if invert: for i inrange(n): a[i] /= n
defmultiply_polynomials(a, b): n = 1 while n < len(a) + len(b): n *= 2 fa = a + [0] * (n - len(a)) fb = b + [0] * (n - len(b)) fft(fa, False) fft(fb, False) for i inrange(n): fa[i] *= fb[i] fft(fa, True) return [int(round(x.real)) for x in fa[:len(a) + len(b) - 1]]
defida_star(start, goal, heuristic, get_neighbors): defsearch(node, g, bound): f = g + heuristic(node, goal) if f > bound: return f if node == goal: return -g min_bound = float('inf') for neighbor, cost in get_neighbors(node): result = search(neighbor, g + cost, bound) if result < 0: return result min_bound = min(min_bound, result) return min_bound bound = heuristic(start, goal) whileTrue: result = search(start, 0, bound) if result < 0: return -result if result == float('inf'): returnNone bound = result
defiterative_deepening_dfs(graph, start, goal, max_depth): defdfs(node, depth, visited): if node == goal: return [node] if depth == 0: returnNone visited.add(node) for neighbor in graph[node]: if neighbor notin visited: result = dfs(neighbor, depth - 1, visited) if result: return [node] + result visited.remove(node) returnNone for depth inrange(max_depth + 1): result = dfs(start, depth, set()) if result: return result returnNone
classBlockArray: def__init__(self, arr): self.arr = arr[:] self.n = len(arr) self.block_size = int(self.n ** 0.5) + 1 self.block_count = (self.n + self.block_size - 1) // self.block_size self.block_max = [float('-inf')] * self.block_count self.block_sum = [0] * self.block_count for i inrange(self.n): block_id = i // self.block_size self.block_max[block_id] = max(self.block_max[block_id], arr[i]) self.block_sum[block_id] += arr[i] defupdate(self, idx, val): block_id = idx // self.block_size diff = val - self.arr[idx] self.arr[idx] = val self.block_sum[block_id] += diff self.block_max[block_id] = float('-inf') start = block_id * self.block_size end = min(start + self.block_size, self.n) for i inrange(start, end): self.block_max[block_id] = max(self.block_max[block_id], self.arr[i]) defrange_sum(self, l, r): block_l, block_r = l // self.block_size, r // self.block_size result = 0 if block_l == block_r: for i inrange(l, r + 1): result += self.arr[i] else: for i inrange(l, (block_l + 1) * self.block_size): result += self.arr[i] for b inrange(block_l + 1, block_r): result += self.block_sum[b] for i inrange(block_r * self.block_size, r + 1): result += self.arr[i] return result defrange_max(self, l, r): block_l, block_r = l // self.block_size, r // self.block_size result = float('-inf') if block_l == block_r: for i inrange(l, r + 1): result = max(result, self.arr[i]) else: for i inrange(l, (block_l + 1) * self.block_size): result = max(result, self.arr[i]) for b inrange(block_l + 1, block_r): result = max(result, self.block_max[b]) for i inrange(block_r * self.block_size, r + 1): result = max(result, self.arr[i]) return result
classSuffixAutomaton: def__init__(self, s): self.states = [{'link': -1, 'len': 0, 'next': {}}] self.last = 0 for c in s: self.extend(c) defextend(self, c): p = self.last curr = len(self.states) self.states.append({'link': 0, 'len': self.states[p]['len'] + 1, 'next': {}}) while p != -1and c notinself.states[p]['next']: self.states[p]['next'][c] = curr p = self.states[p]['link'] if p != -1: q = self.states[p]['next'][c] ifself.states[p]['len'] + 1 == self.states[q]['len']: self.states[curr]['link'] = q else: clone = len(self.states) self.states.append({ 'link': self.states[q]['link'], 'len': self.states[p]['len'] + 1, 'next': self.states[q]['next'].copy() }) while p != -1andself.states[p]['next'].get(c) == q: self.states[p]['next'][c] = clone p = self.states[p]['link'] self.states[q]['link'] = clone self.states[curr]['link'] = clone self.last = curr defcount_substrings(self): result = 0 for i inrange(1, len(self.states)): result += self.states[i]['len'] - self.states[self.states[i]['link']]['len'] return result deflongest_common_substring(self, t): v = 0 l = 0 best = 0 for c in t: while v != 0and c notinself.states[v]['next']: v = self.states[v]['link'] l = self.states[v]['len'] if c inself.states[v]['next']: v = self.states[v]['next'][c] l += 1 best = max(best, l) return best defis_substring(self, t): v = 0 for c in t: if c notinself.states[v]['next']: returnFalse v = self.states[v]['next'][c] returnTrue
classLCA: def__init__(self, adj, root=0): self.n = len(adj) self.adj = adj self.root = root self.log = (self.n).bit_length() self.depth = [0] * self.n self.parent = [[-1] * self.n for _ inrange(self.log)] self._dfs(root, -1) self._build_sparse_table() def_dfs(self, node, par): self.parent[0][node] = par for child inself.adj[node]: if child != par: self.depth[child] = self.depth[node] + 1 self._dfs(child, node) def_build_sparse_table(self): for k inrange(1, self.log): for v inrange(self.n): ifself.parent[k - 1][v] != -1: self.parent[k][v] = self.parent[k - 1][self.parent[k - 1][v]] deflca(self, u, v): ifself.depth[u] < self.depth[v]: u, v = v, u diff = self.depth[u] - self.depth[v] for k inrange(self.log): if diff >> k & 1: u = self.parent[k][u] if u == v: return u for k inrange(self.log - 1, -1, -1): ifself.parent[k][u] != self.parent[k][v]: u = self.parent[k][u] v = self.parent[k][v] returnself.parent[0][u] defdistance(self, u, v): returnself.depth[u] + self.depth[v] - 2 * self.depth[self.lca(u, v)] defkth_ancestor(self, node, k): if k > self.depth[node]: return -1 for i inrange(self.log): if k >> i & 1: node = self.parent[i][node] return node
defvirtual_tree(adj, key_nodes): n = len(adj) depth = [0] * n parent = [-1] * n dfn = [0] * n order = [] stack = [(0, -1)] while stack: node, par = stack.pop() parent[node] = par dfn[node] = len(order) order.append(node) for child in adj[node]: if child != par: depth[child] = depth[node] + 1 stack.append((child, node)) log = (n).bit_length() up = [[-1] * n for _ inrange(log)] up[0] = parent[:] for k inrange(1, log): for v inrange(n): if up[k - 1][v] != -1: up[k][v] = up[k - 1][up[k - 1][v]] deflca(u, v): if depth[u] < depth[v]: u, v = v, u diff = depth[u] - depth[v] for k inrange(log): if diff >> k & 1: u = up[k][u] if u == v: return u for k inrange(log - 1, -1, -1): if up[k][u] != up[k][v]: u = up[k][u] v = up[k][v] return parent[u] key_nodes = sorted(key_nodes, key=lambda x: dfn[x]) stack = [key_nodes[0]] virtual_edges = [[] for _ inrange(n)] for node in key_nodes[1:]: l = lca(node, stack[-1]) whilelen(stack) >= 2and depth[stack[-2]] >= depth[l]: virtual_edges[stack[-2]].append(stack[-1]) stack.pop() if stack[-1] != l: virtual_edges[l].append(stack[-1]) stack.pop() ifnot stack or stack[-1] != l: stack.append(l) stack.append(node) whilelen(stack) >= 2: virtual_edges[stack[-2]].append(stack[-1]) stack.pop() return virtual_edges, stack[0]
defcdq_divide_conquer(points): n = len(points) points = [(x, y, z, i) for i, (x, y, z) inenumerate(points)] defcdq(l, r): if l >= r: return mid = (l + r) // 2 cdq(l, mid) cdq(mid + 1, r) left = points[l:mid + 1] right = points[mid + 1:r + 1] left.sort(key=lambda p: p[1]) right.sort(key=lambda p: p[1]) j = 0 fenwick = [0] * (n + 1) defupdate(i, val): while i <= n: fenwick[i] += val i += i & -i defquery(i): result = 0 while i > 0: result += fenwick[i] i -= i & -i return result for point in right: while j < len(left) and left[j][1] <= point[1]: update(left[j][2], 1) j += 1 point = (point[0], point[1], point[2], point[3], query(point[2])) points.sort(key=lambda p: p[0]) cdq(0, n - 1) return points
defparallel_binary_search(n, queries, check_func): low = [0] * len(queries) high = [n] * len(queries) answers = [-1] * len(queries) whileTrue: mid_queries = {} all_done = True for i, (l, h) inenumerate(zip(low, high)): if l <= h: all_done = False mid = (l + h) // 2 if mid notin mid_queries: mid_queries[mid] = [] mid_queries[mid].append(i) if all_done: break for mid, query_indices insorted(mid_queries.items()): for idx in query_indices: if check_func(idx, mid): answers[idx] = mid high[idx] = mid - 1 else: low[idx] = mid + 1 return answers
defkth_smallest_queries(arr, queries): n = len(arr) indexed_arr = [(val, i) for i, val inenumerate(arr)] indexed_arr.sort() defcheck(query_idx, mid): l, r, k = queries[query_idx] count = 0 for i inrange(mid + 1): if l <= indexed_arr[i][1] <= r: count += 1 return count >= k return parallel_binary_search(n - 1, queries, check)
defsimulated_annealing(initial_state, energy_func, neighbor_func, iterations=100000): state = initial_state energy = energy_func(state) best_state = state best_energy = energy for i inrange(iterations): temperature = 1.0 - i / iterations new_state = neighbor_func(state) new_energy = energy_func(new_state) delta = new_energy - energy if delta < 0or random.random() < math.exp(-delta / temperature): state = new_state energy = new_energy if energy < best_energy: best_state = state best_energy = energy return best_state, best_energy
classRandomHash: def__init__(self, n): self.n = n self.hash_values = {} defget_hash(self, x): if x notinself.hash_values: self.hash_values[x] = random.randint(1, 10**18) returnself.hash_values[x] defhash_array(self, arr): result = 0 for x in arr: result ^= self.get_hash(x) return result
defmiller_rabin(n, k=10): if n < 2: returnFalse if n == 2or n == 3: returnTrue if n % 2 == 0: returnFalse r, d = 0, n - 1 while d % 2 == 0: r += 1 d //= 2 for _ inrange(k): a = random.randrange(2, n - 1) x = pow(a, d, n) if x == 1or x == n - 1: continue for _ inrange(r - 1): x = pow(x, 2, n) if x == n - 1: break else: returnFalse returnTrue
defpollard_rho(n): if n % 2 == 0: return2 x = random.randint(2, n - 1) y = x c = random.randint(1, n - 1) d = 1 while d == 1: x = (x * x + c) % n y = (y * y + c) % n y = (y * y + c) % n d = math.gcd(abs(x - y), n) return d if d != n elseNone
deffactorize(n): if n == 1: return [] if miller_rabin(n): return [n] factor = pollard_rho(n) while factor isNone: factor = pollard_rho(n) return factorize(factor) + factorize(n // factor)
defevaluate_expression(s): defprecedence(op): if op in'+-': return1 if op in'*/': return2 return0 defapply_op(a, b, op): if op == '+': return a + b if op == '-': return a - b if op == '*': return a * b if op == '/': return a // b values = [] ops = [] i = 0 while i < len(s): if s[i] == ' ': i += 1 continue if s[i] == '(': ops.append(s[i]) elif s[i].isdigit(): val = 0 while i < len(s) and s[i].isdigit(): val = val * 10 + int(s[i]) i += 1 values.append(val) continue elif s[i] == ')': while ops and ops[-1] != '(': val2 = values.pop() val1 = values.pop() op = ops.pop() values.append(apply_op(val1, val2, op)) ops.pop() else: while ops and precedence(ops[-1]) >= precedence(s[i]): val2 = values.pop() val1 = values.pop() op = ops.pop() values.append(apply_op(val1, val2, op)) ops.append(s[i]) i += 1 while ops: val2 = values.pop() val1 = values.pop() op = ops.pop() values.append(apply_op(val1, val2, op)) return values[-1]
defmerge(left, right): result = [] i = j = 0 while i < len(left) and j < len(right): if left[i] <= right[j]: result.append(left[i]) i += 1 else: result.append(right[j]) j += 1 result.extend(left[i:]) result.extend(right[j:]) return result
动态规划
适用场景:最优子结构 + 重叠子问题
1 2 3 4 5 6 7 8 9 10
defdp_examples(): pass
defknapsack(weights, values, capacity): n = len(weights) dp = [0] * (capacity + 1) for i inrange(n): for w inrange(capacity, weights[i] - 1, -1): dp[w] = max(dp[w], dp[w - weights[i]] + values[i]) return dp[capacity]
搜索算法
适用场景:状态空间可遍历
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
defsearch_examples(): pass
defbfs_shortest_path(graph, start, end): from collections import deque queue = deque([(start, [start])]) visited = {start} while queue: node, path = queue.popleft() if node == end: return path for neighbor in graph[node]: if neighbor notin visited: visited.add(neighbor) queue.append((neighbor, path + [neighbor])) returnNone
defmin_cut(n, edges, s, t): mf = MaxFlow(n) for u, v, cap in edges: mf.add_edge(u, v, cap) max_flow = mf.edmonds_karp(s, t) visited = [False] * n queue = deque([s]) visited[s] = True while queue: u = queue.popleft() for v in mf.graph[u]: ifnot visited[v] and mf.capacity[(u, v)] > 0: visited[v] = True queue.append(v) cut_edges = [] for u, v, _ in edges: if visited[u] andnot visited[v]: cut_edges.append((u, v)) return max_flow, cut_edges
classSensitiveWordFilter: def__init__(self, words): self.ac = AhoCorasick() self.words = words for i, word inenumerate(words): self.ac.add_pattern(word, i) self.ac.build_fail() deffilter(self, text, replace_char='*'): matches = self.ac.search(text) result = list(text) for end_pos, word_idx in matches: word = self.words[word_idx] start_pos = end_pos - len(word) + 1 for i inrange(start_pos, end_pos + 1): result[i] = replace_char return''.join(result)
defbuild_suffix_array(s): n = len(s) k = 1 rank = [ord(c) for c in s] + [-1] tmp = [0] * n sa = list(range(n)) while k < n: sa.sort(key=lambda x: (rank[x], rank[x + k] if x + k < n else -1)) tmp[sa[0]] = 0 for i inrange(1, n): prev, curr = sa[i - 1], sa[i] tmp[curr] = tmp[prev] if (rank[prev], rank[prev + k] if prev + k < n else -1) < \ (rank[curr], rank[curr + k] if curr + k < n else -1): tmp[curr] += 1 rank = tmp[:] k *= 2 return sa
defbuild_lcp(s, sa): n = len(s) rank = [0] * n for i inrange(n): rank[sa[i]] = i lcp = [0] * (n - 1) h = 0 for i inrange(n): if rank[i] == 0: h = 0 continue j = sa[rank[i] - 1] while i + h < n and j + h < n and s[i + h] == s[j + h]: h += 1 lcp[rank[i] - 1] = h if h > 0: h -= 1 return lcp
s = "banana" sa = build_suffix_array(s) lcp = build_lcp(s, sa)
print(f"字符串: {s}") print("后缀数组:") for i, pos inenumerate(sa): print(f" SA[{i}] = {pos}: {s[pos:]}") print(f"LCP 数组: {lcp}")
后缀数组应用:最长重复子串
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
deflongest_repeated_substring(s): iflen(s) <= 1: return"" sa = build_suffix_array(s) lcp = build_lcp(s, sa) max_len = 0 pos = 0 for i inrange(len(lcp)): if lcp[i] > max_len: max_len = lcp[i] pos = sa[i] return s[pos:pos + max_len]
defcount_distinct_substrings(s): n = len(s) sa = build_suffix_array(s) lcp = build_lcp(s, sa) total = n * (n + 1) // 2 for l in lcp: total -= l return total
deffind_pattern(sg): n = len(sg) for period inrange(1, n // 2): is_periodic = True for i inrange(period, n): if sg[i] != sg[i - period]: is_periodic = False break if is_periodic: return period returnNone
defmisere_nim(piles): ifall(p == 0for p in piles): return"先手必败" xor_sum = 0 for p in piles: xor_sum ^= p if xor_sum == 0: return"先手必败" ifall(p <= 1for p in piles): return"先手必败"if xor_sum == 1else"先手必胜" return"先手必胜"
defwythoff_game(a, b): from math import sqrt if a > b: a, b = b, a k = b - a phi = (1 + sqrt(5)) / 2 if a == int(k * phi): return"先手必败" return"先手必胜"
defgrundy_game(n): if n <= 2: return0 reachable = set() for i inrange(1, n // 2 + 1): if i != n - i: reachable.add(grundy_game(i) ^ grundy_game(n - i)) return mex(reachable)
print(f"Misère Nim [1,1,1]: {misere_nim([1,1,1])}") print(f"Wythoff 游戏 (3,5): {wythoff_game(3, 5)}")
defiterative_deepening_dfs(start, goal, max_depth, get_neighbors, is_goal): for depth inrange(max_depth + 1): visited = set() result = dfs_with_depth(start, goal, depth, visited, get_neighbors, is_goal) if result isnotNone: return result returnNone
defdfs_with_depth(node, goal, depth, visited, get_neighbors, is_goal): if is_goal(node, goal): return [node] if depth == 0: returnNone visited.add(node) for neighbor in get_neighbors(node): if neighbor notin visited: result = dfs_with_depth(neighbor, goal, depth - 1, visited, get_neighbors, is_goal) if result isnotNone: return [node] + result visited.remove(node) returnNone
defrunning_distance(): start = datetime(2020, 1, 1) end = datetime(2020, 12, 31) total = 0 current = start while current <= end: is_monday = current.weekday() == 0 is_first_day = current.day == 1 if is_monday or is_first_day: total += 2 else: total += 1 current += timedelta(days=1) return total
defcount_decompositions(target): count = 0 for a inrange(1, target): if has_digit(a, 2) or has_digit(a, 4): continue for b inrange(a + 1, target): if has_digit(b, 2) or has_digit(b, 4): continue c = target - a - b if c <= b: continue if has_digit(c, 2) or has_digit(c, 4): continue count += 1 return count
defmemory_optimization(): large_list = [i for i inrange(10**7)] large_gen = (i for i inrange(10**7)) defchunked_processing(data, chunk_size=10000): for i inrange(0, len(data), chunk_size): yield data[i:i + chunk_size] return large_gen
defuse_slots(): classPoint: __slots__ = ['x', 'y'] def__init__(self, x, y): self.x = x self.y = y return Point(1, 2)
defsolve_with_brute_force(): n = int(input()) arr = list(map(int, input().split())) result = 0 for i inrange(n): for j inrange(i + 1, n): result += arr[i] * arr[j] print(result)
defsolve_with_optimized(): n = int(input()) arr = list(map(int, input().split())) total = sum(arr) result = 0 for x in arr: result += x * (total - x) print(result // 2)
defhandle_small_cases(): n = int(input()) if n <= 20: pass elif n <= 1000: pass else: pass
defadvanced_calculator(expression): import re tokens = re.findall(r'\d+|[+\-*/()]', expression) defprecedence(op): if op in'+-': return1 if op in'*/': return2 return0 defapply_op(a, b, op): if op == '+': return a + b if op == '-': return a - b if op == '*': return a * b if op == '/': return a / b values = [] ops = [] for token in tokens: if token.isdigit(): values.append(int(token)) elif token == '(': ops.append(token) elif token == ')': while ops and ops[-1] != '(': b, a = values.pop(), values.pop() values.append(apply_op(a, b, ops.pop())) ops.pop() else: while ops and precedence(ops[-1]) >= precedence(token): b, a = values.pop(), values.pop() values.append(apply_op(a, b, ops.pop())) ops.append(token) while ops: b, a = values.pop(), values.pop() values.append(apply_op(a, b, ops.pop())) return values[0]
defz_function(s): n = len(s) z = [0] * n z[0] = n l, r = 0, 0 for i inrange(1, n): if i < r: z[i] = min(r - i, z[i - l]) while i + z[i] < n and s[z[i]] == s[i + z[i]]: z[i] += 1 if i + z[i] > r: l, r = i, i + z[i] return z
defextended_kmp(text, pattern): concat = pattern + '#' + text z = z_function(concat) n = len(pattern) result = [] for i inrange(n + 1, len(concat)): if z[i] >= n: result.append(i - n - 1) return result
s = "aabcaabxaaz" print(f"Z 函数: {z_function(s)}")
text = "abcabcabc" pattern = "abc" print(f"匹配位置: {extended_kmp(text, pattern)}")
defmanacher(s): t = '#' + '#'.join(s) + '#' n = len(t) p = [0] * n center, right = 0, 0 for i inrange(n): if i < right: mirror = 2 * center - i p[i] = min(right - i, p[mirror]) while i - p[i] - 1 >= 0and i + p[i] + 1 < n and t[i - p[i] - 1] == t[i + p[i] + 1]: p[i] += 1 if i + p[i] > right: center, right = i, i + p[i] return p
classStringHash: def__init__(self, s, base=131, mod=10**9 + 7): self.n = len(s) self.base = base self.mod = mod self.power = [1] * (self.n + 1) self.hash = [0] * (self.n + 1) for i inrange(self.n): self.power[i + 1] = (self.power[i] * base) % mod self.hash[i + 1] = (self.hash[i] * base + ord(s[i])) % mod defget_hash(self, l, r): return (self.hash[r] - self.hash[l] * self.power[r - l] % self.mod + self.mod) % mod defget_full_hash(self): returnself.hash[self.n]
deffind_common_substring(s1, s2, length): h1 = StringHash(s1) h2 = StringHash(s2) seen = set() for i inrange(len(s1) - length + 1): seen.add(h1.get_hash(i, i + length)) for i inrange(len(s2) - length + 1): if h2.get_hash(i, i + length) in seen: returnTrue, i returnFalse, -1
deflongest_common_substring(s1, s2): low, high = 0, min(len(s1), len(s2)) result = "" while low <= high: mid = (low + high) // 2 found, pos = find_common_substring(s1, s2, mid) if found: result = s2[pos:pos + mid] low = mid + 1 else: high = mid - 1 return result
defminimal_representation(s): n = len(s) s = s + s i, j, k = 0, 1, 0 while i < n and j < n and k < n: if s[i + k] == s[j + k]: k += 1 elif s[i + k] > s[j + k]: i = i + k + 1 if i <= j: i = j + 1 k = 0 else: j = j + k + 1 if j <= i: j = i + 1 k = 0 returnmin(i, j)
deftarjan_scc(n, adj): index = [0] * n low = [0] * n on_stack = [False] * n stack = [] sccs = [] idx = [0] defdfs(v): index[v] = low[v] = idx[0] + 1 idx[0] += 1 stack.append(v) on_stack[v] = True for u in adj[v]: if index[u] == 0: dfs(u) low[v] = min(low[v], low[u]) elif on_stack[u]: low[v] = min(low[v], index[u]) if low[v] == index[v]: scc = [] whileTrue: u = stack.pop() on_stack[u] = False scc.append(u) if u == v: break sccs.append(scc) for v inrange(n): if index[v] == 0: dfs(v) return sccs
defkosaraju_scc(n, adj): visited = [False] * n order = [] defdfs1(v): visited[v] = True for u in adj[v]: ifnot visited[u]: dfs1(u) order.append(v) for v inrange(n): ifnot visited[v]: dfs1(v) radj = [[] for _ inrange(n)] for v inrange(n): for u in adj[v]: radj[u].append(v) visited = [False] * n sccs = [] defdfs2(v, scc): visited[v] = True scc.append(v) for u in radj[v]: ifnot visited[u]: dfs2(u, scc) for v inreversed(order): ifnot visited[v]: scc = [] dfs2(v, scc) sccs.append(scc) return sccs
deffind_bridges(n, adj): disc = [0] * n low = [0] * n visited = [False] * n bridges = [] timer = [0] defdfs(v, parent): visited[v] = True disc[v] = low[v] = timer[0] + 1 timer[0] += 1 for u in adj[v]: if u == parent: continue if visited[u]: low[v] = min(low[v], disc[u]) else: dfs(u, v) low[v] = min(low[v], low[u]) if low[u] > disc[v]: bridges.append((v, u)) for v inrange(n): ifnot visited[v]: dfs(v, -1) return bridges
deffind_articulation_points(n, adj): disc = [0] * n low = [0] * n visited = [False] * n ap = [False] * n timer = [0] defdfs(v, parent, root): children = 0 visited[v] = True disc[v] = low[v] = timer[0] + 1 timer[0] += 1 for u in adj[v]: if u == parent: continue if visited[u]: low[v] = min(low[v], disc[u]) else: children += 1 dfs(u, v, root) low[v] = min(low[v], low[u]) if parent != -1and low[u] >= disc[v]: ap[v] = True if parent == -1and children > 1: ap[v] = True for v inrange(n): ifnot visited[v]: dfs(v, -1, v) return [i for i inrange(n) if ap[i]]
defdifference_constraints(n, constraints): adj = [[] for _ inrange(n + 1)] for u, v, c in constraints: adj[v].append((u, c)) for i inrange(n): adj[n].append((i, 0)) dist = [float('inf')] * (n + 1) dist[n] = 0 for _ inrange(n): for u inrange(n + 1): for v, w in adj[u]: if dist[u] + w < dist[v]: dist[v] = dist[u] + w for u inrange(n + 1): for v, w in adj[u]: if dist[u] + w < dist[v]: returnNone return dist[:-1]
defeuler_phi(n): result = n p = 2 while p * p <= n: if n % p == 0: while n % p == 0: n //= p result -= result // p p += 1 if n > 1: result -= result // n return result
defeuler_phi_sieve(n): phi = list(range(n + 1)) for i inrange(2, n + 1): if phi[i] == i: for j inrange(i, n + 1, i): phi[j] -= phi[j] // i return phi
defextended_gcd(a, b): if b == 0: return a, 1, 0 g, x, y = extended_gcd(b, a % b) return g, y, x - (a // b) * y
defcrt(remainders, moduli): n = len(remainders) result = remainders[0] m = moduli[0] for i inrange(1, n): a = remainders[i] m2 = moduli[i] g, p, q = extended_gcd(m, m2) if (a - result) % g != 0: returnNone lcm = m // g * m2 result = (result + m * ((a - result) // g * p % (m2 // g))) % lcm m = lcm return result if result >= 0else result + m
defdiscrete_log(a, b, m): n = int(m ** 0.5) + 1 value = {} an = 1 for i inrange(n): if an notin value: value[an] = i an = (an * a) % m a_n = pow(a, n, m) cur = b for i inrange(n + 1): if cur in value: ans = value[cur] + i * n ifpow(a, ans, m) == b: return ans cur = (cur * a_n) % m returnNone
print(f"离散对数 log_2(3) mod 7 = {discrete_log(2, 3, 7)}")
defprimitive_root(p): phi = p - 1 factors = set() n = phi i = 2 while i * i <= n: if n % i == 0: factors.add(i) while n % i == 0: n //= i i += 1 if n > 1: factors.add(n) for g inrange(2, p + 1): ok = True for f in factors: ifpow(g, phi // f, p) == 1: ok = False break if ok: return g returnNone
defntt(a, invert, root, mod): n = len(a) j = 0 for i inrange(1, n): bit = n >> 1 while j & bit: j ^= bit bit >>= 1 j ^= bit if i < j: a[i], a[j] = a[j], a[i] length = 2 while length <= n: wlen = pow(root, (mod - 1) // length, mod) if invert: wlen = pow(wlen, mod - 2, mod) for i inrange(0, n, length): w = 1 for j inrange(i, i + length // 2): u = a[j] v = a[j + length // 2] * w % mod a[j] = (u + v) % mod a[j + length // 2] = (u - v + mod) % mod w = w * wlen % mod length *= 2 if invert: n_inv = pow(n, mod - 2, mod) for i inrange(n): a[i] = a[i] * n_inv % mod
defpolynomial_multiply_ntt(a, b): MOD = 998244353 ROOT = 3 result_size = len(a) + len(b) - 1 n = 1 while n < result_size: n *= 2 fa = a + [0] * (n - len(a)) fb = b + [0] * (n - len(b)) ntt(fa, False, ROOT, MOD) ntt(fb, False, ROOT, MOD) for i inrange(n): fa[i] = fa[i] * fb[i] % MOD ntt(fa, True, ROOT, MOD) return fa[:result_size]
a = [1, 2, 3] b = [4, 5, 6] print(f"多项式乘积: {polynomial_multiply_ntt(a, b)}")
defcard_puzzle(): cards = [2021] * 10 num = 1 whileTrue: n = num can_form = True temp = [0] * 10 while n > 0: temp[n % 10] += 1 n //= 10 for i inrange(10): if temp[i] > cards[i]: can_form = False break ifnot can_form: return num - 1 for i inrange(10): cards[i] -= temp[i] num += 1
defcount_lines(points): lines = set() n = len(points) for i inrange(n): for j inrange(i + 1, n): x1, y1 = points[i] x2, y2 = points[j] if x1 == x2: lines.add(('v', x1)) elif y1 == y2: lines.add(('h', y1)) else: a = y2 - y1 b = x1 - x2 c = x2 * y1 - x1 * y2 g = gcd(gcd(abs(a), abs(b)), abs(c)) a, b, c = a // g, b // g, c // g if a < 0or (a == 0and b < 0): a, b, c = -a, -b, -c lines.add((a, b, c)) returnlen(lines)
题目:小蓝有一个超大的仓库,可以摆放很多货物。现在,小蓝有 n 箱货物要摆放在仓库,每箱货物都是规则的正方体。小蓝规定了”长”、”宽”、”高”三个方向,每箱货物的边长必须严格等于其长、宽、高。小蓝希望所有的货物最终摆成一个大的长方体。即在长、宽、高的方向上分别堆 a、b、c 箱货物,满足 a × b × c = n。请问有多少种堆放货物的方案?
defcount_arrangements(n): factors = [] i = 1 while i * i <= n: if n % i == 0: factors.append(i) if i != n // i: factors.append(n // i) i += 1 count = 0 for a in factors: if n % a == 0: remaining = n // a j = 1 while j * j <= remaining: if remaining % j == 0: count += 1 if j != remaining // j: count += 1 j += 1 return count
classFenwickTree: def__init__(self, n): self.n = n self.tree = [0] * (n + 1) defupdate(self, i, delta): while i <= self.n: self.tree[i] += delta i += i & -i defquery(self, i): result = 0 while i > 0: result += self.tree[i] i -= i & -i return result
classTwoDTree: def__init__(self, n, m): self.n = n self.m = m self.trees = [FenwickTree(m) for _ inrange(n + 1)] defupdate(self, x, y, delta): i = x while i <= self.n: self.trees[i].update(y, delta) i += i & -i defquery(self, x, y): result = 0 i = x while i > 0: result += self.trees[i].query(y) i -= i & -i return result defrange_query(self, x1, y1, x2, y2): return (self.query(x2, y2) - self.query(x1 - 1, y2) - self.query(x2, y1 - 1) + self.query(x1 - 1, y1 - 1))
defmatrix_mult(A, B, mod=10**9 + 7): n = len(A) m = len(B[0]) k = len(B) C = [[0] * m for _ inrange(n)] for i inrange(n): for j inrange(m): for p inrange(k): C[i][j] = (C[i][j] + A[i][p] * B[p][j]) % mod return C
defmatrix_pow(M, n, mod=10**9 + 7): size = len(M) result = [[1if i == j else0for j inrange(size)] for i inrange(size)] while n > 0: if n & 1: result = matrix_mult(result, M, mod) M = matrix_mult(M, M, mod) n >>= 1 return result
deffibonacci_matrix(n): if n <= 1: return n M = [[1, 1], [1, 0]] result = matrix_pow(M, n - 1) return result[0][0]
deflinear_recurrence(coeffs, init, n, mod=10**9 + 7): k = len(coeffs) if n < k: return init[n] M = [[0] * k for _ inrange(k)] for i inrange(k - 1): M[i][i + 1] = 1 for i inrange(k): M[k - 1][i] = coeffs[k - 1 - i] result = matrix_pow(M, n, mod) ans = 0 for i inrange(k): ans = (ans + result[0][i] * init[i]) % mod return ans
defnext_permutation_of_bits(mask, n): if mask == 0: return0 c = mask & -mask r = mask + c if r >= (1 << n): return0 return r | (((r ^ mask) >> 2) // c)
deftraveling_salesman_dp(dist): n = len(dist) INF = float('inf') dp = [[INF] * n for _ inrange(1 << n)] dp[1][0] = 0 for mask inrange(1, 1 << n): for u inrange(n): ifnot (mask & (1 << u)): continue for v inrange(n): if mask & (1 << v): continue new_mask = mask | (1 << v) dp[new_mask][v] = min(dp[new_mask][v], dp[mask][u] + dist[u][v]) full_mask = (1 << n) - 1 result = INF for u inrange(1, n): result = min(result, dp[full_mask][u] + dist[u][0]) return result
defmemo_template(): @lru_cache(maxsize=None) defdfs(state): if is_terminal(state): return base_case(state) result = initial_value() for next_state in get_transitions(state): result = combine(result, dfs(next_state)) return result return dfs(initial_state)
defsolve_with_memo(): from functools import lru_cache @lru_cache(maxsize=None) defdp(pos, state1, state2): if pos == n: return0if is_valid(state1, state2) elsefloat('-inf') result = float('-inf') for choice in get_choices(pos, state1, state2): new_state1 = update_state1(state1, choice) new_state2 = update_state2(state2, choice) result = max(result, dp(pos + 1, new_state1, new_state2) + get_value(choice)) return result return dp(0, initial_state1, initial_state2)
defcompare(a, b): iflen(a) != len(b): returnlen(a) - len(b) for i inrange(len(a) - 1, -1, -1): if a[i] != b[i]: return a[i] - b[i] return0
defsubtract(a, b): result = [] borrow = 0 for i inrange(len(a)): ai = a[i] bi = b[i] if i < len(b) else0 diff = ai - bi - borrow if diff < 0: diff += 10 borrow = 1 else: borrow = 0 result.append(diff) return remove_leading_zeros(result)
defquick_select_randomized(arr, k): iflen(arr) == 1: return arr[0] pivot = random.choice(arr) left = [x for x in arr if x < pivot] mid = [x for x in arr if x == pivot] right = [x for x in arr if x > pivot] if k <= len(left): return quick_select_randomized(left, k) elif k <= len(left) + len(mid): return pivot else: return quick_select_randomized(right, k - len(left) - len(mid))
defrandomized_partition(arr, low, high): pivot_idx = random.randint(low, high) arr[pivot_idx], arr[high] = arr[high], arr[pivot_idx] pivot = arr[high] i = low - 1 for j inrange(low, high): if arr[j] <= pivot: i += 1 arr[i], arr[j] = arr[j], arr[i] arr[i + 1], arr[high] = arr[high], arr[i + 1] return i + 1
defrandomized_quicksort(arr, low, high): if low < high: pi = randomized_partition(arr, low, high) randomized_quicksort(arr, low, pi - 1) randomized_quicksort(arr, pi + 1, high)
defshuffle_array(arr): for i inrange(len(arr) - 1, 0, -1): j = random.randint(0, i) arr[i], arr[j] = arr[j], arr[i] return arr
defreservoir_sample(stream, k): result = [] for i, item inenumerate(stream): if i < k: result.append(item) else: j = random.randint(0, i) if j < k: result[j] = item return result
classLCA: def__init__(self, adj, root=0): self.n = len(adj) self.adj = adj self.root = root self.depth = [0] * self.n self.parent = [[-1] * self.n for _ inrange(20)] self._dfs(root, -1, 0) for k inrange(1, 20): for v inrange(self.n): ifself.parent[k - 1][v] != -1: self.parent[k][v] = self.parent[k - 1][self.parent[k - 1][v]] def_dfs(self, v, p, d): self.parent[0][v] = p self.depth[v] = d for u inself.adj[v]: if u != p: self._dfs(u, v, d + 1) defget_lca(self, u, v): ifself.depth[u] < self.depth[v]: u, v = v, u diff = self.depth[u] - self.depth[v] for k inrange(20): if diff & (1 << k): u = self.parent[k][u] if u == v: return u for k inrange(19, -1, -1): ifself.parent[k][u] != self.parent[k][v]: u = self.parent[k][u] v = self.parent[k][v] returnself.parent[0][u] defget_distance(self, u, v): lca = self.get_lca(u, v) returnself.depth[u] + self.depth[v] - 2 * self.depth[lca]
defsliding_window_max(arr, k): result = [] dq = deque() for i inrange(len(arr)): while dq and dq[0] < i - k + 1: dq.popleft() while dq and arr[dq[-1]] < arr[i]: dq.pop() dq.append(i) if i >= k - 1: result.append(arr[dq[0]]) return result
defsliding_window_min(arr, k): result = [] dq = deque() for i inrange(len(arr)): while dq and dq[0] < i - k + 1: dq.popleft() while dq and arr[dq[-1]] > arr[i]: dq.pop() dq.append(i) if i >= k - 1: result.append(arr[dq[0]]) return result
defnext_greater_element(arr): n = len(arr) result = [-1] * n stack = [] for i inrange(n): while stack and arr[stack[-1]] < arr[i]: result[stack.pop()] = arr[i] stack.append(i) return result
defprevious_smaller_element(arr): n = len(arr) result = [-1] * n stack = [] for i inrange(n): while stack and arr[stack[-1]] >= arr[i]: stack.pop() if stack: result[i] = arr[stack[-1]] stack.append(i) return result
deflargest_rectangle_in_histogram(heights): stack = [] max_area = 0 heights.append(0) for i, h inenumerate(heights): while stack and heights[stack[-1]] > h: height = heights[stack.pop()] width = i ifnot stack else i - stack[-1] - 1 max_area = max(max_area, height * width) stack.append(i) heights.pop() return max_area
deftopological_sort_kahn(n, edges): adj = [[] for _ inrange(n)] in_degree = [0] * n for u, v in edges: adj[u].append(v) in_degree[v] += 1 from collections import deque queue = deque([i for i inrange(n) if in_degree[i] == 0]) result = [] while queue: u = queue.popleft() result.append(u) for v in adj[u]: in_degree[v] -= 1 if in_degree[v] == 0: queue.append(v) return result iflen(result) == n elseNone
defcourse_schedule(num_courses, prerequisites): edges = [(v, u) for u, v in prerequisites] result = topological_sort_kahn(num_courses, edges) return result isnotNone
deffind_order(num_courses, prerequisites): edges = [(v, u) for u, v in prerequisites] return topological_sort_kahn(num_courses, edges)
defmatrix_chain_multiplication(dims): n = len(dims) - 1 dp = [[0] * n for _ inrange(n)] for length inrange(2, n + 1): for i inrange(n - length + 1): j = i + length - 1 dp[i][j] = float('inf') for k inrange(i, j): cost = dp[i][k] + dp[k + 1][j] + dims[i] * dims[k + 1] * dims[j + 1] dp[i][j] = min(dp[i][j], cost) return dp[0][n - 1]
defoptimal_bst(keys, freq): n = len(keys) dp = [[0] * n for _ inrange(n)] prefix_sum = [0] * (n + 1) for i inrange(n): prefix_sum[i + 1] = prefix_sum[i] + freq[i] for length inrange(1, n + 1): for i inrange(n - length + 1): j = i + length - 1 dp[i][j] = float('inf') for k inrange(i, j + 1): left = dp[i][k - 1] if k > i else0 right = dp[k + 1][j] if k < j else0 cost = left + right + prefix_sum[j + 1] - prefix_sum[i] dp[i][j] = min(dp[i][j], cost) return dp[0][n - 1]
defpalindrome_partitioning(s): n = len(s) is_palindrome = [[False] * n for _ inrange(n)] for i inrange(n): is_palindrome[i][i] = True for i inrange(n - 1): is_palindrome[i][i + 1] = s[i] == s[i + 1] for length inrange(3, n + 1): for i inrange(n - length + 1): j = i + length - 1 is_palindrome[i][j] = (s[i] == s[j]) and is_palindrome[i + 1][j - 1] dp = [float('inf')] * n for i inrange(n): if is_palindrome[0][i]: dp[i] = 0 else: for j inrange(i): if is_palindrome[j + 1][i]: dp[i] = min(dp[i], dp[j] + 1) return dp[n - 1]
defstone_merge(stones): n = len(stones) prefix = [0] * (n + 1) for i inrange(n): prefix[i + 1] = prefix[i] + stones[i] dp = [[0] * n for _ inrange(n)] for length inrange(2, n + 1): for i inrange(n - length + 1): j = i + length - 1 dp[i][j] = float('inf') for k inrange(i, j): cost = dp[i][k] + dp[k + 1][j] + prefix[j + 1] - prefix[i] dp[i][j] = min(dp[i][j], cost) return dp[0][n - 1]
defkruskal_with_restriction(n, edges, must_include, exclude): parent = list(range(n)) rank = [0] * n deffind(x): if parent[x] != x: parent[x] = find(parent[x]) return parent[x] defunion(x, y): px, py = find(x), find(y) if px == py: returnFalse if rank[px] < rank[py]: px, py = py, px parent[py] = px if rank[px] == rank[py]: rank[px] += 1 returnTrue total = 0 count = 0 for u, v, w in must_include: if union(u, v): total += w count += 1 edges = [e for e in edges if e notin exclude and e notin must_include] edges.sort(key=lambda x: x[2]) for u, v, w in edges: if union(u, v): total += w count += 1 if count == n - 1: break return total if count == n - 1else -1
defsecond_mst(n, edges): defkruskal(n, edges, skip=None): parent = list(range(n)) deffind(x): if parent[x] != x: parent[x] = find(parent[x]) return parent[x] total = 0 mst_edges = [] for i, (u, v, w) inenumerate(edges): if skip == i: continue if find(u) != find(v): parent[find(u)] = find(v) total += w mst_edges.append(i) return total, mst_edges iflen(mst_edges) == n - 1else (float('inf'), []) edges.sort(key=lambda x: x[2]) mst_cost, mst_edges = kruskal(n, edges) second_best = float('inf') for skip_idx in mst_edges: cost, _ = kruskal(n, edges, skip_idx) if cost > mst_cost: second_best = min(second_best, cost) return mst_cost, second_best
defminimum_spanning_arborescence(n, edges, root): INF = float('inf') in_edge = [INF] * n parent = [-1] * n for u, v, w in edges: if v != root and w < in_edge[v]: in_edge[v] = w parent[v] = u visited = [-1] * n cycle_id = [-1] * n total = 0 cnt = 0 for i inrange(n): if i == root: continue total += in_edge[i] v = i while v != root and visited[v] == -1: visited[v] = i v = parent[v] if v != root and visited[v] == i: u = parent[v] while u != v: cycle_id[u] = cnt u = parent[u] cycle_id[v] = cnt cnt += 1 if cnt == 0: return total for i inrange(n): if cycle_id[i] == -1: cycle_id[i] = cnt cnt += 1 new_edges = [] for u, v, w in edges: new_u = cycle_id[u] new_v = cycle_id[v] new_w = w - in_edge[v] if cycle_id[v] != cycle_id[u] else0 if new_u != new_v: new_edges.append((new_u, new_v, new_w)) return total + minimum_spanning_arborescence(cnt, new_edges, cycle_id[root])
defdifference_array_2d(matrix, updates): n, m = len(matrix), len(matrix[0]) diff = [[0] * (m + 2) for _ inrange(n + 2)] for x1, y1, x2, y2, val in updates: diff[x1][y1] += val diff[x1][y2 + 1] -= val diff[x2 + 1][y1] -= val diff[x2 + 1][y2 + 1] += val result = [[0] * m for _ inrange(n)] for i inrange(n): for j inrange(m): if i > 0: diff[i][j] += diff[i - 1][j] if j > 0: diff[i][j] += diff[i][j - 1] if i > 0and j > 0: diff[i][j] -= diff[i - 1][j - 1] result[i][j] = matrix[i][j] + diff[i][j] return result
defdifference_on_tree(n, adj, queries, root=0): parent = [-1] * n depth = [0] * n size = [0] * n heavy = [-1] * n defdfs(u, p): size[u] = 1 max_size = 0 for v in adj[u]: if v != p: parent[v] = u depth[v] = depth[u] + 1 dfs(v, u) size[u] += size[v] if size[v] > max_size: max_size = size[v] heavy[u] = v dfs(root, -1) diff = [0] * n for u, v, val in queries: diff[u] += val diff[v] += val lca = find_lca(u, v, parent, depth) diff[lca] -= val if parent[lca] != -1: diff[parent[lca]] -= val defpush_diff(u, p): for v in adj[u]: if v != p: push_diff(v, u) diff[u] += diff[v] push_diff(root, -1) return diff
deffind_lca(u, v, parent, depth): while u != v: if depth[u] > depth[v]: u = parent[u] else: v = parent[v] return u
defcritical_path(n, adj, durations): in_degree = [0] * n for u inrange(n): for v in adj[u]: in_degree[v] += 1 from collections import deque queue = deque([i for i inrange(n) if in_degree[i] == 0]) order = [] earliest = [0] * n while queue: u = queue.popleft() order.append(u) for v in adj[u]: earliest[v] = max(earliest[v], earliest[u] + durations[u]) in_degree[v] -= 1 if in_degree[v] == 0: queue.append(v) total_time = max(earliest[i] + durations[i] for i inrange(n)) latest = [total_time - durations[i] for i inrange(n)] for u inreversed(order): for v in adj[u]: latest[u] = min(latest[u], latest[v] - durations[u]) critical_tasks = [] for i inrange(n): if earliest[i] == latest[i]: critical_tasks.append(i) return total_time, critical_tasks, earliest, latest
deflongest_path_dag(n, adj, weights): in_degree = [0] * n for u inrange(n): for v in adj[u]: in_degree[v] += 1 from collections import deque queue = deque([i for i inrange(n) if in_degree[i] == 0]) dist = [float('-inf')] * n for i inrange(n): if in_degree[i] == 0: dist[i] = 0 while queue: u = queue.popleft() for v in adj[u]: if dist[v] < dist[u] + weights[(u, v)]: dist[v] = dist[u] + weights[(u, v)] in_degree[v] -= 1 if in_degree[v] == 0: queue.append(v) return dist
defboyer_moore(text, pattern): defbuild_bad_char_table(pattern): table = {} for i inrange(len(pattern)): table[pattern[i]] = i return table defbuild_good_suffix_table(pattern): m = len(pattern) table = [0] * (m + 1) suffix = [0] * m suffix[m - 1] = m for i inrange(m - 2, -1, -1): j = i while j >= 0and pattern[j] == pattern[m - 1 - i + j]: j -= 1 suffix[i] = i - j for i inrange(m): table[i] = m for i inrange(m - 1, -1, -1): if suffix[i] == i + 1: for j inrange(m - 1 - i): if table[j] == m: table[j] = m - 1 - i for i inrange(m - 1): table[m - 1 - suffix[i]] = m - 1 - i return table n, m = len(text), len(pattern) if m == 0: return [] bad_char = build_bad_char_table(pattern) good_suffix = build_good_suffix_table(pattern) result = [] i = 0 while i <= n - m: j = m - 1 while j >= 0and pattern[j] == text[i + j]: j -= 1 if j < 0: result.append(i) i += good_suffix[0] else: bc_shift = j - bad_char.get(text[i + j], -1) gs_shift = good_suffix[j + 1] i += max(bc_shift, gs_shift) return result
defsunday_search(text, pattern): defbuild_shift_table(pattern): table = {} m = len(pattern) for i inrange(m): table[pattern[i]] = m - i return table n, m = len(text), len(pattern) if m == 0: return [] shift = build_shift_table(pattern) result = [] i = 0 while i <= n - m: j = 0 while j < m and text[i + j] == pattern[j]: j += 1 if j == m: result.append(i) if i + m >= n: break i += shift.get(text[i + m], m + 1) return result
defword_frequency(text): import re from collections import Counter words = re.findall(r'\b\w+\b', text.lower()) return Counter(words)
deffind_longest_common_prefix(strs): ifnot strs: return"" prefix = strs[0] for s in strs[1:]: whilenot s.startswith(prefix): prefix = prefix[:-1] ifnot prefix: return"" return prefix
defgroup_anagrams(strs): from collections import defaultdict groups = defaultdict(list) for s in strs: key = ''.join(sorted(s)) groups[key].append(s) returnlist(groups.values())
deftext_justification(words, max_width): result = [] line = [] line_length = 0 for word in words: if line_length + len(line) + len(word) > max_width: spaces = max_width - line_length iflen(line) == 1: result.append(line[0] + ' ' * spaces) else: space_between = spaces // (len(line) - 1) extra = spaces % (len(line) - 1) line_str = "" for i, w inenumerate(line): line_str += w if i < len(line) - 1: line_str += ' ' * (space_between + (1if i < extra else0)) result.append(line_str) line = [] line_length = 0 line.append(word) line_length += len(word) if line: result.append(' '.join(line) + ' ' * (max_width - line_length - len(line) + 1)) return result
classMatrix: def__init__(self, data): self.data = data self.rows = len(data) self.cols = len(data[0]) if data else0 def__add__(self, other): ifself.rows != other.rows orself.cols != other.cols: raise ValueError("矩阵维度不匹配") result = [[self.data[i][j] + other.data[i][j] for j inrange(self.cols)] for i inrange(self.rows)] return Matrix(result) def__mul__(self, other): ifself.cols != other.rows: raise ValueError("矩阵维度不匹配") result = [[0] * other.cols for _ inrange(self.rows)] for i inrange(self.rows): for j inrange(other.cols): for k inrange(self.cols): result[i][j] += self.data[i][k] * other.data[k][j] return Matrix(result) defscalar_mul(self, c): result = [[self.data[i][j] * c for j inrange(self.cols)] for i inrange(self.rows)] return Matrix(result) deftranspose(self): result = [[self.data[j][i] for j inrange(self.rows)] for i inrange(self.cols)] return Matrix(result) defdeterminant(self): ifself.rows != self.cols: raise ValueError("非方阵") n = self.rows if n == 1: returnself.data[0][0] det = 0 for j inrange(n): minor = [[self.data[i][k] for k inrange(n) if k != j] for i inrange(1, n)] det += ((-1) ** j) * self.data[0][j] * Matrix(minor).determinant() return det definverse(self): det = self.determinant() ifabs(det) < 1e-10: raise ValueError("矩阵不可逆") n = self.rows augmented = [row[:] + [1if i == j else0for j inrange(n)] for i, row inenumerate(self.data)] for i inrange(n): max_row = i for k inrange(i + 1, n): ifabs(augmented[k][i]) > abs(augmented[max_row][i]): max_row = k augmented[i], augmented[max_row] = augmented[max_row], augmented[i] pivot = augmented[i][i] for j inrange(2 * n): augmented[i][j] /= pivot for k inrange(n): if k != i: factor = augmented[k][i] for j inrange(2 * n): augmented[k][j] -= factor * augmented[i][j] return Matrix([row[n:] for row in augmented])
defdivisor_function(n): count = 0 total = 0 i = 1 while i * i <= n: if n % i == 0: count += 1if i * i == n else2 total += i if i != n // i: total += n // i i += 1 return count, total
defsum_of_divisors_sieve(n): sigma = [1] * (n + 1) sigma[0] = 0 for i inrange(2, n + 1): for j inrange(i, n + 1, i): sigma[j] += i return sigma
defnumber_of_divisors_sieve(n): d = [1] * (n + 1) d[0] = 0 is_prime = [True] * (n + 1) primes = [] min_prime = [0] * (n + 1) for i inrange(2, n + 1): if is_prime[i]: primes.append(i) min_prime[i] = i d[i] = 2 for p in primes: if i * p > n: break is_prime[i * p] = False min_prime[i * p] = p if i % p == 0: exp = 0 temp = i while temp % p == 0: temp //= p exp += 1 d[i * p] = d[i] // (exp + 1) * (exp + 2) break else: d[i * p] = d[i] * 2 return d
defmobius_function_sieve(n): mu = [1] * (n + 1) is_prime = [True] * (n + 1) primes = [] for i inrange(2, n + 1): if is_prime[i]: primes.append(i) mu[i] = -1 for p in primes: if i * p > n: break is_prime[i * p] = False if i % p == 0: mu[i * p] = 0 break else: mu[i * p] = -mu[i] return mu
deffind_integer(remainders): from math import gcd from functools import reduce defextended_gcd(a, b): if b == 0: return a, 1, 0 g, x, y = extended_gcd(b, a % b) return g, y, x - (a // b) * y defcrt_pair(a1, m1, a2, m2): g, p, q = extended_gcd(m1, m2) if (a2 - a1) % g != 0: returnNone, None lcm = m1 // g * m2 x = (a1 + m1 * ((a2 - a1) // g * p % (m2 // g))) % lcm return x, lcm result = 0 mod = 1 for m, a in remainders.items(): result, mod = crt_pair(result, mod, a, m) if result isNone: returnNone return result if result > 0else result + mod
defmax_submatrix(matrix): n, m = len(matrix), len(matrix[0]) max_sum = float('-inf') for top inrange(n): row_sum = [0] * m for bottom inrange(top, n): for j inrange(m): row_sum[j] += matrix[bottom][j] current = 0 for val in row_sum: current = max(val, current + val) max_sum = max(max_sum, current) return max_sum
defunique_paths(m, n): dp = [[1] * n for _ inrange(m)] for i inrange(1, m): for j inrange(1, n): dp[i][j] = dp[i - 1][j] + dp[i][j - 1] return dp[m - 1][n - 1]
defunique_paths_with_obstacles(obstacle_grid): m, n = len(obstacle_grid), len(obstacle_grid[0]) if obstacle_grid[0][0] == 1: return0 dp = [[0] * n for _ inrange(m)] dp[0][0] = 1 for i inrange(m): for j inrange(n): if obstacle_grid[i][j] == 1: dp[i][j] = 0 else: if i > 0: dp[i][j] += dp[i - 1][j] if j > 0: dp[i][j] += dp[i][j - 1] return dp[m - 1][n - 1]
print(f"不同路径数 (3x7): {unique_paths(3, 7)}")
2020 年真题:数字三角形
题目:给定一个数字三角形,从顶部出发,每次可以移动到下方相邻的数字,求从顶部到底部的最大路径和。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
defmaximum_path_sum(triangle): n = len(triangle) dp = triangle[-1][:] for i inrange(n - 2, -1, -1): for j inrange(i + 1): dp[j] = triangle[i][j] + max(dp[j], dp[j + 1]) return dp[0]
defevaluate_expression_advanced(s): import re tokens = re.findall(r'\d+\.?\d*|[+\-*/()]', s) defprecedence(op): if op in'+-': return1 if op in'*/': return2 return0 defapply_op(a, b, op): if op == '+': return a + b if op == '-': return a - b if op == '*': return a * b if op == '/': return a / b values = [] ops = [] i = 0 while i < len(tokens): token = tokens[i] if token == '(': ops.append(token) elif token == ')': while ops and ops[-1] != '(': b, a = values.pop(), values.pop() values.append(apply_op(a, b, ops.pop())) ops.pop() elif token in'+-*/': while ops and precedence(ops[-1]) >= precedence(token): b, a = values.pop(), values.pop() values.append(apply_op(a, b, ops.pop())) ops.append(token) else: values.append(float(token)) i += 1 while ops: b, a = values.pop(), values.pop() values.append(apply_op(a, b, ops.pop())) return values[0]
defmax_profit_multiple(prices): profit = 0 for i inrange(1, len(prices)): if prices[i] > prices[i - 1]: profit += prices[i] - prices[i - 1] return profit
defmax_profit_k_transactions(k, prices): n = len(prices) if n < 2or k < 1: return0 if k >= n // 2: return max_profit_multiple(prices) dp = [[0] * n for _ inrange(k + 1)] for i inrange(1, k + 1): max_diff = -prices[0] for j inrange(1, n): dp[i][j] = max(dp[i][j - 1], prices[j] + max_diff) max_diff = max(max_diff, dp[i - 1][j] - prices[j]) return dp[k][n - 1]
deffast_input(): data = stdin.read().split() it = iter(data) return it
defsolve(): it = fast_input() n = int(next(it)) arr = [int(next(it)) for _ inrange(n)] result = sum(arr) print(result)
defmulti_test_cases(): it = fast_input() t = int(next(it)) for _ inrange(t): n = int(next(it)) arr = [int(next(it)) for _ inrange(n)] result = max(arr) print(result)
defread_matrix(): it = fast_input() n, m = int(next(it)), int(next(it)) matrix = [[int(next(it)) for _ inrange(m)] for _ inrange(n)] return matrix
defoutput_format(): results = [1, 2, 3, 4, 5] print('\n'.join(map(str, results))) print(' '.join(map(str, results))) for i, result inenumerate(results): print(f"Case {i + 1}: {result}")
ALGORITHM_TEMPLATES = """ 二分查找: left, right = 0, n - 1 while left <= right: mid = (left + right) // 2 if check(mid): right = mid - 1 else: left = mid + 1 BFS: queue = deque([start]) visited = set([start]) while queue: node = queue.popleft() for neighbor in get_neighbors(node): if neighbor not in visited: visited.add(neighbor) queue.append(neighbor) DFS: def dfs(node, visited): visited.add(node) for neighbor in graph[node]: if neighbor not in visited: dfs(neighbor, visited) 并查集: def find(x): if parent[x] != x: parent[x] = find(parent[x]) return parent[x] def union(x, y): px, py = find(x), find(y) if px != py: parent[px] = py 快速幂: def pow_mod(base, exp, mod): result = 1 while exp: if exp & 1: result = result * base % mod base = base * base % mod exp >>= 1 return result GCD: def gcd(a, b): while b: a, b = b, a % b return a """
defslow_fib(n): if n <= 1: return n return slow_fib(n-1) + slow_fib(n-2)
deffast_fib(n): from functools import lru_cache @lru_cache(maxsize=None) deffib(n): if n <= 1: return n return fib(n-1) + fib(n-2) return fib(n)
defmatrix_fib(n): if n <= 1: return n defmatrix_mult(A, B, mod=10**9+7): return [[sum(A[i][k]*B[k][j] for k inrange(2)) % mod for j inrange(2)] for i inrange(2)] defmatrix_pow(M, n): result = [[1,0],[0,1]] while n: if n & 1: result = matrix_mult(result, M) M = matrix_mult(M, M) n >>= 1 return result M = [[1,1],[1,0]] return matrix_pow(M, n-1)[0][0]
defcheck_constraints(value, min_val, max_val): ifnot (min_val <= value <= max_val): raise ValueError(f"Value {value} out of range [{min_val}, {max_val}]") return value
classHeavyLightDecomposition: def__init__(self, n, adj): self.n = n self.adj = adj self.parent = [-1] * n self.depth = [0] * n self.size = [0] * n self.heavy = [-1] * n self.head = [0] * n self.pos = [0] * n self.current_pos = 0 self._dfs(0, -1) self._decompose(0, 0) def_dfs(self, u, p): self.size[u] = 1 max_size = 0 for v inself.adj[u]: if v != p: self.parent[v] = u self.depth[v] = self.depth[u] + 1 self._dfs(v, u) self.size[u] += self.size[v] ifself.size[v] > max_size: max_size = self.size[v] self.heavy[u] = v def_decompose(self, u, h): self.head[u] = h self.pos[u] = self.current_pos self.current_pos += 1 ifself.heavy[u] != -1: self._decompose(self.heavy[u], h) for v inself.adj[u]: if v != self.parent[u] and v != self.heavy[u]: self._decompose(v, v) deflca(self, u, v): whileself.head[u] != self.head[v]: ifself.depth[self.head[u]] > self.depth[self.head[v]]: u = self.parent[self.head[u]] else: v = self.parent[self.head[v]] return u ifself.depth[u] < self.depth[v] else v defpath_query(self, u, v, tree): result = 0 whileself.head[u] != self.head[v]: ifself.depth[self.head[u]] > self.depth[self.head[v]]: result += tree.query(self.pos[self.head[u]], self.pos[u]) u = self.parent[self.head[u]] else: result += tree.query(self.pos[self.head[v]], self.pos[v]) v = self.parent[self.head[v]] ifself.depth[u] > self.depth[v]: u, v = v, u result += tree.query(self.pos[u], self.pos[v]) return result
defhungarian_algorithm(adj, n, m): match_r = [-1] * m defbpm(u, seen): for v inrange(m): if adj[u][v] andnot seen[v]: seen[v] = True if match_r[v] == -1or bpm(match_r[v], seen): match_r[v] = u returnTrue returnFalse result = 0 for u inrange(n): seen = [False] * m if bpm(u, seen): result += 1 return result, match_r
defhopcroft_karp(adj, n, m): from collections import deque pair_u = [-1] * n pair_v = [-1] * m dist = [0] * n defbfs(): queue = deque() for u inrange(n): if pair_u[u] == -1: dist[u] = 0 queue.append(u) else: dist[u] = float('inf') dist_null = float('inf') while queue: u = queue.popleft() if dist[u] < dist_null: for v inrange(m): if adj[u][v] and pair_v[v] != -1and dist[pair_v[v]] == float('inf'): dist[pair_v[v]] = dist[u] + 1 queue.append(pair_v[v]) elif adj[u][v] and pair_v[v] == -1: dist_null = dist[u] + 1 return dist_null != float('inf') defdfs(u): for v inrange(m): if adj[u][v]: if pair_v[v] == -1or (dist[pair_v[v]] == dist[u] + 1and dfs(pair_v[v])): pair_u[u] = v pair_v[v] = u returnTrue dist[u] = float('inf') returnFalse matching = 0 while bfs(): for u inrange(n): if pair_u[u] == -1and dfs(u): matching += 1 return matching, pair_u, pair_v
classSuffixAutomaton: def__init__(self, s): self.states = [{'link': -1, 'len': 0, 'next': {}}] self.last = 0 for c in s: self._extend(c) def_extend(self, c): p = self.last curr = len(self.states) self.states.append({'link': -1, 'len': self.states[p]['len'] + 1, 'next': {}}) while p != -1and c notinself.states[p]['next']: self.states[p]['next'][c] = curr p = self.states[p]['link'] if p == -1: self.states[curr]['link'] = 0 else: q = self.states[p]['next'][c] ifself.states[p]['len'] + 1 == self.states[q]['len']: self.states[curr]['link'] = q else: clone = len(self.states) self.states.append({ 'link': self.states[q]['link'], 'len': self.states[p]['len'] + 1, 'next': self.states[q]['next'].copy() }) while p != -1andself.states[p]['next'].get(c) == q: self.states[p]['next'][c] = clone p = self.states[p]['link'] self.states[q]['link'] = clone self.states[curr]['link'] = clone self.last = curr defcount_distinct_substrings(self): result = 0 for i inrange(1, len(self.states)): result += self.states[i]['len'] - self.states[self.states[i]['link']]['len'] return result deflongest_common_substring(self, t): v = 0 l = 0 best = 0 for c in t: while v != 0and c notinself.states[v]['next']: v = self.states[v]['link'] l = self.states[v]['len'] if c inself.states[v]['next']: v = self.states[v]['next'][c] l += 1 best = max(best, l) return best
sam = SuffixAutomaton("ababa") print(f"不同子串数量: {sam.count_distinct_substrings()}") print(f"与'bab'的最长公共子串: {sam.longest_common_substring('bab')}")
defbuild_suffix_array(s): n = len(s) k = 1 rank = [ord(c) for c in s] tmp = [0] * n sa = list(range(n)) whileTrue: sa.sort(key=lambda x: (rank[x], rank[x + k] if x + k < n else -1)) tmp[sa[0]] = 0 for i inrange(1, n): prev, curr = sa[i - 1], sa[i] tmp[curr] = tmp[prev] if rank[prev] != rank[curr] or \ (rank[prev + k] if prev + k < n else -1) != \ (rank[curr + k] if curr + k < n else -1): tmp[curr] += 1 rank = tmp[:] if rank[sa[-1]] == n - 1: break k *= 2 return sa
defbuild_lcp_array(s, sa): n = len(s) rank = [0] * n for i inrange(n): rank[sa[i]] = i lcp = [0] * (n - 1) h = 0 for i inrange(n): if rank[i] == 0: h = 0 continue j = sa[rank[i] - 1] while i + h < n and j + h < n and s[i + h] == s[j + h]: h += 1 lcp[rank[i] - 1] = h if h > 0: h -= 1 return lcp
defcount_distinct_substrings_sa(s): n = len(s) sa = build_suffix_array(s) lcp = build_lcp_array(s, sa) return n * (n + 1) // 2 - sum(lcp)
s = "banana" sa = build_suffix_array(s) lcp = build_lcp_array(s, sa)
defrotating_calipers(points): hull = convex_hull(points) n = len(hull) if n == 1: return0, hull[0], hull[0] if n == 2: dist = ((hull[0][0] - hull[1][0])**2 + (hull[0][1] - hull[1][1])**2)**0.5 return dist, hull[0], hull[1] defdist(p1, p2): return ((p1[0] - p2[0])**2 + (p1[1] - p2[1])**2)**0.5 max_dist = 0 p1_max, p2_max = None, None j = 1 for i inrange(n): whileTrue: next_j = (j + 1) % n d1 = dist(hull[i], hull[j]) d2 = dist(hull[i], hull[next_j]) if d2 > d1: j = next_j else: break d = dist(hull[i], hull[j]) if d > max_dist: max_dist = d p1_max, p2_max = hull[i], hull[j] return max_dist, p1_max, p2_max
defpolygon_area(points): n = len(points) area = 0 for i inrange(n): j = (i + 1) % n area += points[i][0] * points[j][1] area -= points[j][0] * points[i][1] returnabs(area) / 2
defcircle_intersection(c1, r1, c2, r2): dx = c2[0] - c1[0] dy = c2[1] - c1[1] d = (dx**2 + dy**2)**0.5 if d > r1 + r2: return [] if d < abs(r1 - r2): return [] if d == 0and r1 == r2: returnNone a = (r1**2 - r2**2 + d**2) / (2 * d) h = (r1**2 - a**2)**0.5 px = c1[0] + a * dx / d py = c1[1] + a * dy / d if d == r1 + r2 or d == abs(r1 - r2): return [(px, py)] rx = -h * dy / d ry = h * dx / d return [(px + rx, py + ry), (px - rx, py - ry)]
defcalculate_work_time(records): records.sort(key=lambda x: (x[0], x[2])) result = {} i = 0 while i < len(records): emp_id = records[i][0] date = records[i][1] in_time = None out_time = None while i < len(records) and records[i][0] == emp_id and records[i][1] == date: if records[i][2] == 'in': in_time = records[i][3] else: out_time = records[i][3] i += 1 if in_time and out_time: work_hours = out_time - in_time result[emp_id] = result.get(emp_id, 0) + work_hours return result
defmin_partition_diff(arr): total = sum(arr) n = len(arr) target = total // 2 dp = [False] * (target + 1) dp[0] = True for num in arr: for j inrange(target, num - 1, -1): dp[j] = dp[j] or dp[j - num] for j inrange(target, -1, -1): if dp[j]: return total - 2 * j return total
defpartition_with_trace(arr): total = sum(arr) n = len(arr) target = total // 2 dp = [[False] * (target + 1) for _ inrange(n + 1)] for i inrange(n + 1): dp[i][0] = True for i inrange(1, n + 1): for j inrange(target + 1): dp[i][j] = dp[i - 1][j] if j >= arr[i - 1]: dp[i][j] = dp[i][j] or dp[i - 1][j - arr[i - 1]] for j inrange(target, -1, -1): if dp[n][j]: part1_sum = j part2_sum = total - j part1 = [] remaining = j for i inrange(n, 0, -1): if remaining >= arr[i - 1] and dp[i - 1][remaining - arr[i - 1]]: part1.append(arr[i - 1]) remaining -= arr[i - 1] part2 = [x for x in arr if x notin part1 or part1.remove(x) orTrue] returnabs(part1_sum - part2_sum), part1, part2 return total, arr, []
defsum_of_divisor_squares(n): result = 0 for i inrange(1, int(n**0.5) + 1): count = n // i if i * i <= n: result += i * i * count if i != n // i: j = n // i result += j * j return result
defsum_of_divisor_squares_optimized(n): result = 0 i = 1 while i <= n: q = n // i j = n // q count = j - i + 1 sum_i_to_j = (i + j) * count // 2 result += sum_i_to_j * q * q i = j + 1 return result
n = 10 print(f"因数平方和 (暴力): {sum_of_divisor_squares(n)}") print(f"因数平方和 (优化): {sum_of_divisor_squares_optimized(n)}")
defmin_add_to_make_valid(s): left_needed = 0 right_needed = 0 for c in s: if c == '(': right_needed += 1 elif c == ')': if right_needed > 0: right_needed -= 1 else: left_needed += 1 return left_needed + right_needed
defgenerate_valid_parentheses_with_fix(s): n = len(s) dp = [[None] * n for _ inrange(n)] defsolve(l, r): if l > r: return [''] if dp[l][r] isnotNone: return dp[l][r] result = [] if s[l] in'([': for mid inrange(l, r + 1): if (s[l] == '('and s[mid] == ')') or (s[l] == '['and s[mid] == ']'): left_parts = solve(l + 1, mid - 1) right_parts = solve(mid + 1, r) for lp in left_parts: for rp in right_parts: result.append(s[l] + lp + s[mid] + rp) ifnot result: for mid inrange(l, r): left_parts = solve(l, mid) right_parts = solve(mid + 1, r) for lp in left_parts: for rp in right_parts: result.append(lp + rp) ifnot result: result = [s[l]] dp[l][r] = result return result return solve(0, n - 1)
s = "(()))" print(f"最少添加括号数: {min_add_to_make_valid(s)}")
defsubstring_score(s): n = len(s) total = 0 for i inrange(n): seen = set() for j inrange(i, n): seen.add(s[j]) total += len(seen) return total
defsubstring_score_optimized(s): n = len(s) total = 0 last_pos = {} for i inrange(n): c = s[i] prev = last_pos.get(c, -1) contribution = (i - prev) * (n - i) total += contribution last_pos[c] = i return total
s = "ababc" print(f"子串分值 (暴力): {substring_score(s)}") print(f"子串分值 (优化): {substring_score_optimized(s)}")
defsocial_network_analysis(adj, n): from collections import deque defbfs_distances(start): dist = [-1] * n dist[start] = 0 queue = deque([start]) while queue: u = queue.popleft() for v in adj[u]: if dist[v] == -1: dist[v] = dist[u] + 1 queue.append(v) return dist all_distances = [bfs_distances(i) for i inrange(n)] avg_distances = [] for i inrange(n): valid = [d for d in all_distances[i] if d > 0] avg = sum(valid) / len(valid) if valid else0 avg_distances.append(avg) diameter = max(max(d) for d in all_distances) deffind_clusters(): visited = [False] * n clusters = [] for start inrange(n): if visited[start]: continue cluster = [] queue = deque([start]) visited[start] = True while queue: u = queue.popleft() cluster.append(u) for v in adj[u]: ifnot visited[v]: visited[v] = True queue.append(v) clusters.append(cluster) return clusters clusters = find_clusters() degrees = [len(adj[i]) for i inrange(n)] return { 'avg_distances': avg_distances, 'diameter': diameter, 'clusters': clusters, 'degrees': degrees, 'most_central': avg_distances.index(min(avg_distances)) if avg_distances else -1, 'most_connected': degrees.index(max(degrees)) if degrees else -1 }
defhandle_large_input(): import sys sys.setrecursionlimit(10**6) data = sys.stdin.read().split() it = iter(data) n = int(next(it)) arr = [int(next(it)) for _ inrange(n)] return arr
defprecision_handling(): from decimal import Decimal, getcontext getcontext().prec = 50 a = Decimal('0.1') b = Decimal('0.2') c = a + b return c
defmodular_arithmetic(): MOD = 10**9 + 7 defadd(a, b): return (a + b) % MOD defmul(a, b): return (a * b) % MOD defpow_mod(base, exp): result = 1 while exp: if exp & 1: result = mul(result, base) base = mul(base, base) exp >>= 1 return result return add, mul, pow_mod
defuse_local_variables(): import math local_sqrt = math.sqrt result = [local_sqrt(i) for i inrange(1000)] return result
defavoid_repeated_computation(): cache = {} defexpensive_function(x): if x in cache: return cache[x] result = x * x + 2 * x + 1 cache[x] = result return result return expensive_function
deflist_comprehension_optimization(): squares = [x**2for x inrange(100)] even_squares = [x**2for x inrange(100) if x % 2 == 0] matrix = [[i * j for j inrange(10)] for i inrange(10)] return squares, even_squares, matrix
defgenerator_for_memory(): deffibonacci_generator(n): a, b = 0, 1 for _ inrange(n): yield a a, b = b, a + b returnlist(fibonacci_generator(20))
defsimple_test(func, test_cases): passed = 0 for i, (input_data, expected) inenumerate(test_cases): try: ifisinstance(input_data, tuple): result = func(*input_data) else: result = func(input_data) if result == expected: passed += 1 print(f"测试 {i+1}: 通过") else: print(f"测试 {i+1}: 失败") print(f" 输入: {input_data}") print(f" 期望: {expected}") print(f" 实际: {result}") except Exception as e: print(f"测试 {i+1}: 错误 - {e}") print(f"\n通过率: {passed}/{len(test_cases)}") return passed == len(test_cases)
defperformance_test(func, *args): import time start = time.time() result = func(*args) end = time.time() print(f"执行时间: {end - start:.4f} 秒") return result
defmemory_usage(): import sys defget_size(obj): return sys.getsizeof(obj) arr = [i for i inrange(1000)] print(f"列表内存: {get_size(arr)} 字节") return get_size