#include "rc_calc.h"
 #include "fixed31_32.h"
 
+#include "clk_mgr.h"
+#include "resource.h"
+
 #define DC_LOGGER \
        dsc->ctx->logger
 
 }
 
 /* Forward Declerations */
+static unsigned int get_min_slice_count_for_odm(
+               const struct display_stream_compressor *dsc,
+               const struct dsc_enc_caps *dsc_enc_caps,
+               const struct dc_crtc_timing *timing);
+
 static bool decide_dsc_bandwidth_range(
                const uint32_t min_bpp_x16,
                const uint32_t max_bpp_x16,
                const struct dc_crtc_timing *timing,
                const struct dc_dsc_config_options *options,
                const enum dc_link_encoding_format link_encoding,
+               int min_slice_count,
                struct dc_dsc_config *dsc_cfg);
 
 static bool dsc_buff_block_size_from_dpcd(int dpcd_buff_block_size, int *buff_block_size)
        return true;
 }
 
-
 /* If DSC is possbile, get DSC bandwidth range based on [min_bpp, max_bpp] target bitrate range and
  * timing's pixel clock and uncompressed bandwidth.
  * If DSC is not possible, leave '*range' untouched.
                struct dc_dsc_bw_range *range)
 {
        bool is_dsc_possible = false;
+       unsigned int min_slice_count;
        struct dsc_enc_caps dsc_enc_caps;
        struct dsc_enc_caps dsc_common_caps;
        struct dc_dsc_config config = {0};
 
        get_dsc_enc_caps(dsc, &dsc_enc_caps, timing->pix_clk_100hz);
 
+       min_slice_count = get_min_slice_count_for_odm(dsc, &dsc_enc_caps, timing);
+
        is_dsc_possible = intersect_dsc_caps(dsc_sink_caps, &dsc_enc_caps,
                        timing->pixel_encoding, &dsc_common_caps);
 
        if (is_dsc_possible)
                is_dsc_possible = setup_dsc_config(dsc_sink_caps, &dsc_enc_caps, 0, timing,
-                               &options, link_encoding, &config);
+                               &options, link_encoding, min_slice_count, &config);
 
        if (is_dsc_possible)
                is_dsc_possible = decide_dsc_bandwidth_range(min_bpp_x16, max_bpp_x16,
        DC_LOG_DSC("\tis_dp %d", dsc_sink_caps->is_dp);
 }
 
+
+static void build_dsc_enc_combined_slice_caps(
+               const struct dsc_enc_caps *single_dsc_enc_caps,
+               struct dsc_enc_caps *dsc_enc_caps,
+               unsigned int max_odm_combine_factor)
+{
+       /* 1-16 slice configurations, single DSC */
+       dsc_enc_caps->slice_caps.raw |= single_dsc_enc_caps->slice_caps.raw;
+
+       /* 2x DSC's */
+       if (max_odm_combine_factor >= 2) {
+               /* 1 + 1 */
+               dsc_enc_caps->slice_caps.bits.NUM_SLICES_2 |= single_dsc_enc_caps->slice_caps.bits.NUM_SLICES_1;
+
+               /* 2 + 2 */
+               dsc_enc_caps->slice_caps.bits.NUM_SLICES_4 |= single_dsc_enc_caps->slice_caps.bits.NUM_SLICES_2;
+
+               /* 4 + 4 */
+               dsc_enc_caps->slice_caps.bits.NUM_SLICES_8 |= single_dsc_enc_caps->slice_caps.bits.NUM_SLICES_4;
+
+               /* 8 + 8 */
+               dsc_enc_caps->slice_caps.bits.NUM_SLICES_16 |= single_dsc_enc_caps->slice_caps.bits.NUM_SLICES_8;
+       }
+
+       /* 3x DSC's */
+       if (max_odm_combine_factor >= 3) {
+               /* 4 + 4 + 4 */
+               dsc_enc_caps->slice_caps.bits.NUM_SLICES_12 |= single_dsc_enc_caps->slice_caps.bits.NUM_SLICES_4;
+       }
+
+       /* 4x DSC's */
+       if (max_odm_combine_factor >= 4) {
+               /* 1 + 1 + 1 + 1 */
+               dsc_enc_caps->slice_caps.bits.NUM_SLICES_4 |= single_dsc_enc_caps->slice_caps.bits.NUM_SLICES_1;
+
+               /* 2 + 2 + 2 + 2 */
+               dsc_enc_caps->slice_caps.bits.NUM_SLICES_8 |= single_dsc_enc_caps->slice_caps.bits.NUM_SLICES_2;
+
+               /* 3 + 3 + 3 + 3 */
+               dsc_enc_caps->slice_caps.bits.NUM_SLICES_12 |= single_dsc_enc_caps->slice_caps.bits.NUM_SLICES_3;
+
+               /* 4 + 4 + 4 + 4 */
+               dsc_enc_caps->slice_caps.bits.NUM_SLICES_16 |= single_dsc_enc_caps->slice_caps.bits.NUM_SLICES_4;
+       }
+}
+
+static void build_dsc_enc_caps(
+               const struct display_stream_compressor *dsc,
+               struct dsc_enc_caps *dsc_enc_caps)
+{
+       unsigned int max_dscclk_khz;
+       unsigned int num_dsc;
+       unsigned int max_odm_combine_factor;
+       struct dsc_enc_caps single_dsc_enc_caps;
+
+       struct dc *dc;
+
+       memset(&single_dsc_enc_caps, 0, sizeof(struct dsc_enc_caps));
+
+       if (!dsc || !dsc->ctx || !dsc->ctx->dc || !dsc->funcs->dsc_get_single_enc_caps)
+               return;
+
+       dc = dsc->ctx->dc;
+
+       if (!dc->clk_mgr || !dc->clk_mgr->funcs->get_max_clock_khz || !dc->res_pool)
+               return;
+
+       /* get max DSCCLK from clk_mgr */
+       max_dscclk_khz = dc->clk_mgr->funcs->get_max_clock_khz(dc->clk_mgr, CLK_TYPE_DSCCLK);
+
+       dsc->funcs->dsc_get_single_enc_caps(&single_dsc_enc_caps, max_dscclk_khz);
+
+       /* global capabilities */
+       dsc_enc_caps->dsc_version = single_dsc_enc_caps.dsc_version;
+       dsc_enc_caps->lb_bit_depth = single_dsc_enc_caps.lb_bit_depth;
+       dsc_enc_caps->is_block_pred_supported = single_dsc_enc_caps.is_block_pred_supported;
+       dsc_enc_caps->max_slice_width = single_dsc_enc_caps.max_slice_width;
+       dsc_enc_caps->bpp_increment_div = single_dsc_enc_caps.bpp_increment_div;
+       dsc_enc_caps->color_formats.raw = single_dsc_enc_caps.color_formats.raw;
+       dsc_enc_caps->color_depth.raw = single_dsc_enc_caps.color_depth.raw;
+
+       /* expand per DSC capabilities to global */
+       max_odm_combine_factor = dc->caps.max_odm_combine_factor;
+       num_dsc = dc->res_pool->res_cap->num_dsc;
+       max_odm_combine_factor = min(max_odm_combine_factor, num_dsc);
+       dsc_enc_caps->max_total_throughput_mps =
+                       single_dsc_enc_caps.max_total_throughput_mps *
+                       max_odm_combine_factor;
+
+       /* check slice counts possible for with ODM combine */
+       build_dsc_enc_combined_slice_caps(&single_dsc_enc_caps, dsc_enc_caps, max_odm_combine_factor);
+}
+
+static inline uint32_t dsc_div_by_10_round_up(uint32_t value)
+{
+       return (value + 9) / 10;
+}
+
+static unsigned int get_min_slice_count_for_odm(
+               const struct display_stream_compressor *dsc,
+               const struct dsc_enc_caps *dsc_enc_caps,
+               const struct dc_crtc_timing *timing)
+{
+       unsigned int max_dispclk_khz;
+
+       /* get max pixel rate and combine caps */
+       max_dispclk_khz = dsc_enc_caps->max_total_throughput_mps * 1000;
+       if (dsc && dsc->ctx->dc) {
+               if (dsc->ctx->dc->clk_mgr &&
+                       dsc->ctx->dc->clk_mgr->funcs->get_max_clock_khz) {
+                       /* dispclk is available */
+                       max_dispclk_khz = dsc->ctx->dc->clk_mgr->funcs->get_max_clock_khz(dsc->ctx->dc->clk_mgr, CLK_TYPE_DISPCLK);
+               }
+       }
+
+       /* consider minimum odm slices required due to
+        * 1) display pipe throughput (dispclk)
+        * 2) max image width per slice
+        */
+       return dc_fixpt_ceil(dc_fixpt_max(
+                       dc_fixpt_div_int(dc_fixpt_from_int(dsc_div_by_10_round_up(timing->pix_clk_100hz)),
+                       max_dispclk_khz), // throughput
+                       dc_fixpt_div_int(dc_fixpt_from_int(timing->h_addressable + timing->h_border_left + timing->h_border_right),
+                       dsc_enc_caps->max_slice_width))); // slice width
+}
+
 static void get_dsc_enc_caps(
                const struct display_stream_compressor *dsc,
                struct dsc_enc_caps *dsc_enc_caps,
                int pixel_clock_100Hz)
 {
-       // This is a static HW query, so we can use any DSC
-
        memset(dsc_enc_caps, 0, sizeof(struct dsc_enc_caps));
-       if (dsc) {
+
+       if (!dsc)
+               return;
+
+       /* check if reported cap global or only for a single DCN DSC enc */
+       if (dsc->funcs->dsc_get_enc_caps) {
                if (!dsc->ctx->dc->debug.disable_dsc)
                        dsc->funcs->dsc_get_enc_caps(dsc_enc_caps, pixel_clock_100Hz);
-               if (dsc->ctx->dc->debug.native422_support)
-                       dsc_enc_caps->color_formats.bits.YCBCR_NATIVE_422 = 1;
+       } else {
+               build_dsc_enc_caps(dsc, dsc_enc_caps);
        }
+
+       if (dsc->ctx->dc->debug.native422_support)
+               dsc_enc_caps->color_formats.bits.YCBCR_NATIVE_422 = 1;
 }
 
 /* Returns 'false' if no intersection was found for at least one capability.
        return true;
 }
 
-static inline uint32_t dsc_div_by_10_round_up(uint32_t value)
-{
-       return (value + 9) / 10;
-}
-
 static uint32_t compute_bpp_x16_from_target_bandwidth(
        const uint32_t bandwidth_in_kbps,
        const struct dc_crtc_timing *timing,
                const struct dc_crtc_timing *timing,
                const struct dc_dsc_config_options *options,
                const enum dc_link_encoding_format link_encoding,
+               int min_slices_h,
                struct dc_dsc_config *dsc_cfg)
 {
        struct dsc_enc_caps dsc_common_caps;
        int max_slices_h = 0;
-       int min_slices_h = 0;
        int num_slices_h = 0;
        int pic_width;
        int slice_width;
        if (!is_dsc_possible)
                goto done;
 
-       min_slices_h = pic_width / dsc_common_caps.max_slice_width;
-       if (pic_width % dsc_common_caps.max_slice_width)
-               min_slices_h++;
-
        min_slices_h = fit_num_slices_up(dsc_common_caps.slice_caps, min_slices_h);
 
+       /* increase minimum slice count to meet sink throughput limitations */
        while (min_slices_h <= max_slices_h) {
                int pix_clk_per_slice_khz = dsc_div_by_10_round_up(timing->pix_clk_100hz) / min_slices_h;
                if (pix_clk_per_slice_khz <= sink_per_slice_throughput_mps * 1000)
                min_slices_h = inc_num_slices(dsc_common_caps.slice_caps, min_slices_h);
        }
 
-       is_dsc_possible = (min_slices_h <= max_slices_h);
-
-       if (pic_width % min_slices_h != 0)
-               min_slices_h = 0; // DSC TODO: Maybe try increasing the number of slices first?
-
-       if (min_slices_h == 0 && max_slices_h == 0)
-               is_dsc_possible = false;
+       /* increase minimum slice count to meet divisibility requirements */
+       while (pic_width % min_slices_h != 0 && min_slices_h <= max_slices_h) {
+               min_slices_h = inc_num_slices(dsc_common_caps.slice_caps, min_slices_h);
+       }
 
+       is_dsc_possible = (min_slices_h <= max_slices_h) && max_slices_h != 0;
        if (!is_dsc_possible)
                goto done;
 
 {
        bool is_dsc_possible = false;
        struct dsc_enc_caps dsc_enc_caps;
-
+       unsigned int min_slice_count;
        get_dsc_enc_caps(dsc, &dsc_enc_caps, timing->pix_clk_100hz);
+
+       min_slice_count = get_min_slice_count_for_odm(dsc, &dsc_enc_caps, timing);
+
        is_dsc_possible = setup_dsc_config(dsc_sink_caps,
                &dsc_enc_caps,
                target_bandwidth_kbps,
-               timing, options, link_encoding, dsc_cfg);
+               timing,
+               options,
+               link_encoding,
+               min_slice_count,
+               dsc_cfg);
        return is_dsc_possible;
 }
 
 
 #include "dsc/dscc_types.h"
 #include "dsc/rc_calc.h"
 
-#define MAX_THROUGHPUT_PER_DSC_100HZ 20000000
-#define MAX_DSC_UNIT_COMBINE 4
-
 static void dsc_write_to_registers(struct display_stream_compressor *dsc, const struct dsc_reg_values *reg_vals);
 
 /* Object I/F functions */
 //static void dsc401_get_enc_caps(struct dsc_enc_caps *dsc_enc_caps, int pixel_clock_100Hz);
 //static bool dsc401_get_packed_pps(struct display_stream_compressor *dsc, const struct dsc_config *dsc_cfg, uint8_t *dsc_packed_pps);
+static void dsc401_get_single_enc_caps(struct dsc_enc_caps *dsc_enc_caps, unsigned int max_dscclk_khz);
 
 static const struct dsc_funcs dcn401_dsc_funcs = {
-       .dsc_get_enc_caps = dsc401_get_enc_caps,
        .dsc_read_state = dsc401_read_state,
        .dsc_validate_stream = dsc401_validate_stream,
        .dsc_set_config = dsc401_set_config,
        .dsc_disable = dsc401_disable,
        .dsc_disconnect = dsc401_disconnect,
        .dsc_wait_disconnect_pending_clear = dsc401_wait_disconnect_pending_clear,
+       .dsc_get_single_enc_caps = dsc401_get_single_enc_caps,
 };
 
 /* Macro definitios for REG_SET macros*/
        dsc->max_image_width = 5184;
 }
 
-void dsc401_get_enc_caps(struct dsc_enc_caps *dsc_enc_caps, int pixel_clock_100Hz)
+static void dsc401_get_single_enc_caps(struct dsc_enc_caps *dsc_enc_caps, unsigned int max_dscclk_khz)
 {
-       int min_dsc_unit_required = (pixel_clock_100Hz + MAX_THROUGHPUT_PER_DSC_100HZ - 1) / MAX_THROUGHPUT_PER_DSC_100HZ;
-
        dsc_enc_caps->dsc_version = 0x21; /* v1.2 - DP spec defined it in reverse order and we kept it */
 
-       /* 1 slice is only supported with 1 DSC unit */
-       dsc_enc_caps->slice_caps.bits.NUM_SLICES_1 = min_dsc_unit_required == 1 ? 1 : 0;
-       /* 2 slice is only supported with 1 or 2 DSC units */
-       dsc_enc_caps->slice_caps.bits.NUM_SLICES_2 = (min_dsc_unit_required == 1 || min_dsc_unit_required == 2) ? 1 : 0;
-       /* 3 slice is only supported with 1 DSC unit */
-       dsc_enc_caps->slice_caps.bits.NUM_SLICES_3 = min_dsc_unit_required == 1 ? 1 : 0;
+       dsc_enc_caps->slice_caps.bits.NUM_SLICES_1 = 1;
+       dsc_enc_caps->slice_caps.bits.NUM_SLICES_2 = 1;
+       dsc_enc_caps->slice_caps.bits.NUM_SLICES_3 = 1;
        dsc_enc_caps->slice_caps.bits.NUM_SLICES_4 = 1;
-       dsc_enc_caps->slice_caps.bits.NUM_SLICES_8 = 1;
-       dsc_enc_caps->slice_caps.bits.NUM_SLICES_12 = 1;
-       dsc_enc_caps->slice_caps.bits.NUM_SLICES_16 = 1;
 
        dsc_enc_caps->lb_bit_depth = 13;
        dsc_enc_caps->is_block_pred_supported = true;
        dsc_enc_caps->color_depth.bits.COLOR_DEPTH_8_BPC = 1;
        dsc_enc_caps->color_depth.bits.COLOR_DEPTH_10_BPC = 1;
        dsc_enc_caps->color_depth.bits.COLOR_DEPTH_12_BPC = 1;
-       dsc_enc_caps->max_total_throughput_mps = MAX_THROUGHPUT_PER_DSC_100HZ * MAX_DSC_UNIT_COMBINE;
+       dsc_enc_caps->max_total_throughput_mps = max_dscclk_khz * 3 / 1000;
 
        dsc_enc_caps->max_slice_width = 5184; /* (including 64 overlap pixels for eDP MSO mode) */
        dsc_enc_caps->bpp_increment_div = 16; /* 1/16th of a bit */