Branch data Line data Source code
1 : : /* SPDX-License-Identifier: BSD-3-Clause
2 : : * Copyright (c) 2023 Marvell.
3 : : */
4 : :
5 : : #include <rte_mldev.h>
6 : :
7 : : #include <mldev_utils.h>
8 : :
9 : : #include <roc_api.h>
10 : :
11 : : #include "cnxk_ml_io.h"
12 : :
13 : : inline int
14 : 0 : cnxk_ml_io_quantize_single(struct cnxk_ml_io *input, uint8_t *dbuffer, uint8_t *qbuffer)
15 : : {
16 : : enum rte_ml_io_type qtype;
17 : : enum rte_ml_io_type dtype;
18 : : uint32_t nb_elements;
19 : : float qscale;
20 : : int ret = 0;
21 : :
22 : 0 : dtype = input->dtype;
23 : 0 : qtype = input->qtype;
24 : 0 : qscale = input->scale;
25 : 0 : nb_elements = input->nb_elements;
26 : :
27 [ # # ]: 0 : if (dtype == qtype) {
28 [ # # ]: 0 : rte_memcpy(qbuffer, dbuffer, input->sz_d);
29 : : } else {
30 [ # # # # : 0 : switch (qtype) {
# # ]
31 : 0 : case RTE_ML_IO_TYPE_INT8:
32 : 0 : ret = rte_ml_io_float32_to_int8(qscale, nb_elements, dbuffer, qbuffer);
33 : 0 : break;
34 : 0 : case RTE_ML_IO_TYPE_UINT8:
35 : 0 : ret = rte_ml_io_float32_to_uint8(qscale, nb_elements, dbuffer, qbuffer);
36 : 0 : break;
37 : 0 : case RTE_ML_IO_TYPE_INT16:
38 : 0 : ret = rte_ml_io_float32_to_int16(qscale, nb_elements, dbuffer, qbuffer);
39 : 0 : break;
40 : 0 : case RTE_ML_IO_TYPE_UINT16:
41 : 0 : ret = rte_ml_io_float32_to_uint16(qscale, nb_elements, dbuffer, qbuffer);
42 : 0 : break;
43 : 0 : case RTE_ML_IO_TYPE_FP16:
44 : 0 : ret = rte_ml_io_float32_to_float16(nb_elements, dbuffer, qbuffer);
45 : 0 : break;
46 : 0 : default:
47 : 0 : plt_err("Unsupported qtype : %u", qtype);
48 : : ret = -ENOTSUP;
49 : : }
50 : : }
51 : :
52 : 0 : return ret;
53 : : }
54 : :
55 : : inline int
56 : 0 : cnxk_ml_io_dequantize_single(struct cnxk_ml_io *output, uint8_t *qbuffer, uint8_t *dbuffer)
57 : : {
58 : : enum rte_ml_io_type qtype;
59 : : enum rte_ml_io_type dtype;
60 : : uint32_t nb_elements;
61 : : float dscale;
62 : : int ret = 0;
63 : :
64 : 0 : dtype = output->dtype;
65 : 0 : qtype = output->qtype;
66 : 0 : dscale = output->scale;
67 : 0 : nb_elements = output->nb_elements;
68 : :
69 [ # # ]: 0 : if (dtype == qtype) {
70 [ # # ]: 0 : rte_memcpy(dbuffer, qbuffer, output->sz_q);
71 : : } else {
72 [ # # # # : 0 : switch (qtype) {
# # ]
73 : 0 : case RTE_ML_IO_TYPE_INT8:
74 : 0 : ret = rte_ml_io_int8_to_float32(dscale, nb_elements, qbuffer, dbuffer);
75 : 0 : break;
76 : 0 : case RTE_ML_IO_TYPE_UINT8:
77 : 0 : ret = rte_ml_io_uint8_to_float32(dscale, nb_elements, qbuffer, dbuffer);
78 : 0 : break;
79 : 0 : case RTE_ML_IO_TYPE_INT16:
80 : 0 : ret = rte_ml_io_int16_to_float32(dscale, nb_elements, qbuffer, dbuffer);
81 : 0 : break;
82 : 0 : case RTE_ML_IO_TYPE_UINT16:
83 : 0 : ret = rte_ml_io_uint16_to_float32(dscale, nb_elements, qbuffer, dbuffer);
84 : 0 : break;
85 : 0 : case RTE_ML_IO_TYPE_FP16:
86 : 0 : ret = rte_ml_io_float16_to_float32(nb_elements, qbuffer, dbuffer);
87 : 0 : break;
88 : 0 : default:
89 : 0 : plt_err("Unsupported qtype: %u", qtype);
90 : : ret = -ENOTSUP;
91 : : }
92 : : }
93 : :
94 : 0 : return ret;
95 : : }
|