2547-拆分数组的最小代价

Raphael Liu Lv10

给你一个整数数组 nums 和一个整数 k

将数组拆分成一些非空子数组。拆分的 代价 是每个子数组中的 重要性 之和。

trimmed(subarray) 作为子数组的一个特征,其中所有仅出现一次的数字将会被移除。

  • 例如,trimmed([3,1,2,4,3,4]) = [3,4,3,4]

子数组的 重要性 定义为 k + trimmed(subarray).length

  • 例如,如果一个子数组是 [1,2,3,3,3,4,4]trimmed([1,2,3,3,3,4,4]) = [3,3,3,4,4] 。这个子数组的重要性就是 k + 5

找出并返回拆分 nums 的所有可行方案中的最小代价。

子数组 是数组的一个连续 非空 元素序列。

示例 1:

**输入:** nums = [1,2,1,2,1,3,3], k = 2
**输出:** 8
**解释:** 将 nums 拆分成两个子数组:[1,2], [1,2,1,3,3]
[1,2] 的重要性是 2 + (0) = 2 。
[1,2,1,3,3] 的重要性是 2 + (2 + 2) = 6 。
拆分的代价是 2 + 6 = 8 ,可以证明这是所有可行的拆分方案中的最小代价。

示例 2:

**输入:** nums = [1,2,1,2,1], k = 2
**输出:** 6
**解释:** 将 nums 拆分成两个子数组:[1,2], [1,2,1] 。
[1,2] 的重要性是 2 + (0) = 2 。
[1,2,1] 的重要性是 2 + (2) = 4 。
拆分的代价是 2 + 4 = 6 ,可以证明这是所有可行的拆分方案中的最小代价。

示例 3:

**输入:** nums = [1,2,1,2,1], k = 5
**输出:** 10
**解释:** 将 nums 拆分成一个子数组:[1,2,1,2,1].
[1,2,1,2,1] 的重要性是 5 + (3 + 2) = 10 。
拆分的代价是 10 ,可以证明这是所有可行的拆分方案中的最小代价。

提示:

  • 1 <= nums.length <= 1000
  • 0 <= nums[i] < nums.length
  • 1 <= k <= 109

方法一:O(n^2) 划分型动态规划

如何思考

划分出第一个子数组,问题变成一个规模更小的子问题。

由于「划分出长为 x 和 y 的子数组」和「划分出长为 y 和 x 的子数组」之后,剩余的子问题是相同的,因此这题适合用动态规划解决。

附:视频讲解

具体算法

定义 f[i+1] 表示划分 nums 的前 i 个数的最小代价,从 i 开始倒序枚举最后一个子数组的开始位置 j,同时用一个数组 state 维护每个元素的出现次数,用一个变量 unique 维护只出现一次的元素个数。

具体来说:

  • state}[x]=0 表示 x 出现 0 次;
  • state}[x]=1 表示 x 出现 1 次;
  • state}[x]=2 表示 x 出现超过 1 次。
  • 如果 x 首次遇到,那么 unique 加一,state}[x]=1;
  • 如果 x 第二次遇到,那么 unique 减一,state}[x]=2。

经测试,这种写法比直接计算特征要更快一些,尤其是 Python。

重要性为子数组的长度减去只出现一次的元素个数加 k,即

i-j+1 - \textit{unique}_{j,i} + k

这里 unique}_{j,i 表示枚举到 j 时的 unique 值。

加上前面子数组的最小代价,所有结果取最小值,得

\begin{aligned}
f[i+1] &= \min\limits_{j=0}^{i} f[j] + i-j+1 - \textit{unique}{j,i} + k \
&= i+1+k+ \min\limits
{j=0}^{i} f[j] -j - \textit{unique}_{j,i}
\end{aligned}

初始值 f[0] = 0,答案为 f[n]。

优化

注意到 f[j] 每次都要减去 j,而 f[i+1] 最后还要加上 i+1,如果定义 f’[i] = f[i]-i,则有

f’[i+1] = k+\min\limits_{j=0}^{i} f’[j] - \textit{unique}_{j,i}

由于 f’[n] = f[n]-n,所以最后答案为 f’[n]+n。

[sol1-Python3]
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
class Solution:
def minCost(self, nums: List[int], k: int) -> int:
n = len(nums)
f = [0] * (n + 1)
for i in range(n):
state, unique, mn = [0] * n, 0, inf
for j in range(i, -1, -1):
x = nums[j]
if state[x] == 0: # 首次出现
state[x] = 1
unique += 1
elif state[x] == 1: # 不再唯一
state[x] = 2
unique -= 1
mn = min(mn, f[j] - unique)
# if f[j]-unique < mn: mn = f[j]-unique # 手写 min 会快很多
f[i + 1] = k + mn
return f[n] + n
[sol1-Java]
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
class Solution {
public int minCost(int[] nums, int k) {
int n = nums.length;
int[] f = new int[n + 1];
byte[] state = new byte[n];
for (int i = 0; i < n; ++i) {
Arrays.fill(state, (byte) 0);
int unique = 0, mn = Integer.MAX_VALUE;
for (int j = i; j >= 0; --j) {
int x = nums[j];
if (state[x] == 0) { // 首次出现
state[x] = 1;
++unique;
} else if (state[x] == 1) { // 不再唯一
state[x] = 2;
--unique;
}
mn = Math.min(mn, f[j] - unique);
}
f[i + 1] = k + mn;
}
return f[n] + n;
}
}
[sol1-C++]
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
class Solution {
public:
int minCost(vector<int> &nums, int k) {
int n = nums.size(), f[n + 1];
f[0] = 0;
int8_t state[n];
for (int i = 0; i < n; ++i) {
memset(state, 0, sizeof(state));
int unique = 0, mn = INT_MAX;
for (int j = i; j >= 0; --j) {
int x = nums[j];
if (state[x] == 0) state[x] = 1, ++unique; // 首次出现
else if (state[x] == 1) state[x] = 2, --unique; // 不再唯一
mn = min(mn, f[j] - unique);
}
f[i + 1] = k + mn;
}
return f[n] + n;
}
};
[sol1-Go]
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
func minCost(nums []int, k int) int {
n := len(nums)
f := make([]int, n+1)
for i := 0; i < n; i++ {
state, unique, mn := make([]int8, n), 0, math.MaxInt
for j := i; j >= 0; j-- {
x := nums[j]
if state[x] == 0 { // 首次出现
state[x] = 1
unique++
} else if state[x] == 1 { // 不再唯一
state[x] = 2
unique--
}
mn = min(mn, f[j]-unique)
}
f[i+1] = mn + k
}
return f[n] + n
}

func min(a, b int) int { if a > b { return b }; return a }

复杂度分析

  • 时间复杂度:O(n^2),其中 n 为 nums 的长度。
  • 空间复杂度:O(n)。

方法二:线段树优化

记 x 上一次出现的位置是 last}[x],上上一次出现的位置是 last}_2[x]。

如果从左到右枚举 x=\textit{nums}[i],那么:

  • 区间 [\textit{last}[x]+1,i] 内的数的 unique 都加一;
  • 区间 [\textit{last}_2[x]+1,\textit{last}[x]] 内的数的 unique 都减一,相当于把之前的加一撤销掉(如果 last}_2[x] 不存在则不更新)。

此外,我们求的是一段区间内的 f[j]-\textit{unique}{j,i 的最小值,这可以用线段树优化(区间更新,区间查询)。注意 unique}{j,i 前面是负号,所以上面的区间更新值要取反。

代码实现时,由于枚举 nums}[i] 时,更新的是 f[i+1],我们可以在上一轮循环时把它记录下来,在下一轮循环时去把它加到线段树中。

[sol2-Python3]
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
class Solution:
def minCost(self, nums: List[int], k: int) -> int:
# Lazy 线段树模板(区间加,查询区间最小)
n = len(nums)
mn = [0] * (4 * n)
todo = [0] * (4 * n)

def do(o: int, v: int) -> None:
mn[o] += v
todo[o] += v

def spread(o: int) -> None:
v = todo[o]
if v:
do(o * 2, v)
do(o * 2 + 1, v)
todo[o] = 0

# 区间 [L,R] 内的数都加上 v o,l,r=1,1,n
def update(o: int, l: int, r: int, L: int, R: int, v: int) -> None:
if L <= l and r <= R:
do(o, v)
return
spread(o)
m = (l + r) // 2
if m >= L: update(o * 2, l, m, L, R, v)
if m < R: update(o * 2 + 1, m + 1, r, L, R, v)
mn[o] = min(mn[o * 2], mn[o * 2 + 1])

# 查询区间 [L,R] 的最小值 o,l,r=1,1,n
def query(o: int, l: int, r: int, L: int, R: int) -> int:
if L <= l and r <= R:
return mn[o]
spread(o)
m = (l + r) // 2
if m >= R: return query(o * 2, l, m, L, R)
if m < L: return query(o * 2 + 1, m + 1, r, L, R)
return min(query(o * 2, l, m, L, R), query(o * 2 + 1, m + 1, r, L, R))

ans = 0
last = [0] * n
last2 = [0] * n
for i, x in enumerate(nums, 1):
update(1, 1, n, i, i, ans) # 相当于设置 f[i+1] 的值
update(1, 1, n, last[x] + 1, i, -1)
if last[x]: update(1, 1, n, last2[x] + 1, last[x], 1)
ans = k + query(1, 1, n, 1, i)
last2[x] = last[x]
last[x] = i
return ans + n
[sol2-Java]
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
class Solution {
public int minCost(int[] nums, int k) {
int n = nums.length, ans = 0;
mn = new int[n * 4];
todo = new int[n * 4];
int[] last = new int[n], last2 = new int[n];
for (int i = 1; i <= n; ++i) {
int x = nums[i - 1];
update(1, 1, n, i, i, ans); // 相当于设置 f[i+1] 的值
update(1, 1, n, last[x] + 1, i, -1);
if (last[x] > 0) update(1, 1, n, last2[x] + 1, last[x], 1);
ans = k + query(1, 1, n, 1, i);
last2[x] = last[x];
last[x] = i;
}
return ans + n;
}

// Lazy 线段树模板(区间加,查询区间最小)
private int[] mn, todo;

private void do_(int o, int v) {
mn[o] += v;
todo[o] += v;
}

private void spread(int o) {
int v = todo[o];
if (v != 0) {
do_(o * 2, v);
do_(o * 2 + 1, v);
todo[o] = 0;
}
}

// 区间 [L,R] 内的数都加上 v o,l,r=1,1,n
private void update(int o, int l, int r, int L, int R, int v) {
if (L <= l && r <= R) {
do_(o, v);
return;
}
spread(o);
int m = (l + r) / 2;
if (m >= L) update(o * 2, l, m, L, R, v);
if (m < R) update(o * 2 + 1, m + 1, r, L, R, v);
mn[o] = Math.min(mn[o * 2], mn[o * 2 + 1]);
}

// 查询区间 [L,R] 的最小值 o,l,r=1,1,n
private int query(int o, int l, int r, int L, int R) {
if (L <= l && r <= R)
return mn[o];
spread(o);
int m = (l + r) / 2;
if (m >= R) return query(o * 2, l, m, L, R);
if (m < L) return query(o * 2 + 1, m + 1, r, L, R);
return Math.min(query(o * 2, l, m, L, R), query(o * 2 + 1, m + 1, r, L, R));
}
}
[sol2-C++]
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
// Lazy 线段树模板(区间加,查询区间最小)
const int MAX_N = 1000;
int mn[MAX_N * 4], todo[MAX_N * 4], last[MAX_N], last2[MAX_N];

class Solution {
void do_(int o, int v) {
mn[o] += v;
todo[o] += v;
}

void spread(int o) {
int v = todo[o];
if (v) {
do_(o * 2, v);
do_(o * 2 + 1, v);
todo[o] = 0;
}
}

// 区间 [L,R] 内的数都加上 v o,l,r=1,1,n
void update(int o, int l, int r, int L, int R, int v) {
if (L <= l && r <= R) {
do_(o, v);
return;
}
spread(o);
int m = (l + r) / 2;
if (m >= L) update(o * 2, l, m, L, R, v);
if (m < R) update(o * 2 + 1, m + 1, r, L, R, v);
mn[o] = min(mn[o * 2], mn[o * 2 + 1]);
}

// 查询区间 [L,R] 的最小值 o,l,r=1,1,n
int query(int o, int l, int r, int L, int R) {
if (L <= l && r <= R)
return mn[o];
spread(o);
int m = (l + r) / 2;
if (m >= R) return query(o * 2, l, m, L, R);
if (m < L) return query(o * 2 + 1, m + 1, r, L, R);
return min(query(o * 2, l, m, L, R), query(o * 2 + 1, m + 1, r, L, R));
}

public:
int minCost(vector<int> &nums, int k) {
int n = nums.size(), ans = 0;
memset(mn, 0, sizeof(int) * 4 * n);
memset(todo, 0, sizeof(int) * 4 * n);
memset(last, 0, sizeof(int) * n);
memset(last2, 0, sizeof(int) * n);
for (int i = 1; i <= n; ++i) {
int x = nums[i - 1];
update(1, 1, n, i, i, ans); // 相当于设置 f[i+1] 的值
update(1, 1, n, last[x] + 1, i, -1);
if (last[x]) update(1, 1, n, last2[x] + 1, last[x], 1);
ans = k + query(1, 1, n, 1, i);
last2[x] = last[x];
last[x] = i;
}
return ans + n;
}
};
[sol2-Go]
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
// Lazy 线段树模板(区间加,查询区间最小)
type seg []struct{ l, r, min, todo int }

func (t seg) build(o, l, r int) {
t[o].l, t[o].r = l, r
if l == r {
return
}
m := (l + r) >> 1
t.build(o<<1, l, m)
t.build(o<<1|1, m+1, r)
}

func (t seg) do(o, v int) {
t[o].min += v
t[o].todo += v
}

func (t seg) spread(o int) {
if v := t[o].todo; v != 0 {
t.do(o<<1, v)
t.do(o<<1|1, v)
t[o].todo = 0
}
}

// 区间 [l,r] 内的数都加上 v o=1
func (t seg) update(o, l, r, v int) {
if l <= t[o].l && t[o].r <= r {
t.do(o, v)
return
}
t.spread(o)
m := (t[o].l + t[o].r) >> 1
if l <= m {
t.update(o<<1, l, r, v)
}
if m < r {
t.update(o<<1|1, l, r, v)
}
t[o].min = min(t[o<<1].min, t[o<<1|1].min)
}

// 查询区间 [l,r] 的最小值 o=1
func (t seg) query(o, l, r int) int {
if l <= t[o].l && t[o].r <= r {
return t[o].min
}
t.spread(o)
m := (t[o].l + t[o].r) >> 1
if r <= m {
return t.query(o<<1, l, r)
}
if l > m {
return t.query(o<<1|1, l, r)
}
return min(t.query(o<<1, l, r), t.query(o<<1|1, l, r))
}

func minCost(nums []int, k int) (ans int) {
n := len(nums)
last := make([]int, n)
last2 := make([]int, n)
t := make(seg, n*4)
t.build(1, 1, n)
for i, x := range nums {
i++ // 线段树区间从 1 开始
t.update(1, i, i, ans) // 相当于设置 f[i+1] 的值
t.update(1, last[x]+1, i, -1)
if last[x] > 0 {
t.update(1, last2[x]+1, last[x], 1)
}
ans = k + t.query(1, 1, i)
last2[x] = last[x]
last[x] = i
}
return ans + n
}

func min(a, b int) int { if a > b { return b }; return a }

复杂度分析

  • 时间复杂度:O(n\log n),其中 n 为 nums 的长度。
  • 空间复杂度:O(n)。
 Comments