from mpi4py import MPI
import numpy as np
import time

# Function to perform matrix multiplication
def matrix_multiply(a, b):
    rows_a, cols_a = a.shape
    rows_b, cols_b = b.shape

    if cols_a != rows_b:
        raise ValueError("Cannot perform matrix multiplication. Invalid dimensions.")

    result = np.zeros((rows_a, cols_b), dtype=np.float32)

    for i in range(rows_a):
        for j in range(cols_b):
            for k in range(cols_a):
                result[i, j] += a[i, k] * b[k, j]

    return result

# Initialize MPI
rank = comm.Get_rank()
size = comm.Get_size()

# Matrix size
matrix_sizes = [100, 200, 400, 800]

for size in matrix_sizes:
    # Create matrices on process 0
    if rank == 0:
        a = np.random.rand(size, size).astype(np.float32)
        b = np.random.rand(size, size).astype(np.float32)
        a = None
        b = None

    # Scatter the matrices to all processes
    a = comm.scatter(a, root=0)
    b = comm.scatter(b, root=0)

    # Perform local matrix multiplication
    local_result = matrix_multiply(a, b)

    # Gather all local results on process 0
    result = comm.gather(local_result, root=0)

    # Combine the results on process 0
    if rank == 0:
        final_result = np.sum(result, axis=0)

        # Print the result if desired
        # print(final_result)

# Measure and compare execution times
if rank == 0:
    sequential_times = []
    parallel_times = []

    for size in matrix_sizes:
        # Sequential matrix multiplication
        a = np.random.rand(size, size).astype(np.float32)
        b = np.random.rand(size, size).astype(np.float32)

        start_time = time.time()
        sequential_result = matrix_multiply(a, b)
        end_time = time.time()
        sequential_times.append(end_time - start_time)

        # Parallel matrix multiplication
        start_time = time.time()
        # Create matrices on process 0
        a = np.random.rand(size, size).astype(np.float32)
        b = np.random.rand(size, size).astype(np.float32)

        # Scatter the matrices to all processes
        a = comm.scatter(a, root=0)
        b = comm.scatter(b, root=0)

        # Perform local matrix multiplication
        local_result = matrix_multiply(a, b)

        # Gather all local results on process 0
        result = comm.gather(local_result, root=0)

        # Combine the results on process 0
        final_result = np.sum(result, axis=0)
        end_time = time.time()
        parallel_times.append(end_time - start_time)

    # Print the execution times
    print("Sequential Execution Times:", sequential_times)
    print("Parallel Execution Times:", parallel_times)

