矩阵算法方法(算法之2矩阵乘法的Strassen算法)
矩阵算法方法(算法之2矩阵乘法的Strassen算法)其中的方矩阵,求 C=AB ,如下所示:。Strassen算法证明了存在时间复杂度低于的算法。假设矩阵 A 和矩阵 B 都是
一般的矩阵乘法算法时间复杂度为
。
1969年,Volker Strassen第一个提出了复杂度低于
的矩阵乘法算法,算法时间复杂度为
。Strassen算法证明了存在时间复杂度低于
的算法。
假设矩阵 A 和矩阵 B 都是
的方矩阵,求 C=AB ,如下所示:
其中
矩阵 C 可以通过下列公式求出:
从上述公式我们可以得出,计算2个 n * n 的矩阵相乘需要2个
的矩阵8次乘法和4次加法。我们使用 T (n) 表示 n * n 矩阵乘法的时间复杂度,那么我们可以根据上面的分解得到下面的递推公式:
其中,
- 1.
表示8次矩阵乘法,而且相乘的矩阵规模降到了
。
- 2.
表示4次矩阵加法的时间复杂度以及合并矩阵 C 的时间复杂度。
最终可计算得到
。
现在,我们来看一下Strassen算法的原理。
仍然把每个矩阵分割为4份,然后创建如下10个中间矩阵:
S1 = B12 - B22
S2 = A11 A12
S3 = A21 A22
S4 = B21 - B11
S5 = A11 A22
S6 = B11 B22
S7 = A12 - A22
S8 = B21 B22
S9 = A11 - A21
S10 = B11 B12
接着,计算7次矩阵乘法:
P1 = A11 • S1
P2 = S2 • B22
P3 = S3 • B11
P4 = A22 • S4
P5 = S5 • S6
P6 = S7 • S8
P7 = S9 • S10
最后,根据这7个结果就可以计算出C矩阵:
C11 = P5 P4 - P2 P6
C12 = P1 P2
C21 = P3 P4
C22 = P5 P1 - P3 - P7
T(n) = 7T(n/2) Θ(n2)
使用递归树或主方法可以计算出结果:
T(n) = Θ(nlg7) ≈ Θ(n2.81)
下图展示了平凡算法和Strassen算法的性能差异,n越大,Strassen算法节约的时间越多。
代码如下:
import java.util.Arrays;
public class MatrixMultiply {
public static void SquareMatrixMultiply(int A[][] int B[][]) {
int rows = A.length;
int C[][] = new int[rows][rows];
for (int i = 0; i < rows; i ) {
for (int j = 0; j < rows; j ) {
C[i][j] = 0;
for (int k = 0; k < rows; k ) {
C[i][j] = A[i][k] * B[k][j];
}
}
}
displaySquare(C);
}
public static void displaySquare(int matrix[][]) {
for (int i = 0; i < matrix.length; i ) {
for (int j : matrix[i]) {
System.out.print(j " ");
}
System.out.println();
}
}
public static void copyToMatrixArray(int srcMatrix[][] int startI int startJ int iLen int jLen
int destMatrix[][]) {
for (int i = startI; i < startI iLen; i ) {
for (int j = startJ; j < startJ jLen; j ) {
destMatrix[i - startI][j - startJ] = srcMatrix[i][j];
}
}
}
public static void copyFromMatrixArray(int destMatrix[][] int startI int startJ int iLen int jLen
int srcMatrix[][]) {
for (int i = 0; i < iLen; i ) {
for (int j = 0; j < jLen; j ) {
destMatrix[startI i][startJ j] = srcMatrix[i][j];
}
}
}
public static void squareMatrixAdd(int A[][] int B[][] int C[][]) {
for (int i = 0; i < A.length; i ) {
for (int j = 0; j < A[i].length; j ) {
C[i][j] = A[i][j] B[i][j];
}
}
}
public static void squareMatrixSub(int A[][] int B[][] int C[][]) {
for (int i = 0; i < A.length; i ) {
for (int j = 0; j < A[i].length; j ) {
C[i][j] = A[i][j] - B[i][j];
}
}
}
public static int[][] squareMatrixMultiplyRecursive(int A[][] int B[][]) {
int n = A.length;
int C[][] = new int[n][n];
if (n == 1) {
C[0][0] = A[0][0] * B[0][0];
} else {
int A11[][] A12[][] A21[][] A22[][];
int B11[][] B12[][] B21[][] B22[][];
int C11[][] C12[][] C21[][] C22[][];
A11 = new int[n/2][n/2];A12 = new int[n/2][n/2];A21 = new int[n/2][n/2];A22 = new int[n/2][n/2];
copyToMatrixArray(A 0 0 n/2 n/2 A11);
copyToMatrixArray(A 0 n/2 n/2 n/2 A12);
copyToMatrixArray(A n/2 0 n/2 n/2 A21);
copyToMatrixArray(A n/2 n/2 n/2 n/2 A22);
B11 = new int[n/2][n/2];B12 = new int[n/2][n/2];B21 = new int[n/2][n/2];B22 = new int[n/2][n/2];
copyToMatrixArray(B 0 0 n/2 n/2 B11);
copyToMatrixArray(B 0 n/2 n/2 n/2 B12);
copyToMatrixArray(B n/2 0 n/2 n/2 B21);
copyToMatrixArray(B n/2 n/2 n/2 n/2 B22);
C11 = new int[n/2][n/2];C12 = new int[n/2][n/2];C21 = new int[n/2][n/2];C22 = new int[n/2][n/2];
squareMatrixAdd(squareMatrixMultiplyRecursive(A11 B11) squareMatrixMultiplyRecursive(A12 B21)
C11);
squareMatrixAdd(squareMatrixMultiplyRecursive(A11 B12) squareMatrixMultiplyRecursive(A12 B22)
C12);
squareMatrixAdd(squareMatrixMultiplyRecursive(A21 B11) squareMatrixMultiplyRecursive(A22 B21)
C21);
squareMatrixAdd(squareMatrixMultiplyRecursive(A21 B12) squareMatrixMultiplyRecursive(A22 B22)
C22);
copyFromMatrixArray(C 0 0 n/2 n/2 C11);
copyFromMatrixArray(C 0 n/2 n/2 n/2 C12);
copyFromMatrixArray(C n/2 0 n/2 n/2 C21);
copyFromMatrixArray(C n/2 n/2 n/2 n/2 C22);
}
return C;
}
public static int[][] strassenMatrixMultiplyRecursive(int A[][] int B[][]) {
int n = A.length;
int C[][] = new int[n][n];
if (n == 1) {
C[0][0] = A[0][0] * B[0][0];
} else {
int A11[][] A12[][] A21[][] A22[][];
int B11[][] B12[][] B21[][] B22[][];
int C11[][] C12[][] C21[][] C22[][];
int S1[][] S2[][] S3[][] S4[][] S5[][] S6[][] S7[][] S8[][] S9[][] S10[][];
int P1[][] P2[][] P3[][] P4[][] P5[][] P6[][] P7[][];
A11 = new int[n/2][n/2];A12 = new int[n/2][n/2];A21 = new int[n/2][n/2];A22 = new int[n/2][n/2];
copyToMatrixArray(A 0 0 n/2 n/2 A11);
copyToMatrixArray(A 0 n/2 n/2 n/2 A12);
copyToMatrixArray(A n/2 0 n/2 n/2 A21);
copyToMatrixArray(A n/2 n/2 n/2 n/2 A22);
B11 = new int[n/2][n/2];B12 = new int[n/2][n/2];B21 = new int[n/2][n/2];B22 = new int[n/2][n/2];
copyToMatrixArray(B 0 0 n/2 n/2 B11);
copyToMatrixArray(B 0 n/2 n/2 n/2 B12);
copyToMatrixArray(B n/2 0 n/2 n/2 B21);
copyToMatrixArray(B n/2 n/2 n/2 n/2 B22);
S1 = new int[n/2][n/2];S2 = new int[n/2][n/2];S3 = new int[n/2][n/2];S4 = new int[n/2][n/2];
S5 = new int[n/2][n/2];S6 = new int[n/2][n/2];S7 = new int[n/2][n/2];S8 = new int[n/2][n/2];
S9 = new int[n/2][n/2];S10 = new int[n/2][n/2];
squareMatrixSub(B12 B22 S1);squareMatrixAdd(A11 A12 S2);squareMatrixAdd(A21 A22 S3);
squareMatrixSub(B21 B11 S4);squareMatrixAdd(A11 A22 S5);squareMatrixAdd(B11 B22 S6);
squareMatrixSub(A12 A22 S7);squareMatrixAdd(B21 B22 S8);squareMatrixSub(A11 A21 S9);
squareMatrixAdd(B11 B12 S10);
P1 = new int[n/2][n/2];P2 = new int[n/2][n/2];P3 = new int[n/2][n/2];P4 = new int[n/2][n/2];
P5 = new int[n/2][n/2];P6 = new int[n/2][n/2];P7 = new int[n/2][n/2];
P1 = strassenMatrixMultiplyRecursive(A11 S1);
P2 = strassenMatrixMultiplyRecursive(S2 B22);
P3 = strassenMatrixMultiplyRecursive(S3 B11);
P4 = strassenMatrixMultiplyRecursive(A22 S4);
P5 = strassenMatrixMultiplyRecursive(S5 S6);
P6 = strassenMatrixMultiplyRecursive(S7 S8);
P7 = strassenMatrixMultiplyRecursive(S9 S10);
C11 = new int[n/2][n/2];C12 = new int[n/2][n/2];C21 = new int[n/2][n/2];C22 = new int[n/2][n/2];
int temp[][] = new int[n/2][n/2];
squareMatrixAdd(P5 P4 temp);
squareMatrixSub(temp P2 temp);
squareMatrixAdd(temp P6 C11);
squareMatrixAdd(P1 P2 C12);
squareMatrixAdd(P3 P4 C21);
squareMatrixAdd(P5 P1 temp);
squareMatrixSub(temp P3 temp);
squareMatrixSub(temp P7 C22);
copyFromMatrixArray(C 0 0 n/2 n/2 C11);
copyFromMatrixArray(C 0 n/2 n/2 n/2 C12);
copyFromMatrixArray(C n/2 0 n/2 n/2 C21);
copyFromMatrixArray(C n/2 n/2 n/2 n/2 C22);
}
return C;
}
public static int sMatrixA[][] = new int[][] {
{1 2 3 4 5 6 7 8}
{1 2 3 4 5 6 7 8}
{1 2 3 4 5 6 7 8}
{1 2 3 4 5 6 7 8}
{1 2 3 4 5 6 7 8}
{1 2 3 4 5 6 7 8}
{1 2 3 4 5 6 7 8}
{1 2 3 4 5 6 7 8}
};
public static int sMatrixB[][] = new int[][] {
{5 6 7 8 1 2 3 4}
{5 6 7 8 1 2 3 4}
{5 6 7 8 1 2 3 4}
{5 6 7 8 1 2 3 4}
{5 6 7 8 1 2 3 4}
{5 6 7 8 1 2 3 4}
{5 6 7 8 1 2 3 4}
{5 6 7 8 1 2 3 4}
};
public static void main(String[] args) {
System.out.println("普通矩阵乘法");
SquareMatrixMultiply(sMatrixA sMatrixB);
System.out.println("\n递归矩阵乘法");
int C[][] = squareMatrixMultiplyRecursive(sMatrixA sMatrixB);
displaySquare(C);
System.out.println("\n Strassen 递归矩阵乘法");
C = strassenMatrixMultiplyRecursive(sMatrixA sMatrixB);
displaySquare(C);
}
}
注:凡属于本公众号内容,未经允许不得私自转载,否则将依法追究侵权责任。