From a3f8cb99b889a8d529fdab91a4464884ee7b5e27 Mon Sep 17 00:00:00 2001
From: rickyskv <schiavon@eurecom.fr>
Date: Thu, 25 Jun 2020 20:07:52 +0200
Subject: [PATCH] polar decoder 8 bits solved problems with the structure
 generated testbench option it runs, time decreases increasing SNR but same
 BLER is it counting correctly the BLER?

---
 openair1/PHY/CODING/TESTBENCH/polartest.c     | 27 ++++++++--
 .../CODING/nrPolar_tools/nr_polar_decoder.c   | 10 ++--
 .../nrPolar_tools/nr_polar_decoding_tools.c   | 50 +++++++++----------
 .../PHY/CODING/nrPolar_tools/nr_polar_defs.h  | 38 ++++++++++++++
 .../nrPolar_tools/nr_polar_procedures.c       | 21 ++++++++
 .../nrPolar_tools/nr_polar_rate_match.c       |  4 +-
 openair1/PHY/CODING/nr_polar_init.c           |  1 +
 7 files changed, 114 insertions(+), 37 deletions(-)

diff --git a/openair1/PHY/CODING/TESTBENCH/polartest.c b/openair1/PHY/CODING/TESTBENCH/polartest.c
index d9adf3e24f..bb4ccd9f75 100644
--- a/openair1/PHY/CODING/TESTBENCH/polartest.c
+++ b/openair1/PHY/CODING/TESTBENCH/polartest.c
@@ -26,6 +26,7 @@ int main(int argc, char *argv[])
 {
   //Default simulation values (Aim for iterations = 1000000.)
   int decoder_int16=0;
+  int16_t decoder_int8=0;
   int itr, iterations = 1000, arguments, polarMessageType = 0; //0=PBCH, 1=DCI, 2=UCI
   double SNRstart = -20.0, SNRstop = 0.0, SNRinc= 0.5; //dB
   double SNR, SNR_lin;
@@ -37,7 +38,7 @@ int main(int argc, char *argv[])
   uint8_t aggregation_level = 8, decoderListSize = 8, logFlag = 0;
   uint16_t rnti=0;
 
-  while ((arguments = getopt (argc, argv, "s:d:f:m:i:l:a:p:hqgFL:k:")) != -1)
+  while ((arguments = getopt (argc, argv, "s:d:f:m:i:l:a:p:q:hgFL:k:")) != -1)
     switch (arguments) {
     case 's':
     	SNRstart = atof(optarg);
@@ -68,6 +69,13 @@ int main(int argc, char *argv[])
 
     case 'q':
     	decoder_int16 = 1;
+	decoder_int8=atoi(optarg);
+    	if (decoder_int8 != 8 && decoder_int8 != 1 && decoder_int8 != 0 && decoder_int8 != 16) {
+    		printf("Illegal argument for option -q: %d \nPossible values: 0 or 16 to use 16-bit decoder, 1 or 8 to use the 8-bit decoder\n",decoder_int8);
+    		exit(-1);
+    	}
+    	if (decoder_int8 == 8) decoder_int8=1;
+	if (decoder_int8 == 16) decoder_int8=0;
     	break;
 
     case 'g':
@@ -99,7 +107,7 @@ int main(int argc, char *argv[])
 
     case 'h':
       printf("./polartest\nOptions\n-h Print this help\n-s SNRstart (dB)\n-d SNRinc (dB)\n-f SNRstop (dB)\n-m [0=PBCH|1=DCI|2=UCI]\n"
-             "-i Number of iterations\n-l decoderListSize\n-q Flag for optimized coders usage\n-F Flag for test results logging\n"
+             "-i Number of iterations\n-l decoderListSize\n-q Flag for optimized coders usage [0 = 16-bit, 1 = 8-bit]\n-F Flag for test results logging\n"
     		 "-L aggregation level (for DCI)\n-k packet_length (bits) for DCI/UCI\n");
       exit(-1);
       break;
@@ -175,8 +183,9 @@ if (logFlag){
   double modulatedInput[coderLength]; //channel input
   double channelOutput[coderLength];  //add noise
   int16_t channelOutput_int16[coderLength];
+  int8_t channelOutput_int8[coderLength];
 
-  t_nrPolar_params *currentPtr = nr_polar_params(polarMessageType, testLength, aggregation_level, 1, NULL);
+  t_nrPolar_params *currentPtr = nr_polar_params(polarMessageType, testLength, aggregation_level, decoder_int8+1, NULL);
 
 #ifdef DEBUG_DCI_POLAR_PARAMS
   uint32_t dci_pdu[4];
@@ -272,16 +281,28 @@ if (logFlag){
     	  channelOutput[i] = modulatedInput[i] + (gaussdouble(0.0,1.0) * (1/sqrt(2*SNR_lin)));
 
     	  if (decoder_int16==1) {
+		if(decoder_int8==0){
     		  if (channelOutput[i] > 15) channelOutput_int16[i] = 127;
     		  else if (channelOutput[i] < -16) channelOutput_int16[i] = -128;
     		  else channelOutput_int16[i] = (int16_t) (8*channelOutput[i]);
+		}
+		else{
+    		  if (channelOutput[i] > 15) channelOutput_int8[i] = 63;
+    		  else if (channelOutput[i] < -16) channelOutput_int8[i] = -64;
+    		  else channelOutput_int8[i] = (int8_t) (4*channelOutput[i]);
+		}
     	  }
       }
 
       start_meas(&timeDecoder);
 
       if (decoder_int16==1) {
+	if(decoder_int8==0){
     	  decoderState = polar_decoder_int16(channelOutput_int16, (uint64_t *)estimatedOutput, 0, currentPtr);
+	}
+	else{
+    	  decoderState = polar_decoder_int8(channelOutput_int8, (uint64_t *)estimatedOutput, 0, currentPtr);
+	}
       } else { //0 --> PBCH, 1 --> DCI, -1 --> UCI
     	  if (polarMessageType == 0) {
     		  decoderState = polar_decoder(channelOutput,
diff --git a/openair1/PHY/CODING/nrPolar_tools/nr_polar_decoder.c b/openair1/PHY/CODING/nrPolar_tools/nr_polar_decoder.c
index cc9142d9c3..99017035d7 100644
--- a/openair1/PHY/CODING/nrPolar_tools/nr_polar_decoder.c
+++ b/openair1/PHY/CODING/nrPolar_tools/nr_polar_decoder.c
@@ -708,20 +708,18 @@ uint32_t polar_decoder_int16(int16_t *input,
 
 // ############### INT 8 #########################
 
-uint32_t polar_decoder_int8(int16_t *input,
+uint32_t polar_decoder_int8(int8_t *input,
                              uint64_t *out,
                              uint8_t ones_flag,
                              const t_nrPolar_params *polarParams)
 {
-  int16_t d_tilde[polarParams->N];// = malloc(sizeof(double) * polarParams->N);
-  nr_polar_rate_matching_int16(input, d_tilde, polarParams->rate_matching_pattern, polarParams->K, polarParams->N, polarParams->encoderLength);
-
+  int8_t d_tilde[polarParams->N];// = malloc(sizeof(double) * polarParams->N);
+  nr_polar_rate_matching_int8(input, d_tilde, polarParams->rate_matching_pattern, polarParams->K, polarParams->N, polarParams->encoderLength);
   for (int i=0; i<polarParams->N; i++) {
     if (d_tilde[i]<-128) d_tilde[i]=-128;
     else if (d_tilde[i]>127) d_tilde[i]=128;
   }
-
-  memcpy((void *)&polarParams->tree.root->alpha[0],(void *)&d_tilde[0],sizeof(int16_t)*polarParams->N);
+  memcpy((void *)&polarParams->tree.root->alpha8[0],(void *)&d_tilde[0],sizeof(int8_t)*polarParams->N);
   generic_polar_decoder_int8(polarParams,polarParams->tree.root);
   //Extract the information bits (没 to 膲)
   uint64_t Cprime[4]= {0,0,0,0};
diff --git a/openair1/PHY/CODING/nrPolar_tools/nr_polar_decoding_tools.c b/openair1/PHY/CODING/nrPolar_tools/nr_polar_decoding_tools.c
index 199c2c8cbe..a924caae80 100644
--- a/openair1/PHY/CODING/nrPolar_tools/nr_polar_decoding_tools.c
+++ b/openair1/PHY/CODING/nrPolar_tools/nr_polar_decoding_tools.c
@@ -518,9 +518,9 @@ decoder_node_t *new_decoder_node_int8(int first_leaf_index, int level) {
   node->left=(decoder_node_t *)NULL;
   node->right=(decoder_node_t *)NULL;
   node->all_frozen=0;
-  node->alpha  = (int8_t*)malloc16(node->Nv*sizeof(int8_t));
-  node->beta   = (int8_t*)malloc16(node->Nv*sizeof(int8_t));
-  memset((void*)node->beta,-1,node->Nv*sizeof(int8_t));
+  node->alpha8  = (int8_t*)malloc16(node->Nv*sizeof(int8_t));
+  node->beta8   = (int8_t*)malloc16(node->Nv*sizeof(int8_t));
+  memset((void*)node->beta8,-1,node->Nv*sizeof(int8_t));
   
   return(node);
 }
@@ -589,9 +589,9 @@ void build_decoder_tree_int8(t_nrPolar_params *polarParams)
 #endif
 
 void applyFtoleft_int8(const t_nrPolar_params *pp, decoder_node_t *node) {
-  int8_t *alpha_v=node->alpha;
-  int8_t *alpha_l=node->left->alpha;
-  int8_t *betal = node->left->beta;
+  int8_t *alpha_v=node->alpha8;
+  int8_t *alpha_l=node->left->alpha8;
+  int8_t *betal = node->left->beta8;
   int8_t a,b,absa,absb,maska,maskb,minabs;
 
 #ifdef DEBUG_NEW_IMPL
@@ -635,7 +635,7 @@ void applyFtoleft_int8(const t_nrPolar_params *pp, decoder_node_t *node) {
       b64       =((__m64*)alpha_v)[1];
       absa64    =_mm_abs_pi8(a64);
       absb64    =_mm_abs_pi8(b64);
-      minabs64  =_mm_min_pi8(absa64,absb64);
+      minabs64  =_mm_min_pu8(absa64,absb64);
       *((__m64*)alpha_l) =_mm_sign_pi8(minabs64,_mm_sign_pi8(a64,b64));
     }
     else
@@ -663,7 +663,7 @@ void applyFtoleft_int8(const t_nrPolar_params *pp, decoder_node_t *node) {
       b64       =((__m64*)alpha_v)[1];
       absa64    =_mm_abs_pi8(a64);
       absb64    =_mm_abs_pi8(b64);
-      minabs64  =_mm_min_pi8(absa64,absb64);
+      minabs64  =_mm_min_pu8(absa64,absb64);
       *((__m64*)alpha_l) =_mm_sign_pi8(minabs64,_mm_sign_epi8(a64,b64));
     }
 
@@ -697,10 +697,10 @@ void applyFtoleft_int8(const t_nrPolar_params *pp, decoder_node_t *node) {
 
 void applyGtoright_int8(const t_nrPolar_params *pp,decoder_node_t *node) {
 
-  int8_t *alpha_v=node->alpha;
-  int8_t *alpha_r=node->right->alpha;
-  int8_t *betal = node->left->beta;
-  int8_t *betar = node->right->beta;
+  int8_t *alpha_v=node->alpha8;
+  int8_t *alpha_r=node->right->alpha8;
+  int8_t *betal = node->left->beta8;
+  int8_t *betar = node->right->beta8;
 
 #ifdef DEBUG_NEW_IMPL
   printf("applyGtoright %d, Nv %d (level %d), (leaf %d, AF %d)\n",node->first_leaf_index,node->Nv,node->level,node->right->leaf,node->right->all_frozen);
@@ -713,10 +713,7 @@ void applyGtoright_int8(const t_nrPolar_params *pp,decoder_node_t *node) {
       int avx2len = node->Nv/2/32;
       
       for (int i=0;i<avx2len;i++) {
-	((__m256i *)alpha_r)[i] = 
-	  _mm256_subs_epi8(((__m256i *)alpha_v)[i+avx2len],
-			    _mm256_sign_epi8(((__m256i *)alpha_v)[i],
-					      ((__m256i *)betal)[i]));	
+	((__m256i *)alpha_r)[i] = _mm256_subs_epi8(((__m256i *)alpha_v)[i+avx2len], _mm256_sign_epi8(((__m256i *)alpha_v)[i], ((__m256i *)betal)[i]));	
       }
     }
     else if (avx2mod == 16) {
@@ -757,20 +754,20 @@ void applyGtoright_int8(const t_nrPolar_params *pp,decoder_node_t *node) {
   }
 }
 
-int8_t all1[8] = {1,1,1,1,1,1,1,1};
+int8_t all1_int8[8] = {1,1,1,1,1,1,1,1};
 
 void computeBeta_int8(const t_nrPolar_params *pp,decoder_node_t *node) {
 
-  int8_t *betav = node->beta;
-  int8_t *betal = node->left->beta;
-  int8_t *betar = node->right->beta;
+  int8_t *betav = node->beta8;
+  int8_t *betal = node->left->beta8;
+  int8_t *betar = node->right->beta8;
 #ifdef DEBUG_NEW_IMPL
   printf("Computing beta @ level %d first_leaf_index %d (all_frozen %d)\n",node->level,node->first_leaf_index,node->left->all_frozen);
 #endif
   if (node->left->all_frozen==0) { // if left node is not aggregation of frozen bits
 #if defined(__AVX2__) 
     int avx2mod = (node->Nv/2)&31;
-    register __m256i allones=*((__m256i*)all1);
+    register __m256i allones=*((__m256i*)all1_int8);
     if (avx2mod == 0) {
       int avx2len = node->Nv/2/32;
       for (int i=0;i<avx2len;i++) {
@@ -780,11 +777,11 @@ void computeBeta_int8(const t_nrPolar_params *pp,decoder_node_t *node) {
     }
     else if (avx2mod == 16) {
       ((__m128i*)betav)[0] = _mm_or_si128(_mm_cmpeq_epi8(((__m128i*)betar)[0],
-							  ((__m128i*)betal)[0]),*((__m128i*)all1));
+							  ((__m128i*)betal)[0]),*((__m128i*)all1_int8));
     }
     else if (avx2mod == 8) {
       ((__m64*)betav)[0] = _mm_or_si64(_mm_cmpeq_pi8(((__m64*)betar)[0],
-						      ((__m64*)betal)[0]),*((__m64*)all1));
+						      ((__m64*)betal)[0]),*((__m64*)all1_int8));
     }
     else
 #else
@@ -792,13 +789,13 @@ void computeBeta_int8(const t_nrPolar_params *pp,decoder_node_t *node) {
     int ssr4mod = (node->Nv/2)&15;
     if (ssr4mod == 0) {
       int ssr4len = node->Nv/2/16;
-      register __m128i allones=*((__m128i*)all1);
+      register __m128i allones=*((__m128i*)all1_int8);
       for (int i=0;i<sse4len;i++) {
       ((__m128i*)betav)[i] = _mm_or_si128(_mm_cmpeq_epi8(((__m128i*)betar)[i], ((__m128i*)betal)[i]),allones);
       }
     }
     else if (sse4mod == 8) {
-      ((__m64*)betav)[0] = _mm_or_si64(_mm_cmpeq_pi8(((__m64*)betar)[0], ((__m64*)betal)[0]),*((__m64*)all1));
+      ((__m64*)betav)[0] = _mm_or_si64(_mm_cmpeq_pi8(((__m64*)betar)[0], ((__m64*)betal)[0]),*((__m64*)all1_int8));
     }
     else
 #endif
@@ -817,11 +814,12 @@ void generic_polar_decoder_int8(const t_nrPolar_params *pp,decoder_node_t *node)
 
   // Apply F to left
   applyFtoleft_int8(pp, node);
+
   // if left is not a leaf recurse down to the left
   if (node->left->leaf==0)
     generic_polar_decoder_int8(pp, node->left);
 
-  applyGtoright(pp, node);
+  applyGtoright_int8(pp, node);
   if (node->right->leaf==0) generic_polar_decoder_int8(pp, node->right);
 
   computeBeta_int8(pp, node);
diff --git a/openair1/PHY/CODING/nrPolar_tools/nr_polar_defs.h b/openair1/PHY/CODING/nrPolar_tools/nr_polar_defs.h
index 6e01a05d0f..a9d44152d8 100644
--- a/openair1/PHY/CODING/nrPolar_tools/nr_polar_defs.h
+++ b/openair1/PHY/CODING/nrPolar_tools/nr_polar_defs.h
@@ -70,6 +70,8 @@ typedef struct decoder_node_t_s {
   int all_frozen;
   int16_t *alpha;
   int16_t *beta;
+  int8_t *alpha8;
+  int8_t *beta8;
 } decoder_node_t;
 
 typedef struct decoder_tree_t_s {
@@ -161,6 +163,14 @@ uint32_t polar_decoder_int16(int16_t *input,
                              uint8_t ones_flag,
                              const t_nrPolar_params *polarParams);
 
+// ############## INT 8 ##############
+uint32_t polar_decoder_int8(int8_t *input,
+                             uint64_t *out,
+                             uint8_t ones_flag,
+                             const t_nrPolar_params *polarParams);
+
+//################ END INT 8 ################
+
 int8_t polar_decoder_dci(double *input,
                          uint32_t *out,
                          t_nrPolar_params *polarParams,
@@ -180,6 +190,25 @@ void computeBeta(const t_nrPolar_params *pp,
 				 decoder_node_t *node);
 
 void build_decoder_tree(t_nrPolar_params *pp);
+
+//################ INT 8 ##############
+
+void generic_polar_decoder_int8(const t_nrPolar_params *pp,
+						   decoder_node_t *node);
+
+void applyFtoleft_int8(const t_nrPolar_params *pp,
+				  decoder_node_t *node);
+
+void applyGtoright_int8(const t_nrPolar_params *pp,
+				   decoder_node_t *node);
+
+void computeBeta_int8(const t_nrPolar_params *pp,
+				 decoder_node_t *node);
+
+void build_decoder_tree_int8(t_nrPolar_params *pp);
+
+//################ END INT 8 ################
+
 void build_polar_tables(t_nrPolar_params *polarParams);
 void init_polar_deinterleaver_table(t_nrPolar_params *polarParams);
 
@@ -231,6 +260,15 @@ void nr_polar_rate_matching_int16(int16_t *input,
                                   uint16_t N,
                                   uint16_t E);
 
+//########### INT 8 ################
+void nr_polar_rate_matching_int8(int8_t *input,
+                                  int8_t *output,
+                                  uint16_t *rmp,
+                                  uint16_t K,
+                                  uint16_t N,
+                                  uint16_t E);
+//########### END INT 8 ################
+
 void nr_polar_interleaving_pattern(uint16_t K,
                                    uint8_t I_IL,
                                    uint16_t *PI_k_);
diff --git a/openair1/PHY/CODING/nrPolar_tools/nr_polar_procedures.c b/openair1/PHY/CODING/nrPolar_tools/nr_polar_procedures.c
index 26832fd70c..50ed196128 100644
--- a/openair1/PHY/CODING/nrPolar_tools/nr_polar_procedures.c
+++ b/openair1/PHY/CODING/nrPolar_tools/nr_polar_procedures.c
@@ -345,3 +345,24 @@ void nr_polar_rate_matching_int16(int16_t *input,
    
   }
 }
+
+void nr_polar_rate_matching_int8(int8_t *input,
+				  int8_t *output,
+				  uint16_t *rmp,
+				  uint16_t K,
+				  uint16_t N,
+				  uint16_t E)
+{
+  if (E>=N) { //repetition
+    memset((void*)output,0,N*sizeof(int8_t));
+    for (int i=0; i<=E-1; i++) output[rmp[i]]+=input[i];
+  } else {
+    if ( (K/(double)E) <= (7.0/16) ) memset((void*)output,0,N*sizeof(int8_t)); //puncturing
+    else { //shortening
+      for (int i=0; i<=N-1; i++) output[i]=127;//instead of INFINITY, to prevent [-Woverflow]
+    }
+
+    for (int i=0; i<=E-1; i++) output[rmp[i]]=input[i];
+   
+  }
+}
diff --git a/openair1/PHY/CODING/nrPolar_tools/nr_polar_rate_match.c b/openair1/PHY/CODING/nrPolar_tools/nr_polar_rate_match.c
index d237603059..e7c2f495e0 100644
--- a/openair1/PHY/CODING/nrPolar_tools/nr_polar_rate_match.c
+++ b/openair1/PHY/CODING/nrPolar_tools/nr_polar_rate_match.c
@@ -80,7 +80,7 @@ void nr_polar_rate_matching(double *input, double *output, uint16_t *rmp, uint16
 
 }
 
-void nr_polar_rate_matching_int8(int16_t *input, int16_t *output, uint16_t *rmp, uint16_t K, uint16_t N, uint16_t E){
+/*void nr_polar_rate_matching_int8(int16_t *input, int16_t *output, uint16_t *rmp, uint16_t K, uint16_t N, uint16_t E){
 
 	if (E>=N) { //repetition
 		for (int i=0; i<=N-1; i++) output[i]=0;
@@ -99,4 +99,4 @@ void nr_polar_rate_matching_int8(int16_t *input, int16_t *output, uint16_t *rmp,
 		}
 	}
 
-}
+}*/
diff --git a/openair1/PHY/CODING/nr_polar_init.c b/openair1/PHY/CODING/nr_polar_init.c
index f6a700117d..9880ee95fe 100644
--- a/openair1/PHY/CODING/nr_polar_init.c
+++ b/openair1/PHY/CODING/nr_polar_init.c
@@ -183,6 +183,7 @@ static void nr_polar_init(t_nrPolar_params * *polarParams,
                                          newPolarInitNode->encoderLength);
     free(J);
     if (decoder_flag == 1) build_decoder_tree(newPolarInitNode);
+    if (decoder_flag == 2) build_decoder_tree_int8(newPolarInitNode);
     build_polar_tables(newPolarInitNode);
     init_polar_deinterleaver_table(newPolarInitNode);
     //printf("decoder tree nodes %d\n",newPolarInitNode->tree.num_nodes);
-- 
2.26.2