extern int amd_special_default_mtrr(void);
 void mtrr_disable(void);
 void mtrr_enable(void);
+void mtrr_generic_set_state(void);
 #  else
 static inline u8 mtrr_type_lookup(u64 addr, u64 end, u8 *uniform)
 {
 #define mtrr_bp_restore() do {} while (0)
 #define mtrr_disable() do {} while (0)
 #define mtrr_enable() do {} while (0)
+#define mtrr_generic_set_state() do {} while (0)
 #  endif
 
 #ifdef CONFIG_COMPAT
 
 
        raw_spin_unlock(&cache_disable_lock);
 }
+
+void cache_cpu_init(void)
+{
+       unsigned long flags;
+
+       local_irq_save(flags);
+       cache_disable();
+
+       if (memory_caching_control & CACHE_MTRR)
+               mtrr_generic_set_state();
+
+       if (memory_caching_control & CACHE_PAT)
+               pat_init();
+
+       cache_enable();
+       local_irq_restore(flags);
+}
 
        mtrr_wrmsr(MSR_MTRRdefType, deftype_lo, deftype_hi);
 }
 
-static void generic_set_all(void)
+void mtrr_generic_set_state(void)
 {
        unsigned long mask, count;
-       unsigned long flags;
-
-       local_irq_save(flags);
-       cache_disable();
 
        /* Actually set the state */
        mask = set_mtrr_state();
 
-       /* also set PAT */
-       pat_init();
-
-       cache_enable();
-       local_irq_restore(flags);
-
        /* Use the atomic bitops to update the global mask */
        for (count = 0; count < sizeof(mask) * 8; ++count) {
                if (mask & 0x01)
                        set_bit(count, &smp_changes_mask);
                mask >>= 1;
        }
-
 }
 
 /**
  * Generic structure...
  */
 const struct mtrr_ops generic_mtrr_ops = {
-       .set_all                = generic_set_all,
+       .set_all                = cache_cpu_init,
        .get                    = generic_get_mtrr,
        .get_free_region        = generic_get_free_region,
        .set                    = generic_set_mtrr,