if (bch2_snapshot_equiv(c, id))
                return 0;
 
-       /* 0 is an invalid tree ID */
+       /* Do we need to reconstruct the snapshot_tree entry as well? */
+       struct btree_iter iter;
+       struct bkey_s_c k;
+       int ret = 0;
        u32 tree_id = 0;
-       int ret = bch2_snapshot_tree_create(trans, id, 0, &tree_id);
+
+       for_each_btree_key_norestart(trans, iter, BTREE_ID_snapshot_trees, POS_MIN,
+                                    0, k, ret) {
+               if (le32_to_cpu(bkey_s_c_to_snapshot_tree(k).v->root_snapshot) == id) {
+                       tree_id = k.k->p.offset;
+                       break;
+               }
+       }
+       bch2_trans_iter_exit(trans, &iter);
+
        if (ret)
                return ret;
 
+       if (!tree_id) {
+               ret = bch2_snapshot_tree_create(trans, id, 0, &tree_id);
+               if (ret)
+                       return ret;
+       }
+
        struct bkey_i_snapshot *snapshot = bch2_trans_kmalloc(trans, sizeof(*snapshot));
        ret = PTR_ERR_OR_ZERO(snapshot);
        if (ret)
        snapshot->v.tree        = cpu_to_le32(tree_id);
        snapshot->v.btime.lo    = cpu_to_le64(bch2_current_time(c));
 
+       for_each_btree_key_norestart(trans, iter, BTREE_ID_subvolumes, POS_MIN,
+                                    0, k, ret) {
+               if (le32_to_cpu(bkey_s_c_to_subvolume(k).v->snapshot) == id) {
+                       snapshot->v.subvol = cpu_to_le32(k.k->p.offset);
+                       SET_BCH_SNAPSHOT_SUBVOL(&snapshot->v, true);
+                       break;
+               }
+       }
+       bch2_trans_iter_exit(trans, &iter);
+
        return  bch2_btree_insert_trans(trans, BTREE_ID_snapshots, &snapshot->k_i, 0) ?:
                bch2_mark_snapshot(trans, BTREE_ID_snapshots, 0,
                                   bkey_s_c_null, bkey_i_to_s(&snapshot->k_i), 0) ?: