From e21a09967dfc32c5fc0f85a3fb0b4f036dc5fe59 Mon Sep 17 00:00:00 2001
From: Roberto Louro Magueta <rmagueta@allbesmart.pt>
Date: Wed, 15 Mar 2023 15:52:06 +0000
Subject: [PATCH] Compute LLR for QPSK for ML

---
 .../PHY/NR_TRANSPORT/nr_transport_proto.h     |  12 +
 .../PHY/NR_TRANSPORT/nr_ulsch_demodulation.c  |  65 +++--
 .../NR_TRANSPORT/nr_ulsch_llr_computation.c   | 264 ++++++++++++++++++
 3 files changed, 325 insertions(+), 16 deletions(-)

diff --git a/openair1/PHY/NR_TRANSPORT/nr_transport_proto.h b/openair1/PHY/NR_TRANSPORT/nr_transport_proto.h
index 0509ce7acf..306d44f8f5 100644
--- a/openair1/PHY/NR_TRANSPORT/nr_transport_proto.h
+++ b/openair1/PHY/NR_TRANSPORT/nr_transport_proto.h
@@ -292,6 +292,18 @@ void nr_ulsch_compute_llr(int32_t *rxdataF_comp,
 void reset_active_stats(PHY_VARS_gNB *gNB, int frame);
 void reset_active_ulsch(PHY_VARS_gNB *gNB, int frame);
 
+void nr_ulsch_compute_ML_llr(int32_t **rxdataF_comp,
+                             int32_t ***rho,
+                             int16_t **llr_layers,
+                             uint8_t nb_antennas_rx,
+                             uint32_t rb_size,
+                             uint32_t nb_re,
+                             uint8_t symbol,
+                             uint32_t rxdataF_ext_offset,
+                             uint8_t mod_order);
+
+void nr_ulsch_shift_llr(int16_t **llr_layers, uint32_t nb_re, uint32_t rxdataF_ext_offset, uint8_t mod_order, int shift);
+
 void nr_fill_ulsch(PHY_VARS_gNB *gNB,
                    int frame,
                    int slot,
diff --git a/openair1/PHY/NR_TRANSPORT/nr_ulsch_demodulation.c b/openair1/PHY/NR_TRANSPORT/nr_ulsch_demodulation.c
index 070968760b..56e4575f7a 100644
--- a/openair1/PHY/NR_TRANSPORT/nr_ulsch_demodulation.c
+++ b/openair1/PHY/NR_TRANSPORT/nr_ulsch_demodulation.c
@@ -12,6 +12,7 @@
 //#define DEBUG_CH_COMP
 //#define DEBUG_RB_EXT
 //#define DEBUG_CH_MAG
+//#define ML_DEBUG
 
 #define INVALID_VALUE 255
 
@@ -636,7 +637,7 @@ void nr_ulsch_channel_compensation(int **rxdataF_ext,
       QAM_amp128 = _mm_set1_epi16(QAM16_n1);  // 2/sqrt(10)
       QAM_amp128b = _mm_setzero_si128();
       QAM_amp128c = _mm_setzero_si128();
-    } 
+    }
     else if (mod_order == 6) {
       QAM_amp128  = _mm_set1_epi16(QAM64_n1); //
       QAM_amp128b = _mm_set1_epi16(QAM64_n2);
@@ -1081,7 +1082,7 @@ void nr_ulsch_detection_mrc(NR_DL_FRAME_PARMS *frame_parms,
                 int32_t **ul_ch_mag,
                 int32_t **ul_ch_magb,
                 int32_t **ul_ch_magc,
-                int32_t ***rho,                
+                int32_t ***rho,
                 uint8_t  nrOfLayers,
                 uint8_t symbol,
                 uint16_t nb_rb,
@@ -1115,7 +1116,7 @@ void nr_ulsch_detection_mrc(NR_DL_FRAME_PARMS *frame_parms,
         ul_ch_mag128[1]      = (__m128i *)&ul_ch_mag[aatx*frame_parms->nb_antennas_rx+aa][(symbol*(nb_re + off))];
         ul_ch_mag128b[1]     = (__m128i *)&ul_ch_magb[aatx*frame_parms->nb_antennas_rx+aa][(symbol*(nb_re + off))];
         ul_ch_mag128c[1]     = (__m128i *)&ul_ch_magc[aatx*frame_parms->nb_antennas_rx+aa][(symbol*(nb_re + off))];
-      
+
         // MRC on each re of rb, both on MF output and magnitude (for 16QAM/64QAM llr computation)
         for (i=0; i<nb_rb_0*3; i++) {
             rxdataF_comp128[0][i] = _mm_adds_epi16(rxdataF_comp128[0][i],rxdataF_comp128[1][i]);
@@ -1898,6 +1899,9 @@ void nr_rx_pusch(PHY_VARS_gNB *gNB,
                  unsigned char harq_pid)
 {
 
+  // Temporary flag: (true) ML receiver, (false) MMSE receiver
+  bool ml_rx = true;
+
   uint8_t aarx, aatx;
   uint32_t nb_re_pusch, bwp_start_subcarrier;
   int avgs = 0;
@@ -2001,8 +2005,8 @@ void nr_rx_pusch(PHY_VARS_gNB *gNB,
   int ad_shift = 0;
   if (rel15_ul->nrOfLayers == 1) {
     ad_shift = 1 + log2_approx(frame_parms->nb_antennas_rx >> 2);
-  } else {
-    ad_shift = -3; // For 2-layers, we are already doing a bit shift in the nr_ulsch_zero_forcing_rx_2layers() function, so we can use more bits
+  } else if (ml_rx == false) {
+    ad_shift = -3; // For 2-layers, we are already doing a bit shift in the nr_ulsch_mmse_2layers() function, so we can use more bits
   }
 
   for(uint8_t symbol = rel15_ul->start_symbol_index; symbol < (rel15_ul->start_symbol_index + rel15_ul->nr_of_symbols); symbol++) {
@@ -2108,7 +2112,7 @@ void nr_rx_pusch(PHY_VARS_gNB *gNB,
                              nb_re_pusch);
 
       // Apply MMSE for 2 Tx layers
-      if (rel15_ul->nrOfLayers == 2) {
+      if (ml_rx == false && rel15_ul->nrOfLayers == 2) {
         nr_ulsch_mmse_2layers(frame_parms,
                               pusch_vars->rxdataF_comp,
                               pusch_vars->ul_ch_mag0,
@@ -2159,16 +2163,45 @@ void nr_rx_pusch(PHY_VARS_gNB *gNB,
       /*--------------------  LLRs computation  -------------------------------------------------------------*/
       /*-----------------------------------------------------------------------------------------------------*/
       start_meas(&gNB->ulsch_llr_stats);
-      for (aatx=0; aatx < rel15_ul->nrOfLayers; aatx++) {
-        nr_ulsch_compute_llr(&pusch_vars->rxdataF_comp[aatx*frame_parms->nb_antennas_rx][symbol * (off + rel15_ul->rb_size * NR_NB_SC_PER_RB)],
-                             pusch_vars->ul_ch_mag0[aatx*frame_parms->nb_antennas_rx],
-                             pusch_vars->ul_ch_magb0[aatx*frame_parms->nb_antennas_rx],
-                             pusch_vars->ul_ch_magc0[aatx*frame_parms->nb_antennas_rx],
-                             &pusch_vars->llr_layers[aatx][rxdataF_ext_offset * rel15_ul->qam_mod_order],
-                             rel15_ul->rb_size,
-                             pusch_vars->ul_valid_re_per_slot[symbol],
-                             symbol,
-                             rel15_ul->qam_mod_order);
+      if (ml_rx == false || rel15_ul->nrOfLayers == 1) {
+        for (aatx=0; aatx < rel15_ul->nrOfLayers; aatx++) {
+          nr_ulsch_compute_llr(&pusch_vars->rxdataF_comp[aatx * frame_parms->nb_antennas_rx][symbol * (off + rel15_ul->rb_size * NR_NB_SC_PER_RB)],
+                               pusch_vars->ul_ch_mag0[aatx * frame_parms->nb_antennas_rx],
+                               pusch_vars->ul_ch_magb0[aatx * frame_parms->nb_antennas_rx],
+                               pusch_vars->ul_ch_magc0[aatx * frame_parms->nb_antennas_rx],
+                               &pusch_vars->llr_layers[aatx][rxdataF_ext_offset * rel15_ul->qam_mod_order],
+                               rel15_ul->rb_size,
+                               pusch_vars->ul_valid_re_per_slot[symbol],
+                               symbol,
+                               rel15_ul->qam_mod_order);
+        }
+      } else {
+        nr_ulsch_compute_ML_llr(pusch_vars->rxdataF_comp,
+                                pusch_vars->rho,
+                                pusch_vars->llr_layers,
+                                frame_parms->nb_antennas_rx,
+                                rel15_ul->rb_size,
+                                nb_re_pusch,
+                                symbol,
+                                rxdataF_ext_offset,
+                                rel15_ul->qam_mod_order);
+
+        if (rel15_ul->qam_mod_order == 2) {
+          nr_ulsch_shift_llr(pusch_vars->llr_layers, nb_re_pusch, rxdataF_ext_offset, rel15_ul->qam_mod_order, 4);
+        }
+
+#ifdef ML_DEBUG
+        c16_t *llr_layers0 = (c16_t *)&pusch_vars->llr_layers[0][rxdataF_ext_offset * rel15_ul->qam_mod_order];
+        c16_t *llr_layers1 = (c16_t *)&pusch_vars->llr_layers[1][rxdataF_ext_offset * rel15_ul->qam_mod_order];
+        printf("===============================\n");
+        printf("AFTER nr_ulsch_compute_ML_llr()\n");
+        printf("===============================\n");
+        for (int k = 0; k < nb_re_pusch; k++) {
+          printf("[%3i] llr_layers0 = (%6i, %6i), llr_layers1 = (%6i, %6i)\n",
+                 k, llr_layers0[k].r, llr_layers0[k].i, llr_layers1[k].r, llr_layers1[k].i);
+        }
+        printf("\n");
+#endif
       }
       stop_meas(&gNB->ulsch_llr_stats);
       rxdataF_ext_offset += pusch_vars->ul_valid_re_per_slot[symbol];
diff --git a/openair1/PHY/NR_TRANSPORT/nr_ulsch_llr_computation.c b/openair1/PHY/NR_TRANSPORT/nr_ulsch_llr_computation.c
index 4d4c81a986..dc5e117267 100644
--- a/openair1/PHY/NR_TRANSPORT/nr_ulsch_llr_computation.c
+++ b/openair1/PHY/NR_TRANSPORT/nr_ulsch_llr_computation.c
@@ -509,3 +509,267 @@ void nr_ulsch_compute_llr(int32_t *rxdataF_comp,
       break;
   }
 }
+
+/*
+ * This function computes the LLRs of stream 0 (s_0) in presence of the interfering stream 1 (s_1) assuming that both symbols are
+ * QPSK. It can be used for both MU-MIMO interference-aware receiver or for SU-MIMO receivers.
+ *
+ * Input:
+ *   stream0_in:  MF filter output for 1st stream, i.e., y0' = h0'*y0
+ *   stream1_in:  MF filter output for 2nd stream, i.e., y1' = h1'*y0
+ *   rho01:       Channel cross correlation, i.e., rho01 = h0'*h1
+ *   length:      Number of resource elements
+ *
+ * Output:
+ *   stream0_out: Output LLRs for 1st stream
+ */
+void nr_ulsch_qpsk_qpsk(c16_t *stream0_in, c16_t *stream1_in, c16_t *stream0_out, c16_t *rho01, uint32_t length)
+{
+  __m128i *rho01_128i = (__m128i *)rho01;
+  __m128i *stream0_128i_in = (__m128i *)stream0_in;
+  __m128i *stream1_128i_in = (__m128i *)stream1_in;
+  __m128i *stream0_128i_out = (__m128i *)stream0_out;
+  __m128i ONE_OVER_2_SQRT_2 = _mm_set1_epi16(23170); // round(2 ^ 16 / (2 * sqrt(2)))
+
+  // In each iteration, we take 8 complex symbols
+  for (int i = 0; i < length >> 2; i += 2) {
+
+    /// Compute real and imaginary parts of MF output for stream 0 (desired stream)
+
+    // Put xmm0 = [Re(0,1) Re(2,3) Im(0,1) Im(2,3)]
+    __m128i xmm0 = stream0_128i_in[i];            // 4 symbols
+    xmm0 = simde_mm_shufflelo_epi16(xmm0, 0xd8);  //_MM_SHUFFLE(0,2,1,3));
+    xmm0 = simde_mm_shufflehi_epi16(xmm0, 0xd8);  //_MM_SHUFFLE(0,2,1,3));
+    xmm0 = simde_mm_shuffle_epi32(xmm0, 0xd8);    //_MM_SHUFFLE(0,2,1,3));
+
+    // Put xmm1 = [Re(4,5) Re(6,7) Im(4,5) Im(6,7)]
+    __m128i xmm1 = stream0_128i_in[i + 1];        // 4 symbols
+    xmm1 = simde_mm_shufflelo_epi16(xmm1, 0xd8);  //_MM_SHUFFLE(0,2,1,3));
+    xmm1 = simde_mm_shufflehi_epi16(xmm1, 0xd8);  //_MM_SHUFFLE(0,2,1,3));
+    xmm1 = simde_mm_shuffle_epi32(xmm1, 0xd8);    //_MM_SHUFFLE(0,2,1,3));
+
+    __m128i y0r = simde_mm_unpacklo_epi64(xmm0, xmm1);  // y0r = Re(y0)
+    __m128i y0i = simde_mm_unpackhi_epi64(xmm0, xmm1);  // y0i = Im(y0)
+
+    __m128i y0r_over2 = simde_mm_mulhi_epi16(y0r, ONE_OVER_2_SQRT_2);
+    y0r_over2 = _mm_slli_epi16(y0r_over2, 1); // y0r_over2 = Re(y0) / sqrt(2)
+    __m128i y0i_over2 = simde_mm_mulhi_epi16(y0i, ONE_OVER_2_SQRT_2);
+    y0i_over2 = _mm_slli_epi16(y0i_over2, 1); // y0i_over2 = Im(y0) / sqrt(2)
+
+    /// Compute real and imaginary parts of MF output for stream 1 (interference stream)
+
+    // Put xmm0 = [Re(0,1) Re(2,3) Im(0,1) Im(2,3)]
+    xmm0 = stream1_128i_in[i];                    // 4 symbols
+    xmm0 = simde_mm_shufflelo_epi16(xmm0, 0xd8);  //_MM_SHUFFLE(0,2,1,3));
+    xmm0 = simde_mm_shufflehi_epi16(xmm0, 0xd8);  //_MM_SHUFFLE(0,2,1,3));
+    xmm0 = simde_mm_shuffle_epi32(xmm0, 0xd8);    //_MM_SHUFFLE(0,2,1,3));
+
+    // Put xmm1 = [Re(4,5) Re(6,7) Im(4,5) Im(6,7)]
+    xmm1 = stream1_128i_in[i + 1];                // 4 symbols
+    xmm1 = simde_mm_shufflelo_epi16(xmm1, 0xd8);  //_MM_SHUFFLE(0,2,1,3));
+    xmm1 = simde_mm_shufflehi_epi16(xmm1, 0xd8);  //_MM_SHUFFLE(0,2,1,3));
+    xmm1 = simde_mm_shuffle_epi32(xmm1, 0xd8);    //_MM_SHUFFLE(0,2,1,3));
+
+    __m128i y1r = simde_mm_unpacklo_epi64(xmm0, xmm1);  // y1r = Re(y1)
+    __m128i y1i = simde_mm_unpackhi_epi64(xmm0, xmm1);  // y1i = Im(y1)
+    __m128i y1r_over2 = simde_mm_srai_epi16(y1r, 1);          // y1r_over2 = Re(y1) / 2
+    __m128i y1i_over2 = simde_mm_srai_epi16(y1i, 1);          // y1i_over2 = Im(y1) / 2
+
+    /// Get real and imaginary parts of rho
+
+    // Put xmm0 = [Re(0,1) Re(2,3) Im(0,1) Im(2,3)]
+    xmm0 = rho01_128i[i];                         // 4 symbols
+    xmm0 = simde_mm_shufflelo_epi16(xmm0, 0xd8);  //_MM_SHUFFLE(0,2,1,3));
+    xmm0 = simde_mm_shufflehi_epi16(xmm0, 0xd8);  //_MM_SHUFFLE(0,2,1,3));
+    xmm0 = simde_mm_shuffle_epi32(xmm0, 0xd8);    //_MM_SHUFFLE(0,2,1,3));
+
+    // Put xmm1 = [Re(4,5) Re(6,7) Im(4,5) Im(6,7)]
+    xmm1 = rho01_128i[i + 1];             // 4 symbols
+    xmm1 = simde_mm_shufflelo_epi16(xmm1, 0xd8);  //_MM_SHUFFLE(0,2,1,3));
+    xmm1 = simde_mm_shufflehi_epi16(xmm1, 0xd8);  //_MM_SHUFFLE(0,2,1,3));
+    xmm1 = simde_mm_shuffle_epi32(xmm1, 0xd8);    //_MM_SHUFFLE(0,2,1,3));
+
+    __m128i rhor = simde_mm_unpacklo_epi64(xmm0, xmm1); // rhor = Re(rho)
+    __m128i rhoi = simde_mm_unpackhi_epi64(xmm0, xmm1); // rhoi = Im(rho)
+
+    /// Compute |psi_r| and |psi_i|
+
+    // psi_r = rhor * xR + rhoi * xI
+    // psi_i = rhor * xI - rhoi * xR
+
+    // Put (rho_r + rho_i)/(2*sqrt(2)) in rho_p
+    // rhor * xR + rhoi * xI  --> xR = 1/sqrt(2) and xI = 1/sqrt(2)
+    // rhor * xI - rhoi * xR  --> xR = -1/sqrt(2) and xI = 1/sqrt(2)
+    __m128i rho_p = simde_mm_adds_epi16(rhor, rhoi);        // rho_p = Re(rho) + Im(rho)
+    rho_p = simde_mm_mulhi_epi16(rho_p, ONE_OVER_2_SQRT_2); // rho_p = rho_p / (2*sqrt(2))
+
+    // Put (rho_r - rho_i)/(2*sqrt(2)) in rho_m
+    // rhor * xR + rhoi * xI  --> xR = 1/sqrt(2) and xI = -1/sqrt(2)
+    // rhor * xI - rhoi * xR  --> xR = 1/sqrt(2) and xI = 1/sqrt(2)
+    __m128i rho_m = simde_mm_subs_epi16(rhor, rhoi);        // rho_m = Re(rho) - Im(rho)
+    rho_m = simde_mm_mulhi_epi16(rho_m, ONE_OVER_2_SQRT_2); // rho_m = rho_m / (2*sqrt(2))
+
+    // xR = 1/sqrt(2) and xI = 1/sqrt(2)
+    __m128i abs_psi_rpm = simde_mm_subs_epi16(rho_p, y1r_over2);  // psi_rpm = rho_p - y1r/2
+    abs_psi_rpm = simde_mm_abs_epi16(abs_psi_rpm);                   // abs_psi_rpm = |psi_rpm|
+
+    // xR = 1/sqrt(2) and xI = 1/sqrt(2)
+    __m128i abs_psi_imm = simde_mm_subs_epi16(rho_m, y1i_over2);  // psi_imm = rho_m - y1i/2
+    abs_psi_imm = simde_mm_abs_epi16(abs_psi_imm);                   // abs_psi_imm = |psi_imm|
+
+    // xR = 1/sqrt(2) and xI = -1/sqrt(2)
+    __m128i abs_psi_rmm = simde_mm_subs_epi16(rho_m, y1r_over2);  // psi_rmm = rho_m - y1r/2
+    abs_psi_rmm = simde_mm_abs_epi16(abs_psi_rmm);                   // abs_psi_rmm = |psi_rmm|
+
+    // xR = -1/sqrt(2) and xI = 1/sqrt(2)
+    __m128i abs_psi_ipm = simde_mm_subs_epi16(rho_p, y1i_over2);  // psi_ipm = rho_p - y1i/2
+    abs_psi_ipm = simde_mm_abs_epi16(abs_psi_ipm);                   // abs_psi_ipm = |psi_ipm|
+
+    // xR = -1/sqrt(2) and xI = -1/sqrt(2)
+    __m128i abs_psi_rpp = simde_mm_adds_epi16(rho_p, y1r_over2);  // psi_rpp = rho_p + y1r/2
+    abs_psi_rpp = simde_mm_abs_epi16(abs_psi_rpp);                   // abs_psi_rpp = |psi_rpp|
+
+    // xR = -1/sqrt(2) and xI = -1/sqrt(2)
+    __m128i abs_psi_imp = simde_mm_adds_epi16(rho_m, y1i_over2);  // psi_imp = rho_m + y1i/2
+    abs_psi_imp = simde_mm_abs_epi16(abs_psi_imp);                   // abs_psi_imp = |psi_imp|
+
+    // xR = -1/sqrt(2) and xI = 1/sqrt(2)
+    __m128i abs_psi_rmp = simde_mm_adds_epi16(rho_m, y1r_over2);  // psi_rmp = rho_m + y1r/2
+    abs_psi_rmp = simde_mm_abs_epi16(abs_psi_rmp);                   // abs_psi_rmp = |psi_rmp|
+
+    // xR = 1/sqrt(2) and xI = -1/sqrt(2)
+    __m128i abs_psi_ipp = simde_mm_adds_epi16(rho_p, y1i_over2);  // psi_ipm = rho_p + y1i/2
+    abs_psi_ipp = simde_mm_abs_epi16(abs_psi_ipp);                   // abs_psi_ipp = |psi_ipm|
+
+    /// Compute bit metrics (lambda)
+
+    // lambda = max { |psi_r - y1r| * |x2R| + |psi_i - y1i| * |x2I| + y0r * xR + y0i * xI}
+
+    // xR = 1/sqrt(2) and xI = 1/sqrt(2)
+    // For numerator: bit_met_num_re_p = abs_psi_rpm + abs_psi_imm + y0r/sqrt(2) + y0i/sqrt(2)
+    __m128i bit_met_num_re_p = simde_mm_adds_epi16(abs_psi_rpm, abs_psi_imm);
+    bit_met_num_re_p = simde_mm_adds_epi16(bit_met_num_re_p, y0r_over2);
+    bit_met_num_re_p = simde_mm_adds_epi16(bit_met_num_re_p, y0i_over2);
+
+    // xR = 1/sqrt(2) and xI = -1/sqrt(2)
+    // For numerator: bit_met_num_re_m = abs_psi_rmm + abs_psi_ipp + y0r/sqrt(2) - y0i/sqrt(2)
+    __m128i bit_met_num_re_m = simde_mm_adds_epi16(abs_psi_rmm, abs_psi_ipp);
+    bit_met_num_re_m = simde_mm_adds_epi16(bit_met_num_re_m, y0r_over2);
+    bit_met_num_re_m = simde_mm_subs_epi16(bit_met_num_re_m, y0i_over2);
+
+    // xR = -1/sqrt(2) and xI = 1/sqrt(2)
+    // For denominator: bit_met_den_re_p = abs_psi_rmp + abs_psi_ipm - y0r/sqrt(2) + y0i/sqrt(2)
+    __m128i bit_met_den_re_p = simde_mm_adds_epi16(abs_psi_rmp, abs_psi_ipm);
+    bit_met_den_re_p = simde_mm_subs_epi16(bit_met_den_re_p, y0r_over2);
+    bit_met_den_re_p = simde_mm_adds_epi16(bit_met_den_re_p, y0i_over2);
+
+    // xR = -1/sqrt(2) and xI = -1/sqrt(2)
+    // For denominator: bit_met_den_re_m = abs_psi_rpp + abs_psi_imp - y0r/sqrt(2) - y0i/sqrt(2)
+    __m128i bit_met_den_re_m = simde_mm_adds_epi16(abs_psi_rpp, abs_psi_imp);
+    bit_met_den_re_m = simde_mm_subs_epi16(bit_met_den_re_m, y0r_over2);
+    bit_met_den_re_m = simde_mm_subs_epi16(bit_met_den_re_m, y0i_over2);
+
+    // xR = 1/sqrt(2) and xI = 1/sqrt(2)
+    // For numerator: bit_met_num_im_p = abs_psi_rpm + abs_psi_imm + y0r/sqrt(2) + y0i/sqrt(2)
+    __m128i bit_met_num_im_p = simde_mm_adds_epi16(abs_psi_rpm, abs_psi_imm);
+    bit_met_num_im_p = simde_mm_adds_epi16(bit_met_num_im_p, y0r_over2);
+    bit_met_num_im_p = simde_mm_adds_epi16(bit_met_num_im_p, y0i_over2);
+
+    // xR = -1/sqrt(2) and xI = 1/sqrt(2)
+    // For numerator: bit_met_num_im_m = abs_psi_rmp + abs_psi_ipm - y0r/sqrt(2) + y0i/sqrt(2)
+    __m128i bit_met_num_im_m = simde_mm_adds_epi16(abs_psi_rmp, abs_psi_ipm);
+    bit_met_num_im_m = simde_mm_subs_epi16(bit_met_num_im_m, y0r_over2);
+    bit_met_num_im_m = simde_mm_adds_epi16(bit_met_num_im_m, y0i_over2);
+
+    // xR = 1/sqrt(2) and xI = -1/sqrt(2)
+    // For denominator: bit_met_den_im_p = abs_psi_rmm + abs_psi_ipp + y0r/sqrt(2) - y0i/sqrt(2)
+    __m128i bit_met_den_im_p = simde_mm_adds_epi16(abs_psi_rmm, abs_psi_ipp);
+    bit_met_den_im_p = simde_mm_adds_epi16(bit_met_den_im_p, y0r_over2);
+    bit_met_den_im_p = simde_mm_subs_epi16(bit_met_den_im_p, y0i_over2);
+
+    // xR = -1/sqrt(2) and xI = -1/sqrt(2)
+    // For denominator: bit_met_den_im_m = abs_psi_rpp + abs_psi_imp - y0r/sqrt(2)- y0i/sqrt(2)
+    __m128i bit_met_den_im_m = simde_mm_adds_epi16(abs_psi_rpp, abs_psi_imp);
+    bit_met_den_im_m = simde_mm_subs_epi16(bit_met_den_im_m, y0r_over2);
+    bit_met_den_im_m = simde_mm_subs_epi16(bit_met_den_im_m, y0i_over2);
+
+    /// Compute the LLRs
+
+    // LLR = lambda(c==1) - lambda(c==0)
+
+    __m128i logmax_num_re0 = simde_mm_max_epi16(bit_met_num_re_p, bit_met_num_re_m); // LLR of the first bit: Bit = 1
+    __m128i logmax_den_re0 = simde_mm_max_epi16(bit_met_den_re_p, bit_met_den_re_m); // LLR of the first bit: Bit = 0
+    __m128i logmax_num_im0 = simde_mm_max_epi16(bit_met_num_im_p, bit_met_num_im_m); // LLR of the second bit: Bit = 1
+    __m128i logmax_den_im0 = simde_mm_max_epi16(bit_met_den_im_p, bit_met_den_im_m); // LLR of the second bit: Bit = 0
+
+    y0r = simde_mm_subs_epi16(logmax_num_re0, logmax_den_re0);  // LLR of first bit [L1(1), L1(2), L1(3), L1(4)]
+    y0i = simde_mm_subs_epi16(logmax_num_im0, logmax_den_im0);  // LLR of second bit [L2(1), L2(2), L2(3), L2(4)]
+
+    // [L1(1), L2(1), L1(2), L2(2)]
+    simde_mm_storeu_si128(&stream0_128i_out[i], simde_mm_unpacklo_epi16(y0r, y0i));
+
+    // false if only 2 REs remain
+    if (i < ((length >> 1) - 1)) {
+      simde_mm_storeu_si128(&stream0_128i_out[i + 1], simde_mm_unpackhi_epi16(y0r, y0i));
+    }
+  }
+
+  _mm_empty();
+  _m_empty();
+}
+
+
+void nr_ulsch_compute_ML_llr(int32_t **rxdataF_comp,
+                             int32_t ***rho,
+                             int16_t **llr_layers,
+                             uint8_t nb_antennas_rx,
+                             uint32_t rb_size,
+                             uint32_t nb_re,
+                             uint8_t symbol,
+                             uint32_t rxdataF_ext_offset,
+                             uint8_t mod_order)
+{
+  int off = ((rb_size & 1) == 1) ? 4 : 0;
+  c16_t *rxdataF_comp0 = (c16_t *)&rxdataF_comp[0][symbol * (off + (rb_size * NR_NB_SC_PER_RB))];
+  c16_t *rxdataF_comp1 = (c16_t *)&rxdataF_comp[nb_antennas_rx][symbol * (off + (rb_size * NR_NB_SC_PER_RB))];
+  c16_t *llr_layers0 = (c16_t *)&llr_layers[0][rxdataF_ext_offset * mod_order];
+  c16_t *llr_layers1 = (c16_t *)&llr_layers[1][rxdataF_ext_offset * mod_order];
+  c16_t *rho0 = (c16_t *)&rho[0][1][symbol * (off + (rb_size * NR_NB_SC_PER_RB))];
+  c16_t *rho1 = (c16_t *)&rho[0][2][symbol * (off + (rb_size * NR_NB_SC_PER_RB))];
+
+  switch (mod_order) {
+    case 2:
+      nr_ulsch_qpsk_qpsk(rxdataF_comp0, rxdataF_comp1, llr_layers0, rho0, nb_re);
+      nr_ulsch_qpsk_qpsk(rxdataF_comp1, rxdataF_comp0, llr_layers1, rho1, nb_re);
+      break;
+    case 4:
+    case 6:
+      AssertFatal(1 == 0, "LLR computation is not implemented yet for ML with Qm = %d\n", mod_order);
+    default:
+      AssertFatal(1 == 0, "nr_ulsch_compute_llr: invalid Qm value, symbol = %d, Qm = %d\n", symbol, mod_order);
+  }
+}
+
+void nr_ulsch_shift_llr(int16_t **llr_layers, uint32_t nb_re, uint32_t rxdataF_ext_offset, uint8_t mod_order, int shift)
+{
+  __m128i *llr_layers0 = (__m128i *)&llr_layers[0][rxdataF_ext_offset * mod_order];
+  __m128i *llr_layers1 = (__m128i *)&llr_layers[1][rxdataF_ext_offset * mod_order];
+
+  uint8_t mem_offset = ((16 - ((long)llr_layers0)) & 0xF) >> 2;
+
+  if (mem_offset > 0) {
+    c16_t *llr_layers0_c16 = (c16_t *)&llr_layers[0][rxdataF_ext_offset * mod_order];
+    c16_t *llr_layers1_c16 = (c16_t *)&llr_layers[1][rxdataF_ext_offset * mod_order];
+    for (int i = 0; i < mem_offset; i++) {
+      llr_layers0_c16[i] = c16Shift(llr_layers0_c16[i], shift);
+      llr_layers1_c16[i] = c16Shift(llr_layers1_c16[i], shift);
+    }
+    llr_layers0 = (__m128i *)&llr_layers[0][rxdataF_ext_offset * mod_order + (mem_offset << 1)];
+    llr_layers1 = (__m128i *)&llr_layers[1][rxdataF_ext_offset * mod_order + (mem_offset << 1)];
+  }
+
+  for (int i = 0; i < nb_re >> 2; i++) {
+    llr_layers0[i] = simde_mm_srai_epi16(llr_layers0[i], shift);
+    llr_layers1[i] = simde_mm_srai_epi16(llr_layers1[i], shift);
+  }
+}
\ No newline at end of file
-- 
2.26.2