Branch data Line data Source code
1 : : /* SPDX-License-Identifier: BSD-3-Clause
2 : : * Copyright (c) 2022 Marvell.
3 : : */
4 : :
5 : : #include "mldev_utils_scalar.h"
6 : :
7 : : /* Description:
8 : : * This file implements scalar versions of Machine Learning utility functions used to convert data
9 : : * types from higher precision to lower precision and vice-versa, except bfloat16.
10 : : */
11 : :
12 : : int
13 : 0 : rte_ml_io_float32_to_int8(float scale, uint64_t nb_elements, void *input, void *output)
14 : : {
15 : : float *input_buffer;
16 : : int8_t *output_buffer;
17 : : uint64_t i;
18 : : int i32;
19 : :
20 [ # # # # ]: 0 : if ((scale == 0) || (nb_elements == 0) || (input == NULL) || (output == NULL))
21 : : return -EINVAL;
22 : :
23 : : input_buffer = (float *)input;
24 : : output_buffer = (int8_t *)output;
25 : :
26 [ # # ]: 0 : for (i = 0; i < nb_elements; i++) {
27 : 0 : i32 = (int32_t)round((*input_buffer) * scale);
28 : :
29 : : if (i32 < INT8_MIN)
30 : : i32 = INT8_MIN;
31 : :
32 : : if (i32 > INT8_MAX)
33 : : i32 = INT8_MAX;
34 : :
35 : 0 : *output_buffer = (int8_t)i32;
36 : :
37 : 0 : input_buffer++;
38 : 0 : output_buffer++;
39 : : }
40 : :
41 : : return 0;
42 : : }
43 : :
44 : : int
45 : 0 : rte_ml_io_int8_to_float32(float scale, uint64_t nb_elements, void *input, void *output)
46 : : {
47 : : int8_t *input_buffer;
48 : : float *output_buffer;
49 : : uint64_t i;
50 : :
51 [ # # # # ]: 0 : if ((scale == 0) || (nb_elements == 0) || (input == NULL) || (output == NULL))
52 : : return -EINVAL;
53 : :
54 : : input_buffer = (int8_t *)input;
55 : : output_buffer = (float *)output;
56 : :
57 [ # # ]: 0 : for (i = 0; i < nb_elements; i++) {
58 : 0 : *output_buffer = scale * (float)(*input_buffer);
59 : :
60 : 0 : input_buffer++;
61 : 0 : output_buffer++;
62 : : }
63 : :
64 : : return 0;
65 : : }
66 : :
67 : : int
68 : 0 : rte_ml_io_float32_to_uint8(float scale, uint64_t nb_elements, void *input, void *output)
69 : : {
70 : : float *input_buffer;
71 : : uint8_t *output_buffer;
72 : : int32_t i32;
73 : : uint64_t i;
74 : :
75 [ # # # # ]: 0 : if ((scale == 0) || (nb_elements == 0) || (input == NULL) || (output == NULL))
76 : : return -EINVAL;
77 : :
78 : : input_buffer = (float *)input;
79 : : output_buffer = (uint8_t *)output;
80 : :
81 [ # # ]: 0 : for (i = 0; i < nb_elements; i++) {
82 : 0 : i32 = (int32_t)round((*input_buffer) * scale);
83 : :
84 : : if (i32 < 0)
85 : : i32 = 0;
86 : :
87 : : if (i32 > UINT8_MAX)
88 : : i32 = UINT8_MAX;
89 : :
90 : 0 : *output_buffer = (uint8_t)i32;
91 : :
92 : 0 : input_buffer++;
93 : 0 : output_buffer++;
94 : : }
95 : :
96 : : return 0;
97 : : }
98 : :
99 : : int
100 : 0 : rte_ml_io_uint8_to_float32(float scale, uint64_t nb_elements, void *input, void *output)
101 : : {
102 : : uint8_t *input_buffer;
103 : : float *output_buffer;
104 : : uint64_t i;
105 : :
106 [ # # # # ]: 0 : if ((scale == 0) || (nb_elements == 0) || (input == NULL) || (output == NULL))
107 : : return -EINVAL;
108 : :
109 : : input_buffer = (uint8_t *)input;
110 : : output_buffer = (float *)output;
111 : :
112 [ # # ]: 0 : for (i = 0; i < nb_elements; i++) {
113 : 0 : *output_buffer = scale * (float)(*input_buffer);
114 : :
115 : 0 : input_buffer++;
116 : 0 : output_buffer++;
117 : : }
118 : :
119 : : return 0;
120 : : }
121 : :
122 : : int
123 : 0 : rte_ml_io_float32_to_int16(float scale, uint64_t nb_elements, void *input, void *output)
124 : : {
125 : : float *input_buffer;
126 : : int16_t *output_buffer;
127 : : int32_t i32;
128 : : uint64_t i;
129 : :
130 [ # # # # ]: 0 : if ((scale == 0) || (nb_elements == 0) || (input == NULL) || (output == NULL))
131 : : return -EINVAL;
132 : :
133 : : input_buffer = (float *)input;
134 : : output_buffer = (int16_t *)output;
135 : :
136 [ # # ]: 0 : for (i = 0; i < nb_elements; i++) {
137 : 0 : i32 = (int32_t)round((*input_buffer) * scale);
138 : :
139 : : if (i32 < INT16_MIN)
140 : : i32 = INT16_MIN;
141 : :
142 : : if (i32 > INT16_MAX)
143 : : i32 = INT16_MAX;
144 : :
145 : 0 : *output_buffer = (int16_t)i32;
146 : :
147 : 0 : input_buffer++;
148 : 0 : output_buffer++;
149 : : }
150 : :
151 : : return 0;
152 : : }
153 : :
154 : : int
155 : 0 : rte_ml_io_int16_to_float32(float scale, uint64_t nb_elements, void *input, void *output)
156 : : {
157 : : int16_t *input_buffer;
158 : : float *output_buffer;
159 : : uint64_t i;
160 : :
161 [ # # # # ]: 0 : if ((scale == 0) || (nb_elements == 0) || (input == NULL) || (output == NULL))
162 : : return -EINVAL;
163 : :
164 : : input_buffer = (int16_t *)input;
165 : : output_buffer = (float *)output;
166 : :
167 [ # # ]: 0 : for (i = 0; i < nb_elements; i++) {
168 : 0 : *output_buffer = scale * (float)(*input_buffer);
169 : :
170 : 0 : input_buffer++;
171 : 0 : output_buffer++;
172 : : }
173 : :
174 : : return 0;
175 : : }
176 : :
177 : : int
178 : 0 : rte_ml_io_float32_to_uint16(float scale, uint64_t nb_elements, void *input, void *output)
179 : : {
180 : : float *input_buffer;
181 : : uint16_t *output_buffer;
182 : : int32_t i32;
183 : : uint64_t i;
184 : :
185 [ # # # # ]: 0 : if ((scale == 0) || (nb_elements == 0) || (input == NULL) || (output == NULL))
186 : : return -EINVAL;
187 : :
188 : : input_buffer = (float *)input;
189 : : output_buffer = (uint16_t *)output;
190 : :
191 [ # # ]: 0 : for (i = 0; i < nb_elements; i++) {
192 : 0 : i32 = (int32_t)round((*input_buffer) * scale);
193 : :
194 : : if (i32 < 0)
195 : : i32 = 0;
196 : :
197 : : if (i32 > UINT16_MAX)
198 : : i32 = UINT16_MAX;
199 : :
200 : 0 : *output_buffer = (uint16_t)i32;
201 : :
202 : 0 : input_buffer++;
203 : 0 : output_buffer++;
204 : : }
205 : :
206 : : return 0;
207 : : }
208 : :
209 : : int
210 : 0 : rte_ml_io_uint16_to_float32(float scale, uint64_t nb_elements, void *input, void *output)
211 : : {
212 : : uint16_t *input_buffer;
213 : : float *output_buffer;
214 : : uint64_t i;
215 : :
216 [ # # # # ]: 0 : if ((scale == 0) || (nb_elements == 0) || (input == NULL) || (output == NULL))
217 : : return -EINVAL;
218 : :
219 : : input_buffer = (uint16_t *)input;
220 : : output_buffer = (float *)output;
221 : :
222 [ # # ]: 0 : for (i = 0; i < nb_elements; i++) {
223 : 0 : *output_buffer = scale * (float)(*input_buffer);
224 : :
225 : 0 : input_buffer++;
226 : 0 : output_buffer++;
227 : : }
228 : :
229 : : return 0;
230 : : }
231 : :
232 : : /* Convert a single precision floating point number (float32) into a half precision
233 : : * floating point number (float16) using round to nearest rounding mode.
234 : : */
235 : : static uint16_t
236 : 0 : __float32_to_float16_scalar_rtn(float x)
237 : : {
238 : : union float32 f32; /* float32 input */
239 : : uint32_t f32_s; /* float32 sign */
240 : : uint32_t f32_e; /* float32 exponent */
241 : : uint32_t f32_m; /* float32 mantissa */
242 : : uint16_t f16_s; /* float16 sign */
243 : : uint16_t f16_e; /* float16 exponent */
244 : : uint16_t f16_m; /* float16 mantissa */
245 : : uint32_t tbits; /* number of truncated bits */
246 : : uint32_t tmsb; /* MSB position of truncated bits */
247 : : uint32_t m_32; /* temporary float32 mantissa */
248 : : uint16_t m_16; /* temporary float16 mantissa */
249 : : uint16_t u16; /* float16 output */
250 : : int be_16; /* float16 biased exponent, signed */
251 : :
252 : 0 : f32.f = x;
253 : 0 : f32_s = (f32.u & FP32_MASK_S) >> FP32_LSB_S;
254 : 0 : f32_e = (f32.u & FP32_MASK_E) >> FP32_LSB_E;
255 : 0 : f32_m = (f32.u & FP32_MASK_M) >> FP32_LSB_M;
256 : :
257 : : f16_s = f32_s;
258 : : f16_e = 0;
259 : : f16_m = 0;
260 : :
261 [ # # # ]: 0 : switch (f32_e) {
262 : : case (0): /* float32: zero or subnormal number */
263 : : f16_e = 0;
264 : : f16_m = 0; /* convert to zero */
265 : : break;
266 : 0 : case (FP32_MASK_E >> FP32_LSB_E): /* float32: infinity or nan */
267 : : f16_e = FP16_MASK_E >> FP16_LSB_E;
268 [ # # ]: 0 : if (f32_m == 0) { /* infinity */
269 : : f16_m = 0;
270 : : } else { /* nan, propagate mantissa and set MSB of mantissa to 1 */
271 : 0 : f16_m = f32_m >> (FP32_MSB_M - FP16_MSB_M);
272 : 0 : f16_m |= BIT(FP16_MSB_M);
273 : : }
274 : : break;
275 : 0 : default: /* float32: normal number */
276 : : /* compute biased exponent for float16 */
277 : 0 : be_16 = (int)f32_e - FP32_BIAS_E + FP16_BIAS_E;
278 : :
279 : : /* overflow, be_16 = [31-INF], set to infinity */
280 [ # # ]: 0 : if (be_16 >= (int)(FP16_MASK_E >> FP16_LSB_E)) {
281 : : f16_e = FP16_MASK_E >> FP16_LSB_E;
282 : : f16_m = 0;
283 [ # # ]: 0 : } else if ((be_16 >= 1) && (be_16 < (int)(FP16_MASK_E >> FP16_LSB_E))) {
284 : : /* normal float16, be_16 = [1:30]*/
285 : 0 : f16_e = be_16;
286 : 0 : m_16 = f32_m >> (FP32_LSB_E - FP16_LSB_E);
287 : : tmsb = FP32_MSB_M - FP16_MSB_M - 1;
288 [ # # ]: 0 : if ((f32_m & GENMASK_U32(tmsb, 0)) > BIT(tmsb)) {
289 : : /* round: non-zero truncated bits except MSB */
290 : 0 : m_16++;
291 : :
292 : : /* overflow into exponent */
293 [ # # ]: 0 : if (((m_16 & FP16_MASK_E) >> FP16_LSB_E) == 0x1)
294 : 0 : f16_e++;
295 [ # # ]: 0 : } else if ((f32_m & GENMASK_U32(tmsb, 0)) == BIT(tmsb)) {
296 : : /* round: MSB of truncated bits and LSB of m_16 is set */
297 [ # # ]: 0 : if ((m_16 & 0x1) == 0x1) {
298 : 0 : m_16++;
299 : :
300 : : /* overflow into exponent */
301 [ # # ]: 0 : if (((m_16 & FP16_MASK_E) >> FP16_LSB_E) == 0x1)
302 : 0 : f16_e++;
303 : : }
304 : : }
305 : 0 : f16_m = m_16 & FP16_MASK_M;
306 [ # # ]: 0 : } else if ((be_16 >= -(int)(FP16_MSB_M)) && (be_16 < 1)) {
307 : : /* underflow: zero / subnormal, be_16 = [-9:0] */
308 : : f16_e = 0;
309 : :
310 : : /* add implicit leading zero */
311 : 0 : m_32 = f32_m | BIT(FP32_LSB_E);
312 : 0 : tbits = FP32_LSB_E - FP16_LSB_E - be_16 + 1;
313 : 0 : m_16 = m_32 >> tbits;
314 : :
315 : : /* if non-leading truncated bits are set */
316 [ # # ]: 0 : if ((f32_m & GENMASK_U32(tbits - 1, 0)) > BIT(tbits - 1)) {
317 : 0 : m_16++;
318 : :
319 : : /* overflow into exponent */
320 [ # # ]: 0 : if (((m_16 & FP16_MASK_E) >> FP16_LSB_E) == 0x1)
321 : : f16_e++;
322 [ # # ]: 0 : } else if ((f32_m & GENMASK_U32(tbits - 1, 0)) == BIT(tbits - 1)) {
323 : : /* if leading truncated bit is set */
324 [ # # ]: 0 : if ((m_16 & 0x1) == 0x1) {
325 : 0 : m_16++;
326 : :
327 : : /* overflow into exponent */
328 [ # # ]: 0 : if (((m_16 & FP16_MASK_E) >> FP16_LSB_E) == 0x1)
329 : : f16_e++;
330 : : }
331 : : }
332 : 0 : f16_m = m_16 & FP16_MASK_M;
333 [ # # ]: 0 : } else if (be_16 == -(int)(FP16_MSB_M + 1)) {
334 : : /* underflow: zero, be_16 = [-10] */
335 : : f16_e = 0;
336 [ # # ]: 0 : if (f32_m != 0)
337 : : f16_m = 1;
338 : : else
339 : : f16_m = 0;
340 : : } else {
341 : : /* underflow: zero, be_16 = [-INF:-11] */
342 : : f16_e = 0;
343 : : f16_m = 0;
344 : : }
345 : :
346 : : break;
347 : : }
348 : :
349 : 0 : u16 = FP16_PACK(f16_s, f16_e, f16_m);
350 : :
351 : 0 : return u16;
352 : : }
353 : :
354 : : int
355 : 0 : rte_ml_io_float32_to_float16(uint64_t nb_elements, void *input, void *output)
356 : : {
357 : : float *input_buffer;
358 : : uint16_t *output_buffer;
359 : : uint64_t i;
360 : :
361 [ # # # # ]: 0 : if ((nb_elements == 0) || (input == NULL) || (output == NULL))
362 : : return -EINVAL;
363 : :
364 : : input_buffer = (float *)input;
365 : : output_buffer = (uint16_t *)output;
366 : :
367 [ # # ]: 0 : for (i = 0; i < nb_elements; i++) {
368 : 0 : *output_buffer = __float32_to_float16_scalar_rtn(*input_buffer);
369 : :
370 : 0 : input_buffer = input_buffer + 1;
371 : 0 : output_buffer = output_buffer + 1;
372 : : }
373 : :
374 : : return 0;
375 : : }
376 : :
377 : : /* Convert a half precision floating point number (float16) into a single precision
378 : : * floating point number (float32).
379 : : */
380 : : static float
381 : 0 : __float16_to_float32_scalar_rtx(uint16_t f16)
382 : : {
383 : : union float32 f32; /* float32 output */
384 : : uint16_t f16_s; /* float16 sign */
385 : : uint16_t f16_e; /* float16 exponent */
386 : : uint16_t f16_m; /* float16 mantissa */
387 : : uint32_t f32_s; /* float32 sign */
388 : : uint32_t f32_e; /* float32 exponent */
389 : : uint32_t f32_m; /* float32 mantissa*/
390 : : uint8_t shift; /* number of bits to be shifted */
391 : : uint32_t clz; /* count of leading zeroes */
392 : : int e_16; /* float16 exponent unbiased */
393 : :
394 : 0 : f16_s = (f16 & FP16_MASK_S) >> FP16_LSB_S;
395 : 0 : f16_e = (f16 & FP16_MASK_E) >> FP16_LSB_E;
396 : 0 : f16_m = (f16 & FP16_MASK_M) >> FP16_LSB_M;
397 : :
398 : 0 : f32_s = f16_s;
399 [ # # # ]: 0 : switch (f16_e) {
400 : 0 : case (FP16_MASK_E >> FP16_LSB_E): /* float16: infinity or nan */
401 : : f32_e = FP32_MASK_E >> FP32_LSB_E;
402 [ # # ]: 0 : if (f16_m == 0x0) { /* infinity */
403 : : f32_m = f16_m;
404 : : } else { /* nan, propagate mantissa, set MSB of mantissa to 1 */
405 : 0 : f32_m = f16_m;
406 : : shift = FP32_MSB_M - FP16_MSB_M;
407 : 0 : f32_m = (f32_m << shift) & FP32_MASK_M;
408 : 0 : f32_m |= BIT(FP32_MSB_M);
409 : : }
410 : : break;
411 : 0 : case 0: /* float16: zero or sub-normal */
412 : 0 : f32_m = f16_m;
413 [ # # ]: 0 : if (f16_m == 0) { /* zero signed */
414 : : f32_e = 0;
415 : : } else { /* subnormal numbers */
416 : 0 : clz = rte_clz32((uint32_t)f16_m) - sizeof(uint32_t) * 8 + FP16_LSB_E;
417 : 0 : e_16 = (int)f16_e - clz;
418 : 0 : f32_e = FP32_BIAS_E + e_16 - FP16_BIAS_E;
419 : :
420 : 0 : shift = clz + (FP32_MSB_M - FP16_MSB_M) + 1;
421 : 0 : f32_m = (f32_m << shift) & FP32_MASK_M;
422 : : }
423 : : break;
424 : 0 : default: /* normal numbers */
425 : 0 : f32_m = f16_m;
426 : : e_16 = (int)f16_e;
427 : 0 : f32_e = FP32_BIAS_E + e_16 - FP16_BIAS_E;
428 : :
429 : : shift = (FP32_MSB_M - FP16_MSB_M);
430 : 0 : f32_m = (f32_m << shift) & FP32_MASK_M;
431 : : }
432 : :
433 : 0 : f32.u = FP32_PACK(f32_s, f32_e, f32_m);
434 : :
435 : 0 : return f32.f;
436 : : }
437 : :
438 : : int
439 : 0 : rte_ml_io_float16_to_float32(uint64_t nb_elements, void *input, void *output)
440 : : {
441 : : uint16_t *input_buffer;
442 : : float *output_buffer;
443 : : uint64_t i;
444 : :
445 [ # # # # ]: 0 : if ((nb_elements == 0) || (input == NULL) || (output == NULL))
446 : : return -EINVAL;
447 : :
448 : : input_buffer = (uint16_t *)input;
449 : : output_buffer = (float *)output;
450 : :
451 [ # # ]: 0 : for (i = 0; i < nb_elements; i++) {
452 : 0 : *output_buffer = __float16_to_float32_scalar_rtx(*input_buffer);
453 : :
454 : 0 : input_buffer = input_buffer + 1;
455 : 0 : output_buffer = output_buffer + 1;
456 : : }
457 : :
458 : : return 0;
459 : : }
|