HNOI2016 序列

使用莫队解决本题。考虑计算以每个数为终点的所有子串的最小值之和。$f(i)=f(pre(i))+(i-pre(i))*a_i$ 这里 $pre$ 代表前面第一个比 $a_i$ 小的数字的位置。这个可以用 ST 表加二分做到 $\mathcal O(n\times \log_2 n)$ 预处理。然后莫队转移区间时怎么计算右边新加入的数字造成的贡献?设当前q区间中最小数的位置为 $pos$,权值为 $val$。那么所有右端点为新加入的点,左端点在 $pos$ 左边(包括 $pos$)的子区间的z最小值都是 $val$,这部分区间的答案就是 $val\times (pos-l+1)$。对于左端点在 $pos$ 右边的子区间,$f(r)-f(pos)$ 即为答案。因为 $f(r)$ 中所有左端点超过 $pos$ 的也就是不需要的子区间的答案其实就是 $f(pos)$,画个图就很好理解。对于其他情况如删除也同理。注意在莫队转移的时候要先添加点再删除。

#include <iostream>
#include <cstdio>
#include <cstring>
#include <cmath>
#include <algorithm>

typedef long long LL;
#define int LL

const int MAXN = 1e5;

int n, m, blockSize, allL, allR, allAns;
int a[MAXN | 1], pre[MAXN | 1], nxt[MAXN | 1], rmqTable[18][MAXN | 1], rmqTablePos[18][MAXN | 1], lg[MAXN | 1];
int dp[2][MAXN | 1], ans[MAXN | 1];

struct Query {
  int l, r, belong, id;
  Query() : l(0), r(0), belong(0), id(0) {}
  friend bool operator< (const Query &lhs, const Query &rhs) {
    return (lhs.belong < rhs.belong) || (lhs.belong == rhs.belong && lhs.r < rhs.r);
  }
} que[MAXN | 1];

inline int read() {
  register int x = 0, v = 1;
  register char ch = getchar();
  while (!isdigit(ch)) {
    if (ch == '-') v = -1;
    ch = getchar();
  }
  while (isdigit(ch)) {
    x = x * 10 + ch - '0';
    ch = getchar();
  }
  return x * v;
}

inline int calcMin(int l, int r) {
  return std::min(rmqTable[lg[r - l + 1]][l], rmqTable[lg[r - l + 1]][r - (1 << lg[r - l + 1]) + 1]);
}

inline int calcMinPos(int l, int r) {
  return rmqTable[lg[r - l + 1]][l] < rmqTable[lg[r - l + 1]][r - (1 << lg[r - l + 1]) + 1] ? rmqTablePos[lg[r - l + 1]][l] : rmqTablePos[lg[r - l + 1]][r - (1 << lg[r - l + 1]) + 1];
}

void getPreAndNxt() {
  for (int i = 1; i <= n; ++i) {
    int l = 1, r = i, mid;
    pre[i] = 0;
    while (l <= r) {
      mid = (l + r) >> 1;
      if (calcMin(mid, i) < a[i]) {
        pre[i] = mid;
        l = mid + 1;
      } else r = mid - 1;
    }
  }
  for (int i = n; i >= 1; --i) {
    int l = i, r = n, mid;
    nxt[i] = n + 1;
    while (l <= r) {
      mid = (l + r) >> 1;
      if (calcMin(i, mid) < a[i]) {
        nxt[i] = mid;
        r = mid - 1;
      } else l = mid + 1;
    }
  }
  for (int i = n; i >= 1; --i) dp[0][i] = dp[0][nxt[i]] + (nxt[i] - i) * a[i];
  for (int i = 1; i <= n; ++i) dp[1][i] = dp[1][pre[i]] + (i - pre[i]) * a[i];
}

void add(int x, int opt) {
  if (opt == 0) {
    int minVal = calcMin(allL, allR), minPos = calcMinPos(allL, allR);
    allAns += dp[0][x] - dp[0][minPos] + minVal * (allR - minPos + 1);
  } else {
    int minVal = calcMin(allL, allR), minPos = calcMinPos(allL, allR);
    allAns += dp[1][x] - dp[1][minPos] + minVal * (minPos - allL + 1);
  }
}

void del(int x, int opt) {
  if (opt == 0) {
    int minVal = calcMin(allL, allR), minPos = calcMinPos(allL, allR);
    allAns -= dp[0][x] - dp[0][minPos] + minVal * (allR - minPos + 1);
  } else {
    int minVal = calcMin(allL, allR), minPos = calcMinPos(allL, allR);
    allAns -= dp[1][x] - dp[1][minPos] + minVal * (minPos - allL + 1);
  }
}

signed main() {
  n = read();
  m = read();
  blockSize = sqrt(n);
  memset(rmqTable, 0x3f, sizeof(rmqTable));
  for (int i = 1; i <= n; ++i) a[i] = read(), rmqTable[0][i] = a[i], rmqTablePos[0][i] = i;
  for (int i = 1; i < 18; ++i) {
    for (int j = 1; j + (1 << i) - 1 <= n; ++j) {
      if (rmqTable[i - 1][j] < rmqTable[i - 1][j + (1 << (i - 1))]) {
        rmqTable[i][j] = rmqTable[i - 1][j];
        rmqTablePos[i][j] = rmqTablePos[i - 1][j];
      } else {
        rmqTable[i][j] = rmqTable[i - 1][j + (1 << (i - 1))];
        rmqTablePos[i][j] = rmqTablePos[i - 1][j + (1 << (i - 1))];
      }
    }
  }
  lg[1] = 0;
  for (int i = 2; i <= n; ++i) lg[i] = lg[i >> 1] + 1;
  getPreAndNxt();
  for (int i = 1; i <= m; ++i) {
    que[i].l = read();
    que[i].r = read();
    que[i].id = i;
    que[i].belong = (que[i].l - 1) / blockSize + 1;
  }
  std::sort(que + 1, que + 1 + m);
  int &l = allL, &r = allR;
  l = 1;
  r = 1;
  allAns = a[1];
  for (int i = 1; i <= m; ++i) {
    while (l > que[i].l) add(--l, 0);
    while (r < que[i].r) add(++r, 1);

    while (l < que[i].l) del(l, 0), ++l;
    while (r > que[i].r) del(r, 1), --r;
    ans[que[i].id] = allAns;
  }
  for (int i = 1; i <= m; ++i) printf("%lld\n", ans[i]);
  return 0;
}
最后修改:2019 年 07 月 30 日 11 : 30 PM

发表评论