Update runq.c

runq - speed up rmsnorm with OpenMP / OpenACC
This commit is contained in:
Vulcan 2024-07-20 19:47:46 +05:30
parent 16e223fbca
commit 1c47da5ebf

12
runq.c
View File

@ -135,9 +135,11 @@ __static_yoink("zipos");
#ifdef OPENMP
#define ACCELS() MK_PRAGMA(omp parallel for)
#define ACCEL(...) MK_PRAGMA(omp parallel for private(__VA_ARGS__))
#define ACCELRD(VAR) MK_PRAGMA(omp parallel for reduction(+:VAR))
#elif defined(OPENACC)
#define ACCELS() MK_PRAGMA(acc parallel loop)
#define ACCEL(...) MK_PRAGMA(acc parallel loop private(__VA_ARGS__))
#define ACCELRD(VAR) MK_PRAGMA(acc parallel loop reduction(+:VAR))
#endif
// ----------------------------------------------------------------------------
@ -504,6 +506,11 @@ void rmsnorm(float* o, float* x, float* weight, int size) {
#ifdef BLAS
ss = cblas_sdot(size, x, 1.0f, x, 1.0f);
#else
// END L2E Addition
// L2E Addition
#ifdef ACCEL
ACCELRD(ss) // OMP/OACC Macro
#endif
// END L2E Addition
for (int j = 0; j < size; j++) {
ss += x[j] * x[j];
@ -515,6 +522,11 @@ void rmsnorm(float* o, float* x, float* weight, int size) {
ss += 1e-5f;
ss = 1.0f / sqrtf(ss);
// normalize and scale
// L2E Addition
#ifdef ACCEL
ACCELS() // OMP/OACC Macro
#endif
// END L2E Addition
for (int j = 0; j < size; j++) {
o[j] = weight[j] * (ss * x[j]);
}