Matrices are quite intimidating! Right? Well not anymore, in this article we will be discussing a very popular way of matrix multiplication. It’s way more efficient than naïve multiplication.
We all know how tedious it is to multiply two matrices. Even for a program it takes a bit of time and writing the code for such is even worse. But do not worry here we will walk through the whole process.
Table of Contents
Introduction
But first let us first recall what a matrix multiplication looks like:
Naïve Approach For Multiplying Matrices
The most naïve way to this is to use three loops and calculate the values of each cell in the answer matrix.
Code Implementation
CPP
#include <bits/stdc++.h>
using namespace std;
void matrixmultiply(int **A, int **B, int **C,int n)
{
for (int i = 0; i < n; i++)
{
for (int j = 0; j < n; j++)
{
C[i][j] = 0;
for (int k = 0; k < n; k++)
{
C[i][j] += A[i][k]*B[k][j];
}
}
}
}
int main() {
// your code goes here
int n;
cin>>n;
int **A = new int*[n];
for(int i = 0;i<n;i++)
{
A[i] = new int[n];
for(int j =0;j<n;j++)
cin>>A[i][j];
}
int **B = new int*[n];
for(int i = 0;i<n;i++)
{
B[i] = new int[n];
for(int j =0;j<n;j++)
cin>>B[i][j];
}
int **C = new int*[n];
for(int i = 0;i<n;i++)
C[i] = new int[n];
matrixmultiply(A,B,C,n);
for(int i =0;i<n;i++){
cout<<endl;
for(int j = 0;j<n;j++)
cout<<C[i][j]<<" ";}
return 0;
}
JAVA
import java.util.*;
import java.lang.*;
import java.io.*;
class Main
{
public static void matrixmultiply(int[][] A, int[][] B, int[][] C,int n)
{
for (int i = 0; i < n; i++)
{
for (int j = 0; j < n; j++)
{
C[i][j] = 0;
for (int k = 0; k < n; k++)
{
C[i][j] += A[i][k]*B[k][j];
}
}
}
}
public static void main (String[] args) throws java.lang.Exception
{
// your code goes here
int n;
Scanner sc = new Scanner(System.in);
n = sc.nextInt();
int[][] A = new int[n][n];
for(int i =0;i<n;i++)
for(int j = 0;j<n;j++)
A[i][j] = sc.nextInt();
int[][] B = new int[n][n];
for(int i =0;i<n;i++)
for(int j = 0;j<n;j++)
B[i][j] = sc.nextInt();
int[][] C = new int[n][n];
matrixmultiply(A,B,C,n);
for(int i =0;i<n;i++){
System.out.println();
for(int j = 0;j<n;j++)
System.out.println(C[i][j]+" ");}
}
}
Time Complexity:On3, where the matrices are of nxn size.
Space Complexity:On2 for C matrix
Now as you can see this method requires cubic time complexity, hence we had to develop some other method to reduce the matrix multiplication complexity.
This method is called Strassen’s Matrix Multiplication. Let’s see this in detail.
Recursive Matrix Multiplication – Strassen has modified this a bit 🙂
The idea of this method is we can find out the matrix multiplication of a 2×2 matrix in constant time. Essentially-
Multiplying these two matrices and putting them in C:
These four operations will take constant time always. This method is a Divide and Conquer Method which means the bigger matrices and broken into smaller subproblems such that –
Can you now observe that we again have 2×2 matrices which have sub problems? Each quadrant can be treated as a subproblem and hence we can have reduced the bigger problem into subproblems. Now let’s see how can we actually multiply them-
Algorithm For Matrices Multiplication
Multiply(A,B,n){
if(n<=2){
formulas above shown
}
else{ // each Aij is basically the quadrants shown above
mid = n/2
//matrix addition not normal integer addition
Multiply(A11,B11,n/2) + Multiply(A12,B21,n/2);
Multiply(A11,B12,n/2) + Multiply(A12,B22,n/2);
Multiply(A21,B11,n/2) + Multiply(A22,B21,n/2);
Multiply(A21,B12,n/2) + Multiply(A22,B22,n/2);
}
}
Strassen’s Matrix Multiplication
The above recursive function also takes On3 time and has eight calls to multiply smaller matrices. Now the whole-time complexity of matrix multiplication depends on these eight calls. This is where Strassen has modified the algorithm to reduce the eight calls to seven and along with it formulated some different formulas for the same process.
These are the 7 equations that correspond to the previous 4 equations and hence this reduces the time for 8 calls to 7 calls.
Code Implementation:
CPP
#include<bits/stdc++.h>
int size;
using namespace std;
void strassen(vector< vector<int> > &A,
vector< vector<int> > &B,
vector< vector<int> > &C, unsigned int t);
unsigned int nextPowerOfTwo(int n);
void strassenR(vector< vector<int> > &A,
vector< vector<int> > &B,
vector< vector<int> > &C,
int t);
void sum(vector< vector<int> > &A,
vector< vector<int> > &B,
vector< vector<int> > &C, int t);
void subtract(vector< vector<int> > &A,
vector< vector<int> > &B,
vector< vector<int> > &C, int t);
void printMatrix(vector< vector<int> > matrix, int n);
void read(string filename, vector< vector<int> > &A, vector< vector<int> > &B);
void multiply(vector< vector<int> > A,
vector< vector<int> > B,
vector< vector<int> > &C, int n) {
for (int i = 0; i < n; i++) {
for (int k = 0; k < n; k++) {
for (int j = 0; j < n; j++) {
C[i][j] += A[i][k] * B[k][j];
}
}
}
}
void strassenR(vector< vector<int> > &A,
vector< vector<int> > &B,
vector< vector<int> > &C, int t) {
if (t <= size) {
multiply(A, B, C, t);
return;
}
else {
int newT = t/2;
vector<int> inner (newT);
vector< vector<int> >
a11(newT,inner), a12(newT,inner), a21(newT,inner), a22(newT,inner),
b11(newT,inner), b12(newT,inner), b21(newT,inner), b22(newT,inner),
c11(newT,inner), c12(newT,inner), c21(newT,inner), c22(newT,inner),
p1(newT,inner), p2(newT,inner), p3(newT,inner), p4(newT,inner),
p5(newT,inner), p6(newT,inner), p7(newT,inner),
aResult(newT,inner), bResult(newT,inner);
int i, j;
for (i = 0; i < newT; i++) {
for (j = 0; j < newT; j++) {
a11[i][j] = A[i][j];
a12[i][j] = A[i][j + newT];
a21[i][j] = A[i + newT][j];
a22[i][j] = A[i + newT][j + newT];
b11[i][j] = B[i][j];
b12[i][j] = B[i][j + newT];
b21[i][j] = B[i + newT][j];
b22[i][j] = B[i + newT][j + newT];
}
}
sum(a11, a22, aResult, newT); // a11 + a22
sum(b11, b22, bResult, newT); // b11 + b22
strassenR(aResult, bResult, p1, newT); // p1 = (a11+a22) * (b11+b22)
sum(a21, a22, aResult, newT); // a21 + a22
strassenR(aResult, b11, p2, newT); // p2 = (a21+a22) * (b11)
subtract(b12, b22, bResult, newT); // b12 - b22
strassenR(a11, bResult, p3, newT); // p3 = (a11) * (b12 - b22)
subtract(b21, b11, bResult, newT); // b21 - b11
strassenR(a22, bResult, p4, newT); // p4 = (a22) * (b21 - b11)
sum(a11, a12, aResult, newT); // a11 + a12
strassenR(aResult, b22, p5, newT); // p5 = (a11+a12) * (b22)
subtract(a21, a11, aResult, newT); // a21 - a11
sum(b11, b12, bResult, newT); // b11 + b12
strassenR(aResult, bResult, p6, newT); // p6 = (a21-a11) * (b11+b12)
subtract(a12, a22, aResult, newT); // a12 - a22
sum(b21, b22, bResult, newT); // b21 + b22
strassenR(aResult, bResult, p7, newT); // p7 = (a12-a22) * (b21+b22)
// calculating c21, c21, c11 e c22:
sum(p3, p5, c12, newT); // c12 = p3 + p5
sum(p2, p4, c21, newT); // c21 = p2 + p4
sum(p1, p4, aResult, newT); // p1 + p4
sum(aResult, p7, bResult, newT); // p1 + p4 + p7
subtract(bResult, p5, c11, newT); // c11 = p1 + p4 - p5 + p7
sum(p1, p3, aResult, newT); // p1 + p3
sum(aResult, p6, bResult, newT); // p1 + p3 + p6
subtract(bResult, p2, c22, newT); // c22 = p1 + p3 - p2 + p6
for (i = 0; i < newT ; i++) {
for (j = 0 ; j < newT ; j++) {
C[i][j] = c11[i][j];
C[i][j + newT] = c12[i][j];
C[i + newT][j] = c21[i][j];
C[i + newT][j + newT] = c22[i][j];
}
}
}
}
unsigned int nextPowerOfTwo(int n) {
return pow(2, int(ceil(log2(n))));
}
void strassen(vector< vector<int> > &A,
vector< vector<int> > &B,
vector< vector<int> > &C, unsigned int n) {
//unsigned int n = t;
unsigned int m = nextPowerOfTwo(n);
vector<int> inner(m);
vector< vector<int> > APrep(m, inner), BPrep(m, inner), CPrep(m, inner);
for(unsigned int i=0; i<n; i++) {
for (unsigned int j=0; j<n; j++) {
APrep[i][j] = A[i][j];
BPrep[i][j] = B[i][j];
}
}
strassenR(APrep, BPrep, CPrep, m);
for(unsigned int i=0; i<n; i++) {
for (unsigned int j=0; j<n; j++) {
C[i][j] = CPrep[i][j];
}
}
}
void sum(vector< vector<int> > &A,
vector< vector<int> > &B,
vector< vector<int> > &C, int t) {
int i, j;
for (i = 0; i < t; i++) {
for (j = 0; j < t; j++) {
C[i][j] = A[i][j] + B[i][j];
}
}
}
void subtract(vector< vector<int> > &A,
vector< vector<int> > &B,
vector< vector<int> > &C, int t) {
int i, j;
for (i = 0; i < t; i++) {
for (j = 0; j < t; j++) {
C[i][j] = A[i][j] - B[i][j];
}
}
}
int getMatrixSize(string filename) {
string line;
ifstream infile;
infile.open (filename.c_str());
getline(infile, line);
return count(line.begin(), line.end(), '\t') + 1;
}
void read(string filename, vector< vector<int> > &A, vector< vector<int> > &B) {
string line;
FILE* matrixfile = freopen(filename.c_str(), "r", stdin);
if (matrixfile == 0) {
cerr << "Could not read file " << filename << endl;
return;
}
int i = 0, j, a;
while (getline(cin, line) && !line.empty()) {
istringstream iss(line);
j = 0;
while (iss >> a) {
A[i][j] = a;
j++;
}
i++;
}
i = 0;
while (getline(cin, line)) {
istringstream iss(line);
j = 0;
while (iss >> a) {
B[i][j] = a;
j++;
}
i++;
}
fclose (matrixfile);
}
void printMatrix(vector< vector<int> > matrix, int n) {
for (int i=0; i < n; i++) {
for (int j=0; j < n; j++) {
if (j != 0) {
cout << "\t";
}
cout << matrix[i][j];
}
cout << endl;
}
}
int main (int argc, char* argv[]) {
string filename;
if (argc < 3) {
filename = "2000.in";
} else {
filename = argv[2];
}
if (argc < 5) {
size = 16;
} else {
size = atoi(argv[4]);
}
int n = getMatrixSize(filename);
vector<int> inner (n);
vector< vector<int> > A(n, inner), B(n, inner), C(n, inner);
read (filename, A, B);
strassen(A, B, C, n);
printMatrix(C, n);
return 0;
}
T(N) = 7T(N/2) + ON2
From Master's Theorem, time complexity of above method is
O(NLog7) which is approximately O(N2.8074)
Frequently Asked Questions
Well, there are a lot of ways although another interesting one is matrix chain multiplication.
When we are talking about a large value of n, we will have nearly 12.5% increment in speed.
Well, this is applicable till two matrices are valid for multiplication.
This is quite a long article but I hope this is helpful to an aspiring developer or programmer.
By Aniruddha Guin
Leave a Reply