mirror of
https://github.com/trholding/llama2.c.git
synced 2026-02-06 11:26:53 +00:00
Update runq.c
runq - speed up rmsnorm with OpenMP / OpenACC
This commit is contained in:
parent
16e223fbca
commit
1c47da5ebf
12
runq.c
12
runq.c
@ -135,9 +135,11 @@ __static_yoink("zipos");
|
||||
#ifdef OPENMP
|
||||
#define ACCELS() MK_PRAGMA(omp parallel for)
|
||||
#define ACCEL(...) MK_PRAGMA(omp parallel for private(__VA_ARGS__))
|
||||
#define ACCELRD(VAR) MK_PRAGMA(omp parallel for reduction(+:VAR))
|
||||
#elif defined(OPENACC)
|
||||
#define ACCELS() MK_PRAGMA(acc parallel loop)
|
||||
#define ACCEL(...) MK_PRAGMA(acc parallel loop private(__VA_ARGS__))
|
||||
#define ACCELRD(VAR) MK_PRAGMA(acc parallel loop reduction(+:VAR))
|
||||
#endif
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
@ -504,6 +506,11 @@ void rmsnorm(float* o, float* x, float* weight, int size) {
|
||||
#ifdef BLAS
|
||||
ss = cblas_sdot(size, x, 1.0f, x, 1.0f);
|
||||
#else
|
||||
// END L2E Addition
|
||||
// L2E Addition
|
||||
#ifdef ACCEL
|
||||
ACCELRD(ss) // OMP/OACC Macro
|
||||
#endif
|
||||
// END L2E Addition
|
||||
for (int j = 0; j < size; j++) {
|
||||
ss += x[j] * x[j];
|
||||
@ -515,6 +522,11 @@ void rmsnorm(float* o, float* x, float* weight, int size) {
|
||||
ss += 1e-5f;
|
||||
ss = 1.0f / sqrtf(ss);
|
||||
// normalize and scale
|
||||
// L2E Addition
|
||||
#ifdef ACCEL
|
||||
ACCELS() // OMP/OACC Macro
|
||||
#endif
|
||||
// END L2E Addition
|
||||
for (int j = 0; j < size; j++) {
|
||||
o[j] = weight[j] * (ss * x[j]);
|
||||
}
|
||||
|
||||
Loading…
Reference in New Issue
Block a user