From 1c47da5ebfc8505b08805f21f7efae75a2bfb32a Mon Sep 17 00:00:00 2001 From: Vulcan <93451215+trholding@users.noreply.github.com> Date: Sat, 20 Jul 2024 19:47:46 +0530 Subject: [PATCH] Update runq.c runq - speed up rmsnorm with OpenMP / OpenACC --- runq.c | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/runq.c b/runq.c index 1a32da0..d3cfc3a 100644 --- a/runq.c +++ b/runq.c @@ -135,9 +135,11 @@ __static_yoink("zipos"); #ifdef OPENMP #define ACCELS() MK_PRAGMA(omp parallel for) #define ACCEL(...) MK_PRAGMA(omp parallel for private(__VA_ARGS__)) +#define ACCELRD(VAR) MK_PRAGMA(omp parallel for reduction(+:VAR)) #elif defined(OPENACC) #define ACCELS() MK_PRAGMA(acc parallel loop) #define ACCEL(...) MK_PRAGMA(acc parallel loop private(__VA_ARGS__)) +#define ACCELRD(VAR) MK_PRAGMA(acc parallel loop reduction(+:VAR)) #endif // ---------------------------------------------------------------------------- @@ -504,6 +506,11 @@ void rmsnorm(float* o, float* x, float* weight, int size) { #ifdef BLAS ss = cblas_sdot(size, x, 1.0f, x, 1.0f); #else +// END L2E Addition +// L2E Addition + #ifdef ACCEL + ACCELRD(ss) // OMP/OACC Macro + #endif // END L2E Addition for (int j = 0; j < size; j++) { ss += x[j] * x[j]; @@ -515,6 +522,11 @@ void rmsnorm(float* o, float* x, float* weight, int size) { ss += 1e-5f; ss = 1.0f / sqrtf(ss); // normalize and scale +// L2E Addition + #ifdef ACCEL + ACCELS() // OMP/OACC Macro + #endif +// END L2E Addition for (int j = 0; j < size; j++) { o[j] = weight[j] * (ss * x[j]); }