CUDA Networks
neural_network_get_predictions.cu
Go to the documentation of this file.
1 /**
2  * @file neural_network_get_predictions.cu
3  * @brief Implementation of the NeuralNetwork::get_predictions method.
4  */
5 #include "neural_network.h"
6 
8  // Get the argmax of A2 along axis 0 (column-wise)
9  return A2.argmax();
10 }
Vector argmax() const
Computes the argmax of each column in the matrix.
Vector get_predictions() const
Get predictions from the output layer (A2)
Represents a vector with GPU-accelerated operations.
Definition: vector.h:13
Defines the NeuralNetwork class for a simple feedforward neural network.