2019 Multi-University Training Contest 10 1011 Make Rounddog Happy

HDU 6701 Make Rounddog Happy
题意 : 给你$n$个数,和$k$,找到区间[l,r] $max(a_l,\dots,a_r)-(r-l+1)<=k$ 的数量(区间内不能出现有相同数字)。
题解: 相当于找区间长度大于区间最大值-k 的区间数量,第一个想到入手的肯定就是区间最大值,然后枚举左边或右边的边界,然后找另外一边的的上界。
假设寻找区间最大值,st表能够实现$O(1)$的查找,假设我们找到[l,r]区间的最大值下标为MID,这个MID很显然不会恰好在正中间,那么我们肯定是向里边界近的方向枚举,然后查询另一端的情况。(这个叫启发式分治 队友告诉我是$O(nlog(n))$) 。假设枚举右端下标为当前下标i,另一端的情况怎么获取呢,$O(n)$找肯定不行,另一端上界up肯定是$i-a[MID]-k$,下届呢?从i开始第一个出现相同数字的位置。这个可以$O(n)$ 预处理出来,每个位置上的数上一次出现的位置可以直接求出来,然后从某一个位置到第一个出现相同值的下标肯定是单调的,所以能$O(n)$预处理出第一个从i位置往前最长不出现相同值的位置,和向后最长不出现相同值的位置 这个地方卡了点常,我用2个ST表卡极限时间过了,实际上两个数组就能解决,找区间最大值实际上也能用单调栈解决可以不用ST表。画个图理解一下
在这里插入图片描述
如果是,枚举左边一样的处理。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
#include "bits/stdc++.h"

using namespace std;
typedef long long LL;
typedef unsigned long long uLL;
typedef pair<int, int> P;
#define VNAME(value) (#value)
#define bug printf("*********\n");
#define debug(x) cout<<"["<<VNAME(x)<<" = "<<x<<"]"<<endl;
#define mid ((l + r) >> 1)
#define chl 2 * k + 1
#define chr 2 * k + 2
#define lson l, mid, chl
#define rson mid + 1, r, chr
#define eb(x) emplace_back(x)
#define pb(x) emplace_back(x)
#define mem(a, b) memset(a, b, sizeof(a));

const LL mod = (LL) 1e9 + 7;
const int maxn = (int) 3e5 + 5;
const LL INF = 0x7fffffff;
const LL inf = 0x3f3f3f3f;
const double eps = 1e-8;

#ifndef ONLINE_JUDGE
clock_t prostart = clock();
#endif

void f() {
#ifndef ONLINE_JUDGE
freopen("../data.in", "r", stdin);
#endif
}

//typedef __int128 LLL;

void read(int &w) {//读入
char c;
while (!isdigit(c = getchar()));
w = c & 15;
while (isdigit(c = getchar()))
w = w * 10 + (c & 15);
}

void output(LL x) {
if (x < 0)
putchar('-'), x = -x;
int ss[55], sp = 0;
do
ss[++sp] = x % 10;
while (x /= 10);
while (sp)
putchar(48 + ss[sp--]);
}

const int MXN = 3e5 + 10;
int a[maxn];
int pre[maxn], nxt[maxn];
int dp[MXN][20], pos[MXN][20];

int lg[maxn];

void init(int n) {
int LOG = lg[n] + 1;
for (int j = 1; j < LOG; ++j) {
if ((1 << (j - 1)) > n) break;
for (int i = 1; i + (1 << j) - 1 <= n; ++i) {
if (dp[i][j - 1] >= dp[i + (1 << (j - 1))][j - 1]) {
dp[i][j] = dp[i][j - 1];
pos[i][j] = pos[i][j - 1];
} else {
dp[i][j] = dp[i + (1 << (j - 1))][j - 1];
pos[i][j] = pos[i + (1 << (j - 1))][j - 1];
}
}
}
}

inline int query(int l, int r) {
int k = lg[r - l + 1];
if (dp[l][k] >= dp[r - (1 << k) + 1][k]) return pos[l][k];
return pos[r - (1 << k) + 1][k];
}

inline int query1(int l, int r) {
return pre[r];
}

inline int query2(int l, int r) {
return nxt[l];
}

int n, k;
LL ans = 0;

void solve(int l, int r) {
if (l > r) return;
if (l == r) {
if (a[l] - 1 <= k) ++ans;
return;
}
int MID = query(l, r);//最大值的位置
if (r - MID > MID - l) {
int up = min(query2(MID, r) - 1, r), low;
for (int i = MID; i >= l; --i) {//枚举左端点
up = min(nxt[i] - 1, up);
low = i + (a[MID] - k) - 1;
low = max(low, MID);
if (up < MID)break;
if (low > up)continue;
else {
ans += up - low + 1;
}
}
} else {
int up, low = max(query1(l, MID) + 1, l);
for (int i = MID; i <= r; ++i) {//枚举右端点
low = max(pre[i] + 1, low);
up = i - (a[MID] - k) + 1;
up = min(up, MID);
if (low > MID)break;
if (low > up)continue;
else {
ans += up - low + 1;
}
}
}
solve(l, MID - 1);
solve(MID + 1, r);
}

int p[maxn];

int main() {
f();
int T;
read(T);
for (int i = 1; i <= 3e5; i++) {
lg[i] = log2(i);
}
while (T--) {
read(n);
read(k);
for (int i = 1; i <= n; i++) {
read(a[i]);
dp[i][0] = a[i];
pos[i][0] = i;
}
for (int i = 1; i <= n; i++) {
p[i] = 0;
}
for (int i = 1; i <= n; i++) {
pre[i] = p[a[i]];
if (i != 1)pre[i] = max(pre[i], pre[i - 1]);
p[a[i]] = i;
}
for (int i = 1; i <= n; i++) {
p[i] = n + 1;
}
for (int i = n; i >= 1; i--) {
nxt[i] = p[a[i]];
if (i != n)nxt[i] = min(nxt[i], nxt[i + 1]);
p[a[i]] = i;
}
for (int i = 1; i <= n; ++i) {
// dp1[i][0] = pre[i];
// dp2[i][0] = nxt[i];
}
init(n);
ans = 0;
solve(1, n);
output(ans);
puts("");
}
return 0;
}