「题解」UVA1519 Dictionary Size

#0.0 大致题意

给定 \(n(n\le 10^4)\) 个字符串,定义任意一个字符串 \(s\) 为“好的”,当且仅当满足下列任意一个条件:

  1. \(s\) 为这 \(n\) 个字符串其中一个。
  2. \(s\) 可以被划分为一个前缀和一个后缀,其中前缀为给定 \(n\) 个字符串中的一个字符串的前缀,后缀为给定 \(n\) 个字符串中的一个字符串的后缀。

给定 \(n\) 个串,求一共有多少个好的字符串 \(s\) 。数据保证每个字符串长度 \(len\)\(1\le len \le 40\)

注意有多组数据。

#1.0 简单思路

直接计算无重复的字符串有多少个是不好做的,我们来考虑进行容斥,先得到具有重复的字符串数量,再减去重复的贡献。

首先将所有字符串正向插入第一棵 trie 中,那么此时这棵 trie 中的节点数量(不包含根节点)便是所有非空前缀的数量,非空后缀的数量大同小异,不过注意需要将字符串反转,这样从根节点向下到达任意一个节点所经过的路径才能组成一个原串中的后缀。

前后缀数量相乘得到了我们的第一部分答案,但这个答案代表的可重集中是否包含了所有的“好的”字符串呢?并不是,注意到我们拼出来的串的长度 \(\geq2\),如果输入的字典中存在长度为 \(1\) 的字符串,那么他显然是不被包含在其中的,所以我们需要特别记录,但是这样的字符串可以直接去重(每个字符最多单个出现 \(1\) 次)。

最后再来考虑如何去掉重复的贡献。考虑一个前缀为 \(Ax\),一个后缀为 \(xB\)\(x\) 为任意单个字符,\(A,B\) 为两个非空字符串),那么有这样一个结合为 \(AxB\),会以下面两种方式计入当前答案:\(Ax+B\)\(A+xB\),所以说一个形如 \(AxB\) 的字符串便被计算了两遍。

所以对于任何一个以 \(x\) 为结尾的前缀 \(Ax\) 和任何一个以 \(x\) 为开头的后缀 \(xB\),都会有这样的重复计算,所以我们可以统计以 \(x\) 为结尾的长度大于 \(1\) 的前缀数量,和以 \(x\) 为开头的长度大于 \(1\) 的后缀数量,两者相乘便是这个字符做出的重复贡献,应当被减去。

这样的时间复杂度为 \(O(|S|)\),瓶颈在于建立 trie.

更多的细节请见下方代码。

#2.0 代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
#define mset(l, x) memset(l, x, sizeof(l))

const int N = 2000010;
const int INF = 0x3fffffff;

struct Trie {
int ch[N][30], rt; ll cnt, res[30];

inline void clear() {
cnt = rt = 0;
mset(ch, 0); mset(res, 0);
}

inline void insert(char *s) {
int p = rt, len = strlen(s);
for (int i = 0; i < len; ++ i) {
int k = s[i] - 'a';
if (!ch[p][k]) {
ch[p][k] = ++ cnt;
if (i > 0) ++ res[k];
/*注意统计的前、后缀长度要大于 1*/
}
p = ch[p][k];
}
}
} t[2];

int n, al[30]; ll ans;

inline void reset() { //多组数据,需要清空
ans = 0; mset(al, 0);
t[0].clear(); t[1].clear();
}

int main() {
while (scanf("%d", &n) != EOF) {
reset(); char s[N];
for (int i = 1; i <= n; ++ i) {
scanf("%s", s);
/*分别将正反向加入两棵 trie*/
int len = strlen(s); t[0].insert(s);
reverse(s, s + len); t[1].insert(s);
if (len == 1) al[s[0] - 'a'] = 1;
/*对于单个字符的记录与去重*/
}
ans = t[0].cnt * t[1].cnt;
for (int i = 0; i < 26; ++ i) {
if (al[i]) ++ ans;
ans -= t[0].res[i] * t[1].res[i];
}
printf("%lld\n", ans);
}
return 0;
}