CUDA Networks
matrix_argmax.cu
Go to the documentation of this file.
1 /**
2  * @file matrix_argmax.cu
3  * @brief GPU implementation of the column-wise argmax function for matrices.
4  */
5 
6 #include "matrix.h"
7 #include <cuda_runtime.h>
8 
9 /**
10  * @brief CUDA kernel for computing the argmax of each column in a matrix.
11  * @param m Pointer to the matrix data on the GPU.
12  * @param result Pointer to the result vector on the GPU.
13  * @param rows Number of rows in the matrix.
14  * @param cols Number of columns in the matrix.
15  */
16 __global__ void argmax_GPU(const double *m, double *result, int rows, int cols) {
17  // Determine the column this thread is responsible for
18  int col = blockIdx.x * blockDim.x + threadIdx.x;
19 
20  // Proceed if the column index is within matrix bounds
21  if (col < cols) {
22  // Initialize max_val with the first element in the column and max_idx to the first row
23  double max_val = m[col];
24  int max_idx = 0;
25 
26  // Iterate through the rows to find the maximum value in the column
27  for (int row = 1; row < rows; row++) {
28  double val = m[row * cols + col]; // Access element (row, col)
29  if (val > max_val) {
30  max_val = val;
31  max_idx = row;
32  }
33  }
34 
35  // Store the index of the maximum value in the result vector for this column
36  result[col] = static_cast<double>(max_idx);
37  }
38 }
39 
40 /**
41  * @brief Launches the argmax_GPU kernel to perform column-wise argmax on the matrix.
42  * @return A Vector containing the row indices of the maximum values for each column.
43  */
45  // Create a result vector on the device
46  Vector result(cols);
47 
48  // Define grid and block sizes
49  int threadsPerBlock = 256;
50  int blocksPerGrid = (cols + threadsPerBlock - 1) / threadsPerBlock;
51 
52  // Launch the argmax kernel on the device
53  argmax_GPU<<<blocksPerGrid, threadsPerBlock>>>(d_data, result.get_data(), rows, cols);
54 
55  // Ensure the kernel execution is complete
56  cudaDeviceSynchronize();
57 
58  return result;
59 }
Vector argmax() const
Computes the argmax of each column in the matrix.
Represents a vector with GPU-accelerated operations.
Definition: vector.h:13
double * get_data() const
Get the raw data pointer of the vector.
Defines the Matrix class for GPU-accelerated matrix operations.
__global__ void argmax_GPU(const double *m, double *result, int rows, int cols)
CUDA kernel for computing the argmax of each column in a matrix.