复数
复数的概念
我们把形如 $a+b\mathrm{i}(a, b\in \mathbf{R})$ 的数叫做复数,其中 $\mathrm{i}$ 叫做虚数单位,$\mathrm{i} = \sqrt {-1}$。
全体复数构成的集合 $\mathbf{C}=\{a+b\mathrm{i} \mid a,b\in \mathbf{R}\}$ 叫做复数集。
复数通常用字母 $z$ 表示,即 $z=a+b\mathrm{i}(a, b\in \mathbf{R})$,其中 $a$ 与 $b$ 分别是复数 $z$ 的实部与虚部,分别记作 $\Re(z)$ 和 $\Im(z)$。
可以建立复平面来表示复数,即在直角坐标系 $xOy$ 中,可以用点 $Z(a,b)$ 来表示复数 $z=a+b\mathrm{i}$,其中 $x$ 轴叫做实轴,$y$ 轴叫做虚轴。复数 $z$ 也与向量 $\vec{OZ}$ 一一对应。向量 $\vec{OZ}$ 的模叫做复数 $z$ 的模或绝对值,记作 $|z|$ 或 $|a+b\mathrm{i}|$,即 $|z|=|a+b\mathrm{i}|=\sqrt{a^2+b^2}$。
一般为了方便起见,通常把复数 $z=a+b\mathrm{i}$ 说成点 $Z$ 或说成向量 $\vec{OZ}$。
一般地,当两个复数实部相等,虚部互为相反数时,这两个复数叫做互为共轭复数。复数 $z$ 的共轭复数用 $\overline{z}$ 表示,即如果 $z=a+b\mathrm{i}$,那么 $\overline{z}=a-b\mathrm{i}$。
复数的四则运算
复数的加法
设 $z_1=a+b\mathrm{i}$,$z_2=c+d\mathrm{i}$,则 $z_1+z_2=(a+c)+(b+d)\mathrm{i}$。
满足交换律和结合律。
复数的减法
设 $z_1=a+b\mathrm{i}$,$z_2=c+d\mathrm{i}$,则 $z_1-z_2=(a-c)+(b-d)\mathrm{i}$。
复数的乘法
设 $z_1=a+b\mathrm{i}$,$z_2=c+d\mathrm{i}$,则 $z_1z_2=(ac-bd)+(ad+bc)\mathrm{i}$。
满足交换律,结合律和分配律。
复数的除法
设 $z_1=a+b\mathrm{i}$,$z_2=c+d\mathrm{i}$,则 $z_1\div z_2=\frac{ac+bd}{c^2+d^2}+\frac{bc-ad}{c^2+d^2}\mathrm{i}$。将 $\frac{z_1}{z_2}$ 上下同时乘 $\overline{z_2}$ 即可。
复数的三角表示
一般地,任何一个复数 $z=a+b\mathrm{i}$ 都可以表示为 $r(\cos \theta + \mathrm{i}\sin\theta)$,叫做复数 $z$ 的三角表示式,其中 $r$ 是 $z$ 的模,$\theta$ 是以实轴非负半轴为始边,射线 $OZ$ 为终边的角,叫做 $z$ 的辐角。
一个复数对应着无限多个辐角。我们规定在 $0\le\theta<2\pi$(有些地方好像定义为 $(-\pi,\pi]$)的范围内的 $z$ 的辐角 $\theta$ 的值为辐角的主值,通常记作 $\arg z$。
$z_1z_2 = r_1(\cos\theta_1+\mathrm{i}\sin\theta_1)\cdot r_2(\cos\theta_2+\mathrm{i}\sin\theta_2) = r_1r_2[\cos(\theta_1+\theta_2)+\mathrm{i}\sin(\theta_1+\theta_2)]$,也就是两个复数相乘,模相乘,辐角相加。
$\frac{z_1}{z_2}=\frac{r_1(\cos\theta_1+\mathrm{i}\sin\theta_1)}{r_2(\cos\theta_2+\mathrm{i}\sin\theta_2)} = \frac{r_1}{r_2}[\cos(\theta_1-\theta_2)+\mathrm{i}\sin(\theta_1-\theta_2)]$,即两个复数相除,模相除,辐角相减。
欧拉公式
对于任意实数 $x$ 有 $\mathrm{e}^{\mathrm{i}x}=\cos x + \mathrm{i}\sin x$。
复指数函数
对于复数 $z=a+b\mathrm{i}$,定义复指数函数为 $\exp z = \mathrm{e}^z = \mathrm{e}^a(\cos b + \mathrm{i}\sin b)$。
- $|\exp z| = \exp a \gt 0$。
- $\arg \exp z = b$。
- $\exp(z_1+z_2)=\exp(z_1)\exp(z_2)$。
- $\exp z$ 以 $2\pi \mathrm{i}$ 为最小正周期。
复三角函数
对于复数 $z=a+b\mathrm{i}$,定义复三角函数 $\cos z = \frac{\exp(\mathrm{i}z) + \exp(-\mathrm{i}z)}{2}$,$\sin z = \frac{\exp(\mathrm{i}z) + \exp(-\mathrm{i}z)}{2\mathrm{i}}$。
- 若取 $z\in \mathbf{R}$,则由欧拉公式,有,$\cos z = \Re(\mathrm{e}^{\mathrm{i}z})$,$\sin z = \Im(\mathrm{e}^{\mathrm{i}z})$。
- 复三角函数的正余弦函数均以 $2\pi$ 为最小正周期。
复数的表示形式
代数形式:$z=a+b\mathrm{i}$。
三角形式:$z=r(\cos\theta + \mathrm{i}\sin\theta)$。
指数形式:$z=r\exp(i\theta)$。
单位根
若复数 $z$ 满足 $z^n=1$,则称 $z$ 为 $n$ 次单位根。
设 $\omega_n=\exp \frac{2\pi\mathrm{i}}{n}=\cos(\frac{2\pi}{n}) + \mathrm{i}\sin(\frac{2\pi}{n})$(即辐角为 $\frac{2\pi}{n}$ 的单位复数),则所有 $n$ 次单位根 $z$ 的集合可以表示为 $\{\omega_n^{k}\mid k=0,1,\cdots,n-1\}$,即 $z_k = \omega_{n}^{k}$。
每个单位根 $z_k$ 对应复平面单位圆上一个点,其辐角为 $\frac{2k\pi}{n}$。
所有单位根将单位圆分成 $n$ 等份,构成正 $n$ 边形的顶点。
性质:
- $\omega_{n}^{n}=1$。
- $\omega_{n}^{k}=\exp \frac{2k\pi\mathrm{i}}{n} = \cos(\frac{2k\pi}{n})+\mathrm{i}\sin(\frac{2k\pi}{n}) = \cos(\frac{2kx\pi}{nx})+\mathrm{i}\sin(\frac{2kx\pi}{nx}) = \omega_{nx}^{kx}$。
- $\omega_{2n}^{k+n}=-\omega_{2n}^{k}$。
本原单位根
称集合 $\{\omega_{n}^{k}\mid 0\le k\lt n,\gcd(n,k)=1\}$ 中的元素为本原单位根。
对于任何一个本原单位根 $\omega_n$,它的 $k(0\le k\lt n)$ 次幂互不相同。从互质和整除的角度出发容易证明。
快速傅里叶变换 FFT
一种能在 $O(n\log n)$ 的时间内计算两个给定的 $n$ 次多项式的乘法的算法。
离散傅里叶变换 DFT & 离散傅里叶逆变换 IDFT
求出一个 $n$ 次多项式在 $n$ 个 $n$ 次单位根下的点值的过程称为离散傅里叶变换 DFT。
而将这些点值重新插值成系数表示法的过程称为离散傅里叶逆变换 IDFT。
DFT:$F_k = \sum\limits_{i=0}^{n-1}G_i\omega_{n}^{ik}$。
IDFT:$G_k=\frac{1}{n} \sum\limits_{i=0}^{n-1}F_i\omega_{n}^{-ik}$。(下面有证明)
若要用 FFT 计算多项式乘法 $C(x)=A(x)\cdot B(x)$。其基本思想就是,在 $O(n\log n)$ 时间内进行两遍 DFT 分别求出 $A(x)$ 和 $B(x)$ 的点值,再 $O(n)$ 将点值相乘合并为 $C(x)$ 的点值,最后在 $O(n\log n)$ 时间内进行 IDFT 插值成 $C(x)$ 的系数表示法。
快速傅里叶变换 FFT
FFT 算法可以 $O(n\log n)$ 实现 DFT。
设 $n-1$ 次多项式 $A(x)=a_0 + a_1x + a_2x^2 + \cdots + a_{n-1}x^{n-1}$,
为了方便处理,以下我们将 $A(x)$ 项数保证为 $2$ 的整数次幂,若次数不够,则补系数 $0$ 即可。即 $A(x)$ 是不超过 $n-1$ 次多项式,且 $n$ 是 $2$ 的整数次幂。
考虑如何快速计算其在 $n$ 个 $n$ 次单位根下的点值,即 $A(\omega_{n}^0),A(\omega_{n}^1),\dots,A(\omega_{n}^{n-1})$,
设, $$ A_1(x) = a_0 + a_2x + a_4x^2 + \cdots + a_{n-2}x^{\frac{n}{2}-1}, \\ A_2(x) = a_1 + a_3x + a_5x^2 + \cdots + a_{n-1}x^{\frac{n}{2}-1}, $$ 则 $A(x) = A_1(x^2) + xA_2(x^2)$,
设 $0\le k\le \frac{n}{2}-1$,则, $$ \begin{aligned} A(\omega_{n}^{k}) &= A_1(\omega_{n}^{2k}) + \omega_{n}^{k}A_2(\omega_{n}^{2k}) \\ &= A_1(\omega_{\frac{n}{2}}^{k}) + \omega_{n}^{k}A_2(\omega_{\frac{n}{2}}^{k}) \end{aligned} $$ 又有, $$ \begin{aligned} A(\omega_{n}^{k+\frac{n}{2}}) &= A_1(\omega_{n}^{2k+n}) + \omega_{n}^{k+\frac{n}{2}}A_2(\omega_{n}^{2k+n}) \\ &= A_1(\omega_{n}^{2k}) - \omega_{n}^{k}A_2(\omega_{n}^{2k}) \\ &= A_1(\omega_{\frac{n}{2}}^{k}) - \omega_{n}^{k}A_2(\omega_{\frac{n}{2}}^{k}) \end{aligned} $$ 那么我们只需要求出 $A_1(\omega_{\frac{n}{2}}^{k})$ 和 $A_2(\omega_{\frac{n}{2}}^{k})$,就可以知道 $A(\omega_{n}^{k})$ 和 $A(\omega_{n}^{k+\frac{n}{2}})$,
由于 $k$ 和 $k+\frac{n}{2}$ 取遍了 $[0,n-1]$,于是我们可以利用 $A_1(x)$ 和 $A_2(x)$ 求出 $A(x)$ 在 $n$ 个 $n$ 次单位根下的点值,
由于 $A_1(x)$ 和 $A_2(x)$ 的项数都是 $\frac{n}{2}$,即为 $A(x)$ 问题规模的一半的子问题,于是可以在 $T(n)=2T(\frac{n}{2})+O(n)=O(n\log n)$ 的复杂度内求出 $A(x)$ 的 $n$ 个点值。
快速傅里叶逆变换 IFFT
IFFT 可以用 FFT 的方法实现,可以 $O(n\log n)$ 实现 IDFT。
我们已知 $y_i=A(\omega_{n}^{i})$,现在要快速求出 $A(x)$ 的系数 $a_0,a_1,\dots,a_{n-1}$,
考虑构造如下多项式, $$ F(x) = \sum\limits_{i=0}^{n-1} y_ix^i $$ 设 $b_i=\omega_{n}^{-i}$,则, $$ \begin{aligned} F(b_k) &= \sum\limits_{i=0}^{n-1}A(\omega_{n}^{i})\omega_{n}^{-ki} \\ &= \sum\limits_{i=0}^{n-1}\omega_{n}^{-ki}\sum\limits_{j=0}^{n-1}a_{j}\omega_{n}^{ij} \\ &= \sum\limits_{j=0}^{n-1}a_j\sum\limits_{i=0}^{n-1}\omega_{n}^{(j-k)i} \end{aligned} $$ 记 $S(\omega_{n}^{x})=\sum\limits_{i=0}^{n-1}\omega_{n}^{xi}$,
当 $x\bmod n = 0$ 时,$S(\omega_{n}^{x})=n$。
当 $x\bmod n \neq 0$ 时, $$ \begin{aligned} S(\omega_{n}^{x}) &= \sum\limits_{i=0}^{n-1}\omega_{n}^{xi} & (1) \\ \omega_{n}^{x}S(\omega_{n}^{x}) &= \sum\limits_{i=1}^{n}\omega_{n}^{xi} & (2) \\ \end{aligned} $$ $(2)-(1)$ 得, $$ S(\omega_{n}^{x}) = \frac{\omega_{n}^{xn}-\omega_{n}^{0}}{\omega_{n}^{x}-1} = 0 $$ 因此, $$ \begin{aligned} F(b_k)&=a_k\cdot n \\ a_k &= \frac{F(\omega_{n}^{-k})}{n}=\frac{F(\omega_{n}^{n-k})}{n} \end{aligned} $$ 同时也证明了 IDFT。
于是我们相当于要对构造多项式 $F(x)$ 求出 $n$ 次单位根下的点值 $F(\omega_{n}^{k})$,然后除以 $n$ 得到 $A(x)$ 的系数 $a_{n-k}$。
或者可以 FFT 直接求出 $F(\omega_{n}^{-k})$ 的点值,再除以 $n$。由于 $\omega_{n}^{-k}=(\omega_{n}^{-1})^k$,相当于 FFT 证明过程中多带了一个 $-1$ 次幂,不影响 FFT 的推导。
这里可以用数值分析的方法证明除以 $n$ 四舍五入的答案不会受精度影响。
解释
FFT 为什么要选择单位根求值?
首先,初始的单位根必须选择本原单位根,保证了其 $0$ 到 $n-1$ 次幂一定互不相同,因此就能确定 $n-1$ 次多项式。这也说明了 IFFT 时选择 $\omega_n^{-1}=\omega_n^{n-1}$ 的正确性。
其次,若我们随便选择一个整数 $v$,将 $v^0,v^1,v^2,\cdots,v^{n-1}$ 代入求值,那么一样也可以推出 $A(v^k)=A_1(v^{2k})+v^kA_2(v^{2k})$,但是只有在 $k\in[0,\frac{n}{2}-1]$ 时成立。
单位根的好处就是其利用了以下几个性质:
- $\omega_{n}^{k+\frac{n}{2}}=-\omega_{n}^{k}$。
- $\omega_{n}^{k+n}=\omega_{n}^{k}$。
- $\omega_{mn}^{mk}=\omega_{n}^{k}$。
NTT 即选择了与单位根有类似性质的原根来代入求值,相对于 FFT 的优点就是常数较小,且没有精度误差。但是 NTT 只在对一些特殊质数取模才有效。
代码实现
分治法
用一个函数可以同时实现 FFT 和 IFFT。
#include <bits/stdc++.h>
using namespace std;
using LL = long long;
using ULL = unsigned long long;
using LD = long double;
using PII = pair<int, int>;
const int N = 1e6 + 5;
const double PI = acos(-1.0);
int n, m;
void FFT(vector<complex<double>>& a, int n, int o) {
if (n == 1) return;
vector<complex<double>> a1(n / 2), a2(n / 2);
for (int i = 0; i < n; i += 2) a1[i / 2] = a[i], a2[i / 2] = a[i + 1];
FFT(a1, n / 2, o), FFT(a2, n / 2, o);
complex<double> p(1, 0), w(cos(2 * PI / n), o * sin(2 * PI / n));
for (int i = 0; i < n / 2; i++) {
a[i] = a1[i] + p * a2[i];
a[i + n / 2] = a1[i] - p * a2[i];
p *= w;
}
}
int main() {
cin >> n >> m;
int tot = 1;
while (tot < n + m + 1) tot <<= 1;
vector<complex<double>> a(tot), b(tot);
for (int i = 0; i <= n; i++) {
double x;
cin >> x;
a[i].real(x);
}
for (int i = 0; i <= m; i++) {
double x;
cin >> x;
b[i].real(x);
}
FFT(a, tot, 1);
FFT(b, tot, 1);
vector<complex<double>> c(tot);
for (int i = 0; i < tot; i++) c[i] = a[i] * b[i];
FFT(c, tot, -1);
for (int i = 0; i <= n + m; i++) cout << (int)round(c[i].real() / tot) << ' ';
cout << '\n';
return 0;
}
蝶形变换法
分治法实现的常数非常大,且空间复杂度是 $O(n\log n)$ 的,于是有了更高效的蝶形变换法。
考虑 FFT 的时候,每次分治都要把奇偶位的系数分开,那么会不会存在什么规律?
假设 $n-1$ 次 $n$ 项多项式 $A(x)$ 的系数有 $a_0,a_1,a_2,a_3,a_4,a_5,a_6,a_7$,$n=8$,在 FFT 分治到最后的时候,系数位置会变成 $a_0,a_4,a_2,a_6,a_1,a_5,a_3,a_7$,
我们发现 $a_i$ 位置上的系数,最后变成 $a_j$,满足 $j$ 是 $i$ 在补全 $m=\log_2n$ 位二进制后的二进制翻转。例如 $(000)_2\to(000)_2$,$(001)_2\to(100)_2$,$(110)_2\to (011)_2$。
证明:
考虑系数 $a_x$。下证分治过后,$a_x$ 的位置会变成 $y$,$y$ 是 $x$ 在补全 $m$ 位二进制后的二进制翻转。
设 $x$ 的二进制第 $i$ 位是 $x_i$,$i$ 从 $0$ 开始。
分治一共会进行 $m$ 层,第 $m$ 层只剩下一个系数,也就是区间长度为 $1$。
从 $1$ 到 $m-1$ 考虑每一层,我们模拟一遍递归的过程。
若 $x$ 最低位是 $1$,则下一层 $x\to \frac{n}{2}+\frac{x-1}{2}$,但是下一层之后我们为了方便计算 $x$ 的目标位置,因此不能把 $\frac{n}{2}$ 带下去,于是记一个 $k$(第一层之前 $k=0$),然后 $k\to k+\frac{n}{2}$,$x\to \frac{x-1}{2}$,$n\to \frac{n}{2}$。
若 $x$ 最低位是 $0$,则下一层 $k\to k+0$,$x\to \frac{x}{2}$, $n\to \frac{n}{2}$。
那么每一层相当于把 $x$ 当前最低位,以 $x$ 二进制位中线为对称轴翻转,计入变量 $k$,然后把 $x$ 最低位去掉(除以 $2$)。
那么这样递归过后,最终 $x$ 的值变为 $0$,然后要加上 $k$ 才是 $x$ 在原序列中的位置,容易看出 $k$ 就是一开始的 $x$ 的二进制翻转。
证毕。
于是我们可以提前交换好系数 $a_i$ 的位置,然后自底向上合并,合并的时候可以在原数组上进行覆写,那么空间复杂度就降低到了 $O(n)$,且常数比分治法的递归小了很多。
如何求一个数的二进制翻转呢?这里可以 $O(n)$ 递推实现。
设 $\mathrm{rev}(x)$ 表示 $x$ 的二进制翻转,显然 $\mathrm{rev}(0)=0$,
假设我们要求 $\mathrm{rev}(x)$,且 $\mathrm{rev}(\lfloor\frac{x}{2}\rfloor)$ 已经求得,于是有, $$ \mathrm{rev}(x)=\lfloor\frac{\mathrm{rev}(\lfloor\frac{x}{2}\rfloor)}{2}\rfloor + [x\bmod 2=1]\cdot \frac{n}{2} $$ 例如 $n=32$,$(01101)_2\to (00110)_2 \to (01100)_2\to (00110)_2 \to(10110)$。
#include <bits/stdc++.h>
using namespace std;
using LL = long long;
using ULL = unsigned long long;
using LD = long double;
using PII = pair<int, int>;
const int N = 4e6 + 5;
const double PI = acos(-1.0);
int n, m;
complex<double> a[N], b[N], c[N];
int rev[N];
void FFT(complex<double> *a, int n, int o) {
rev[0] = 0;
for (int i = 1; i < n; i++) rev[i] = (rev[i >> 1] >> 1) + (i & 1) * n / 2;
for (int i = 0; i < n; i++) if (rev[i] < i) swap(a[i], a[rev[i]]);
for (int len = 2; len <= n; len <<= 1) {
for (int l = 0; l + len - 1 < n; l += len) {
int mid = len >> 1;
complex<double> w(cos(2 * PI / len), o * sin(2 * PI / len)), p(1, 0);
for (int i = 0; i < mid; i++) {
auto x = a[l + i], y = a[l + i + mid];
a[l + i] = x + p * y;
a[l + i + mid] = x - p * y;
p *= w;
}
}
}
if (o == -1) for (int i = 0; i < n; i++) a[i] /= n;
}
int main() {
ios::sync_with_stdio(false); cin.tie(0); cout.tie(0);
cin >> n >> m;
double x;
for (int i = 0; i <= n; i++) {
cin >> x;
a[i].real(x);
}
for (int i = 0; i <= m; i++) {
cin >> x;
b[i].real(x);
}
int tot = 1;
while (tot < n + m + 1) tot <<= 1;
FFT(a, tot, 1);
FFT(b, tot, 1);
for (int i = 0; i < tot; i++) c[i] = a[i] * b[i];
FFT(c, tot, -1);
for (int i = 0; i <= n + m; i++) cout << (int)round(c[i].real()) << ' ';
cout << '\n';
return 0;
}
三步变两步
设多项式 $F(x) = A(x) + B(x)\mathrm{i}$,则 $F(x)^2 = A(x)^2 - B(x)^2 + 2\cdot A(x)\cdot B(x) \cdot \mathrm{i}$,于是 $A(x)\cdot B(x) = \frac{\Im(F(x))}{2}$。
于是若要求 $A(x)\cdot B(x)$ 的系数,可以直接求出 $F(x)^2$,再取 $F(x)^2$ 每一项系数的虚部除以 $2$ 即可,只需两步 FFT。
#include <bits/stdc++.h>
using namespace std;
using LL = long long;
using ULL = unsigned long long;
using LD = long double;
using PII = pair<int, int>;
const int N = 4e6 + 5;
const double PI = acos(-1.0);
int n, m;
complex<double> a[N];
int rev[N];
void FFT(complex<double> *a, int n, int o) {
rev[0] = 0;
for (int i = 1; i < n; i++) rev[i] = (rev[i >> 1] >> 1) + (i & 1) * n / 2;
for (int i = 0; i < n; i++) if (rev[i] < i) swap(a[i], a[rev[i]]);
for (int len = 2; len <= n; len <<= 1) {
for (int l = 0; l + len - 1 < n; l += len) {
int mid = len >> 1;
complex<double> w(cos(2 * PI / len), o * sin(2 * PI / len)), p(1, 0);
for (int i = 0; i < mid; i++) {
auto x = a[l + i], y = a[l + i + mid];
a[l + i] = x + p * y;
a[l + i + mid] = x - p * y;
p *= w;
}
}
}
if (o == -1) for (int i = 0; i < n; i++) a[i] /= n;
}
int main() {
ios::sync_with_stdio(false); cin.tie(0); cout.tie(0);
cin >> n >> m;
double x;
for (int i = 0; i <= n; i++) {
cin >> x;
a[i].real(x);
}
for (int i = 0; i <= m; i++) {
cin >> x;
a[i].imag(x);
}
int tot = 1;
while (tot < n + m + 1) tot <<= 1;
FFT(a, tot, 1);
for (int i = 0; i < tot; i++) a[i] *= a[i];
FFT(a, tot, -1);
for (int i = 0; i <= n + m; i++) cout << (int)round(a[i].imag() / 2) << ' ';
cout << '\n';
return 0;
}
各优化时间对比
题目
P1919 【模板】高精度乘法 | A*B Problem 升级版
相当于给了两个多项式在 $x=10$ 位置的点值,求多项式相乘后在 $x=10$ 位置的点值。
#include <bits/stdc++.h>
using namespace std;
using LL = long long;
using ULL = unsigned long long;
using LD = long double;
using PII = pair<int, int>;
const int N = 4e6 + 5;
const double PI = acos(-1.0);
string s, t;
complex<double> a[N];
int rev[N], ans[N];
void FFT(complex<double> *a, int n, int o) {
rev[0] = 0;
for (int i = 1; i < n; i++) rev[i] = (rev[i >> 1] >> 1) + (i & 1) * n / 2;
for (int i = 0; i < n; i++) if (rev[i] < i) swap(a[i], a[rev[i]]);
for (int len = 2; len <= n; len <<= 1) {
for (int l = 0; l + len - 1 < n; l += len) {
int mid = len >> 1;
complex<double> w(cos(2 * PI / len), o * sin(2 * PI / len)), p(1, 0);
for (int i = 0; i < mid; i++) {
auto x = a[l + i], y = a[l + i + mid];
a[l + i] = x + p * y;
a[l + i + mid] = x - p * y;
p *= w;
}
}
}
}
int main() {
cin >> s >> t;
int n = s.size(), m = t.size();
for (int i = n - 1; i >= 0; i--) a[i].real(s[n - i - 1] - '0');
for (int i = m - 1; i >= 0; i--) a[i].imag(t[m - i - 1] - '0');
int tot = 1;
while (tot < n + m - 1) tot <<= 1;
FFT(a, tot, 1);
for (int i = 0; i < tot; i++) a[i] *= a[i];
FFT(a, tot, -1);
int len = n + m - 1;
for (int i = 0; i < len; i++) {
ans[i] += round(a[i].imag() / tot / 2);
if (ans[i] >= 10) {
ans[i + 1] += ans[i] / 10;
ans[i] %= 10;
len += (i == len - 1);
}
}
while (len > 1 && ans[len - 1] == 0) len--;
for (int i = len - 1; i >= 0; i--) cout << ans[i];
cout << '\n';
return 0;
}
HDU4609 3-idiots
设 $f(i)$ 表示相加起来 $=i$ 的两根木棍的对数。
设 $g(i)$ 表示长度 $=i$ 的木棍的个数,即对 $a_i$ 开一个桶。
把 $g$ 看成多项式,$g(i)$ 是 $i$ 次项的系数,设 $h=g^2$,则, $$ h(i) = \sum\limits_{j=1}^{i} g(j)\cdot g(i-j) $$ $h$ 用 FFT 计算即可。
那么 $f(i) = h(i) - \sum\limits_{j=1}^{n} [a_j+a_j=i]$,即减去两个重复选同一根的方案。
于是我们对 $a_i$ 升序排序,枚举 $i\in [1,n]$,计算以 $a_i$ 为三根木棍中的最长木棍,且三根木棍能构成三角形的方案数。
- 首先答案加上 $\sum\limits_{j=a_i+1} f(j)$,表示选择的另外两根木棍长度和 $>a_i$。
- 由于我们要保证两根最小的长度加起来大于第三根,所以剩下两根必须选择 $\le a_i$ 的,并且不能重复选 $a_i$,因此要减去一些不合法的方案。
- 减去 $(n-i)\cdot (i-1)$,表示一根选择了 $\ge a_i$ 的,一根选择 $\le a_i$ 的,但都没有重复选 $a_i$ 的方案数。
- 减去 $n-1$,表示一根选择了 $a_i$ 这根(注意是 $a_i$ 这根,而不是长度 $=a_i$,也就是选重复了),另一根选了其他的方案数。
- 减去 $(n-i)\cdot (n-i-1) \cdot\frac{1}{2}$,表示两个都选择了 $\ge a_i$,都没有重复选 $a_i$ 的方案数。
- 两个都重复选 $a_i$ 的方案数已经在前面计算 $f$ 的时候减过了。
最后合法方案除以总方案就是概率了。
#include <bits/stdc++.h>
using namespace std;
using LL = long long;
using ULL = unsigned long long;
using LD = long double;
using PII = pair<int, int>;
const int N = 4e5 + 5, V = 2e5;
const double PI = acos(-1.0);
int n, a[N], tot;
LL b[N], f[N], sum[N];
complex<double> B[N];
namespace Poly {
int rev[N];
void FFT(complex<double> *a, int n, int o) {
rev[0] = 0;
for (int i = 1; i < n; i++) rev[i] = (rev[i >> 1] >> 1) + (i & 1) * n / 2;
for (int i = 0; i < n; i++) if (rev[i] < i) swap(a[i], a[rev[i]]);
for (int len = 2; len <= n; len <<= 1) {
for (int l = 0; l + len - 1 < n; l += len) {
int mid = len >> 1;
complex<double> p(1, 0), w(cos(2 * PI / len), o * sin(2 * PI / len));
for (int i = 0; i < mid; i++) {
auto x = a[l + i], y = a[l + i + mid];
a[l + i] = x + p * y;
a[l + i + mid] = x - p * y;
p *= w;
}
}
}
}
} using Poly::FFT;
void Solve() {
cin >> n;
for (int i = 1; i <= n; i++) {
cin >> a[i];
b[a[i]]++;
}
sort(a + 1, a + 1 + n);
int m = a[n], tot = 1;
while (tot < m + m + 1) tot <<= 1;
// 计算 f = b * b
for (int i = 0; i <= m; i++) B[i].real(b[i]);
FFT(B, tot, 1);
for (int i = 0; i < tot; i++) B[i] *= B[i];
FFT(B, tot, -1);
for (int i = 0; i < tot; i++) B[i] /= tot;
for (int i = 0; i < tot; i++) f[i] = round(B[i].real());
for (int i = 1; i <= n; i++) f[a[i] + a[i]]--; // 减去重复选的 a[i]+a[i]
for (int i = 1; i < tot; i++) sum[i] = sum[i - 1] + f[i] / 2; // 除以2 是因为一对 a[i]+a[j] 会算两次
LL ans = 0;
// 枚举选取的最大长度是 a[i]
for (int i = 1; i <= n; i++) {
ans += sum[tot - 1] - sum[a[i]]; // 先计算上 加起来 >a[i] 的方案数
// 因为 a[i] 的最大长度,因此我们要求剩余选的两个必须都 <=a[i],且加起来 >a[i]
ans -= (LL)(n - i) * (i - 1); // 一个选 >=a[i] 的,另一个选 <=a[i] 的,不合法
ans -= n - 1; // 一个选 a[i],另一个选不是 a[i] 的,不合法
ans -= (LL)(n - i) * (n - i - 1) / 2; // 两个都选 >=a[i] 的,不合法
// 两个都选 a[i] 已经在前面减过了
// 剩下的就是两个都 <=a[i],且加起来 >a[i] 的方案数
}
printf("%.7lf\n", (double)ans / ((LL)n * (n - 1) * (n - 2) / 6));
// clear
for (int i = 0; i < tot; i++) {
b[i] = f[i] = sum[i] = 0;
B[i] = {0, 0};
}
}
int main() {
int T;
cin >> T;
while (T--) Solve();
return 0;
}