* Even though 40 bits are present, use only 32 for ease of calculation.
  */
 #define ALPHA_REG_BITWIDTH     40
-#define ALPHA_BITWIDTH         32
-#define ALPHA_16BIT_MASK       0xffff
+#define ALPHA_REG_16BIT_WIDTH  16
+#define ALPHA_BITWIDTH         32U
+#define ALPHA_SHIFT(w)         min(w, ALPHA_BITWIDTH)
+
+#define pll_alpha_width(p)                                     \
+               ((PLL_ALPHA_VAL_U(p) - PLL_ALPHA_VAL(p) == 4) ? \
+                                ALPHA_REG_BITWIDTH : ALPHA_REG_16BIT_WIDTH)
 
 #define to_clk_alpha_pll(_hw) container_of(to_clk_regmap(_hw), \
                                           struct clk_alpha_pll, clkr)
        regmap_update_bits(pll->clkr.regmap, PLL_MODE(pll), mask, 0);
 }
 
-static unsigned long alpha_pll_calc_rate(u64 prate, u32 l, u32 a)
+static unsigned long
+alpha_pll_calc_rate(u64 prate, u32 l, u32 a, u32 alpha_width)
 {
-       return (prate * l) + ((prate * a) >> ALPHA_BITWIDTH);
+       return (prate * l) + ((prate * a) >> ALPHA_SHIFT(alpha_width));
 }
 
 static unsigned long
-alpha_pll_round_rate(unsigned long rate, unsigned long prate, u32 *l, u64 *a)
+alpha_pll_round_rate(unsigned long rate, unsigned long prate, u32 *l, u64 *a,
+                    u32 alpha_width)
 {
        u64 remainder;
        u64 quotient;
        }
 
        /* Upper ALPHA_BITWIDTH bits of Alpha */
-       quotient = remainder << ALPHA_BITWIDTH;
+       quotient = remainder << ALPHA_SHIFT(alpha_width);
+
        remainder = do_div(quotient, prate);
 
        if (remainder)
                quotient++;
 
        *a = quotient;
-       return alpha_pll_calc_rate(prate, *l, *a);
+       return alpha_pll_calc_rate(prate, *l, *a, alpha_width);
 }
 
 static const struct pll_vco *
        u32 l, low, high, ctl;
        u64 a = 0, prate = parent_rate;
        struct clk_alpha_pll *pll = to_clk_alpha_pll(hw);
+       u32 alpha_width = pll_alpha_width(pll);
 
        regmap_read(pll->clkr.regmap, PLL_L_VAL(pll), &l);
 
        regmap_read(pll->clkr.regmap, PLL_USER_CTL(pll), &ctl);
        if (ctl & PLL_ALPHA_EN) {
                regmap_read(pll->clkr.regmap, PLL_ALPHA_VAL(pll), &low);
-               if (pll->flags & SUPPORTS_16BIT_ALPHA) {
-                       a = low & ALPHA_16BIT_MASK;
-               } else {
+               if (alpha_width > 32) {
                        regmap_read(pll->clkr.regmap, PLL_ALPHA_VAL_U(pll),
                                    &high);
                        a = (u64)high << 32 | low;
-                       a >>= ALPHA_REG_BITWIDTH - ALPHA_BITWIDTH;
+               } else {
+                       a = low & GENMASK(alpha_width - 1, 0);
                }
+
+               if (alpha_width > ALPHA_BITWIDTH)
+                       a >>= alpha_width - ALPHA_BITWIDTH;
        }
 
-       return alpha_pll_calc_rate(prate, l, a);
+       return alpha_pll_calc_rate(prate, l, a, alpha_width);
 }
 
 static int clk_alpha_pll_set_rate(struct clk_hw *hw, unsigned long rate,
 {
        struct clk_alpha_pll *pll = to_clk_alpha_pll(hw);
        const struct pll_vco *vco;
-       u32 l;
+       u32 l, alpha_width = pll_alpha_width(pll);
        u64 a;
 
-       rate = alpha_pll_round_rate(rate, prate, &l, &a);
+       rate = alpha_pll_round_rate(rate, prate, &l, &a, alpha_width);
        vco = alpha_pll_find_vco(pll, rate);
        if (!vco) {
                pr_err("alpha pll not in a valid vco range\n");
 
        regmap_write(pll->clkr.regmap, PLL_L_VAL(pll), l);
 
-       if (pll->flags & SUPPORTS_16BIT_ALPHA) {
-               regmap_write(pll->clkr.regmap, PLL_ALPHA_VAL(pll),
-                            a & ALPHA_16BIT_MASK);
-       } else {
-               a <<= (ALPHA_REG_BITWIDTH - ALPHA_BITWIDTH);
-               regmap_write(pll->clkr.regmap, PLL_ALPHA_VAL_U(pll),
-                            a >> 32);
-       }
+       if (alpha_width > ALPHA_BITWIDTH)
+               a <<= alpha_width - ALPHA_BITWIDTH;
+
+       if (alpha_width > 32)
+               regmap_write(pll->clkr.regmap, PLL_ALPHA_VAL_U(pll), a >> 32);
+
+       regmap_write(pll->clkr.regmap, PLL_ALPHA_VAL(pll), a);
 
        regmap_update_bits(pll->clkr.regmap, PLL_USER_CTL(pll),
                           PLL_VCO_MASK << PLL_VCO_SHIFT,
                                     unsigned long *prate)
 {
        struct clk_alpha_pll *pll = to_clk_alpha_pll(hw);
-       u32 l;
+       u32 l, alpha_width = pll_alpha_width(pll);
        u64 a;
        unsigned long min_freq, max_freq;
 
-       rate = alpha_pll_round_rate(rate, *prate, &l, &a);
+       rate = alpha_pll_round_rate(rate, *prate, &l, &a, alpha_width);
        if (alpha_pll_find_vco(pll, rate))
                return rate;