这道题重点要说矩阵快速幂,公式不会推。用了一个Berlekamp-Massey algorithm的模板。最主要想将一般情况下线性递推式怎么用矩阵快速幂优化。
本题目的递推公式是:F(n)=6*F(n-1)-8*F(n-2)-8*F(n-3)+16*F(n-4)。
故构造矩阵递推公式:
\begin{bmatrix}
F(n) \\
F(n-1) \\
F(n-2) \\
F(n-3)
\end{bmatrix} =
\begin{bmatrix}
6 & -8 & -8 & 16 \\
1 & 0 & 0 & 0 \\
0 & 1 & 0 & 0 \\
0 & 0 & 1 & 0
\end{bmatrix} *
\begin{bmatrix}
F(n-1) \\
F(n-2) \\
F(n-3) \\
F(n-4)
\end{bmatrix}
得
\begin{bmatrix}
F(n) \\
F(n-1) \\
F(n-2) \\
F(n-3)
\end{bmatrix} =
\begin{bmatrix}
6 & -8 & -8 & 16 \\
1 & 0 & 0 & 0 \\
0 & 1 & 0 & 0 \\
0 & 0 & 1 & 0
\end{bmatrix}^{n-4} *
\begin{bmatrix}
F(4) \\
F(3) \\
F(2) \\
F(1)
\end{bmatrix}
。
对于一般的线性递推公式F(n)=a_1*F(n-1)+a_2*F(n-2)+…+a_{k-1}*F(n-k+1)+a_k*F(n-k),
可以构造一长宽都为k的矩阵,满足:
\begin{bmatrix}
F(n) \\
F(n-1) \\
… \\
F(n-k+2) \\
F(n-k+1)
\end{bmatrix} =
\begin{bmatrix}
a_1 & a_2 & … & a_{n-k+1} & a_{n-k} \\
1 & 0 & … & 0 & 0 \\
0 & 1 & … & 0 & 0 \\
… & … & … & … & … \\
0 & 0 & … & 1 & 0
\end{bmatrix} *
\begin{bmatrix}
F(n-1) \\
F(n-2) \\
… \\
F(n-k+1) \\
F(n-k)
\end{bmatrix}
因此只需对该矩阵做快速幂,即可以以O(k^3logn)的复杂度推出任意n情况下的数列的值。
本题代码如下:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 | #include <iostream> using namespace std; const long long MOD = 1e9 + 7; const int maxtrixN=4; struct matrix { long long arr[maxtrixN][maxtrixN]; matrix operator * (matrix m2) { matrix result; long long(&arr)[maxtrixN][maxtrixN] = result.arr; const long long(&arr1)[maxtrixN][maxtrixN] = this->arr; const long long(&arr2)[maxtrixN][maxtrixN] = m2.arr; for (int i = 0; i < maxtrixN; i++) { for (int j = 0; j < maxtrixN; j++) { arr[i][j] = 0; for (int k = 0; k < maxtrixN; k++) { arr[i][j] += arr1[i][k] * arr2[k][j] % MOD; arr[i][j] += MOD; arr[i][j] %= MOD; } } } return result; } }; const matrix c = { 6,-8,-8,16,1,0,0,0,0,1,0,0,0,0,1,0 }; matrix quickpow(long long n) { if (n == 1)return c; matrix half = quickpow(n / 2); matrix result = half * half; if (n & 1)result = result * c; return result; } long long getValue(long long n) { matrix result = quickpow(n); long long(&arr)[4][4] = result.arr; return (arr[0][0] * 1536 % MOD + arr[0][1] * 416 % MOD + arr[0][2] * 96 % MOD + arr[0][3] * 24 % MOD +MOD) % MOD; } int main() { int n; long long a1[] = { 0,2,24,96,416,1536 }; cin >> n; if (n < 6)cout << a1[n]; else cout << getValue(n-5); } |