单调栈基础用法

适用场景

在一个一维数组中,对每个元素,寻找他的下一个或前一个最近的最大值或最小值的坐标或值。

若不考虑单调栈,采用暴力方法,对每个元素,分别向左向右遍历,找到第一个满足条件的元素就break,即可求解。暴力的方法的时间复杂度是O(n^2)

单调栈维护一个单调递增或递减的栈,以寻找下一个最小值的坐标为例,即在栈里维护大压小的顺序。一般来说,使用单调栈右三个阶段:

  1. 遍历阶段。遍历每一个元素,满足条件则入栈,不满足大压小的条件则弹出栈顶元素,每次弹出,则结算栈顶元素的答案。此时,使栈顶元素弹出的元素为,该栈顶元素的右答案,栈顶元素的下一个栈中元素为左答案。
  2. 清算阶段。如果栈中有剩余元素,依次弹出,每个弹出的元素的右答案为-1,如果栈顶元素不是最后一个元素,则左答案为栈中下一个元素,否则,为-1
  3. 修正阶段。如果数组中出现重复的元素,则需要修正答案。除最后一个元素外,倒序对每一个元素的右答案进行修正。如果nums[右答案]和元素的值相同,则修正该答案为ans[ans[i][1]][1],即右答案的右答案。

因为每一个元素进出栈一次,所以时间复杂度为O(n)

注意,单调栈维护的顺序和题目条件有关,例如,如果要找前一个或下一个最近的最大值坐标。则需要维护小压大的顺序。

算法模版

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
#define MAXN 1000001

static int nums[MAXN];
static int stack[MAXN];
static int ans[MAXN][2];
static int n, r;

void f() {
r = 0;
int cur;
for (int i = 0; i < n; ++i) {
// 当栈不为空且入栈元素大于栈顶元素时
while (r > 0 && nums[stack[r-1]] >= nums[i]) {
// 弹出该元素,结算答案
cur = stack[--r];
// 压在下面的元素为左答案,使其弹出的元素为右答案
ans[cur][0] = r > 0 ? stack[r-1] : -1;
ans[cur][1] = i;
}
stack[r++] = i;
}
// 如果栈中还有元素
while (r > 0) {
cur = stack[--r];
// 此时右答案为-1 左答案为下一个元素
ans[cur][0] = r > 0 ? stack[r-1] : -1;
ans[cur][1] = -1;
}
// 修正答案 最后一个元素可以不考虑,因为都为-1
for (int i = n - 2; i >= 0; --i) {
if (ans[i][1] != -1 && nums[ans[i][1]] == nums[i]) {
ans[i][1] = ans[ans[i][1]][1];
}
}
// 输出
for (int i = 0; i < n; ++i) {
cout << ans[i][0] << ' ' << ans[i][1] << endl;
}
}

例题分析

496. 下一个更大元素 I

nums1 中数字 x 的 下一个更大元素 是指 xnums2 中对应位置 右侧 的 第一个 比 x 大的元素。

给你两个 没有重复元素 的数组 nums1nums2 ,下标从 0 开始计数,其中nums1nums2 的子集。

对于每个 0 <= i < nums1.length ,找出满足 nums1[i] == nums2[j] 的下标 j ,并且在 nums2 确定 nums2[j] 的 下一个更大元素 。如果不存在下一个更大元素,那么本次查询的答案是 -1

返回一个长度为 nums1.length 的数组 ans 作为答案,满足 ans[i] 是如上所述的 下一个更大元素 。

对于nums2,可以先求出每一个元素的下一个最大元素的位置,然后我们对nums2的值和下标做一个映射,为了方便查找nums1中值对应的位置。

求出每一个元素的下一个最大元素位置的过程,可以考虑使用单调栈。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
class Solution {
public:
int stk[1001];
vector<int> nextGreaterElement(vector<int>& nums1, vector<int>& nums2) {
unordered_map<int, int> m;
for (int i = 0; i < nums2.size(); ++i) {
m[nums2[i]] = i;
}
int n1 = nums1.size();
vector<int> ans(n1, -1);
int r = 0, n2 = nums2.size();
vector<int> next(n2, -1);
for (int i = 0; i < n2; ++i) {
while (r > 0 && nums2[i] > nums2[stk[r-1]]) {
int cur = stk[--r];
next[cur] = nums2[i];
}
stk[r++] = i;
}
for (int i = 0; i < n1; ++i) {
int pos = m[nums1[i]];
ans[i] = next[pos];
}
return ans;
}
};