break;
        case WLAN_EID_EXT_EHT_MULTI_LINK:
                if (ieee80211_mle_size_ok(data, len)) {
+                       elems->multi_link_elem = (void *)elem;
                        elems->multi_link = (void *)data;
                        elems->multi_link_len = len;
                }
        return found ? profile_len : 0;
 }
 
-static void ieee80211_defragment_element(struct ieee802_11_elems *elems,
-                                        void **elem_ptr, size_t *len,
-                                        size_t total_len, u8 frag_id)
-{
-       u8 *data = *elem_ptr, *pos, *start;
-       const struct element *elem;
-
-       /*
-        * Since 'data' points to the data of the element, not the element
-        * itself, allow 254 in case it was an extended element where the
-        * extended ID isn't part of the data we see here and thus not part of
-        * 'len' either.
-        */
-       if (!data || (*len != 254 && *len != 255))
-               return;
-
-       start = elems->scratch_pos;
-
-       if (WARN_ON(*len > (elems->scratch + elems->scratch_len -
-                           elems->scratch_pos)))
-               return;
-
-       memcpy(elems->scratch_pos, data, *len);
-       elems->scratch_pos += *len;
-
-       pos = data + *len;
-       total_len -= *len;
-       for_each_element(elem, pos, total_len) {
-               if (elem->id != frag_id)
-                       break;
-
-               if (WARN_ON(elem->datalen >
-                           (elems->scratch + elems->scratch_len -
-                            elems->scratch_pos)))
-                       return;
-
-               memcpy(elems->scratch_pos, elem->data, elem->datalen);
-               elems->scratch_pos += elem->datalen;
-
-               *len += elem->datalen;
-       }
-
-       *elem_ptr = start;
-}
-
 static void ieee80211_mle_get_sta_prof(struct ieee802_11_elems *elems,
                                       u8 link_id)
 {
        const struct ieee80211_multi_link_elem *ml = elems->multi_link;
-       size_t ml_len = elems->multi_link_len;
+       ssize_t ml_len = elems->multi_link_len;
        const struct element *sub;
 
        if (!ml || !ml_len)
 
        for_each_mle_subelement(sub, (u8 *)ml, ml_len) {
                struct ieee80211_mle_per_sta_profile *prof = (void *)sub->data;
+               ssize_t sta_prof_len;
                u16 control;
 
                if (sub->id != IEEE80211_MLE_SUBELEM_PER_STA_PROFILE)
                if (!(control & IEEE80211_MLE_STA_CONTROL_COMPLETE_PROFILE))
                        return;
 
-               elems->prof = prof;
-               elems->sta_prof_len = sub->datalen;
-
                /* the sub element can be fragmented */
-               ieee80211_defragment_element(elems, (void **)&elems->prof,
-                                            &elems->sta_prof_len,
-                                            ml_len - (sub->data - (u8 *)ml),
-                                            IEEE80211_MLE_SUBELEM_FRAGMENT);
+               sta_prof_len =
+                       cfg80211_defragment_element(sub,
+                                                   (u8 *)ml, ml_len,
+                                                   elems->scratch_pos,
+                                                   elems->scratch +
+                                                       elems->scratch_len -
+                                                       elems->scratch_pos,
+                                                   IEEE80211_MLE_SUBELEM_FRAGMENT);
+
+               if (sta_prof_len < 0)
+                       return;
+
+               elems->prof = (void *)elems->scratch_pos;
+               elems->sta_prof_len = sta_prof_len;
+               elems->scratch_pos += sta_prof_len;
+
                return;
        }
 }
                .from_ap = params->from_ap,
                .link_id = -1,
        };
+       ssize_t ml_len = elems->multi_link_len;
        const struct element *non_inherit = NULL;
        const u8 *end;
 
        if (params->link_id == -1)
                return;
 
-       ieee80211_defragment_element(elems, (void **)&elems->multi_link,
-                                    &elems->multi_link_len,
-                                    elems->total_len - ((u8 *)elems->multi_link -
-                                                        elems->ie_start),
-                                    WLAN_EID_FRAGMENT);
+       ml_len = cfg80211_defragment_element(elems->multi_link_elem,
+                                            elems->ie_start,
+                                            elems->total_len,
+                                            elems->scratch_pos,
+                                            elems->scratch +
+                                               elems->scratch_len -
+                                               elems->scratch_pos,
+                                            WLAN_EID_FRAGMENT);
+
+       if (ml_len < 0)
+               return;
+
+       elems->multi_link = (const void *)elems->scratch_pos;
+       elems->multi_link_len = ml_len;
+       elems->scratch_pos += ml_len;
 
        ieee80211_mle_get_sta_prof(elems, params->link_id);
        prof = elems->prof;