#ifndef _CIFSPROTO_H
 #define _CIFSPROTO_H
 #include <linux/nls.h>
+#include <linux/ctype.h>
 #include "trace.h"
 #ifdef CONFIG_CIFS_DFS_UPCALL
 #include "dfs_cache.h"
 extern struct TCP_Server_Info *
 cifs_find_tcp_session(struct smb3_fs_context *ctx);
 
-extern void cifs_put_smb_ses(struct cifs_ses *ses);
+void __cifs_put_smb_ses(struct cifs_ses *ses);
 
 extern struct cifs_ses *
 cifs_get_smb_ses(struct TCP_Server_Info *server, struct smb3_fs_context *ctx);
 void cifs_put_tcon_super(struct super_block *sb);
 int cifs_wait_for_server_reconnect(struct TCP_Server_Info *server, bool retry);
 
+/* Put references of @ses and @ses->dfs_root_ses */
+static inline void cifs_put_smb_ses(struct cifs_ses *ses)
+{
+       struct cifs_ses *rses = ses->dfs_root_ses;
+
+       __cifs_put_smb_ses(ses);
+       if (rses)
+               __cifs_put_smb_ses(rses);
+}
+
+/* Get an active reference of @ses and @ses->dfs_root_ses.
+ *
+ * NOTE: make sure to call this function when incrementing reference count of
+ * @ses to ensure that any DFS root session attached to it (@ses->dfs_root_ses)
+ * will also get its reference count incremented.
+ *
+ * cifs_put_smb_ses() will put both references, so call it when you're done.
+ */
+static inline void cifs_smb_ses_inc_refcount(struct cifs_ses *ses)
+{
+       lockdep_assert_held(&cifs_tcp_ses_lock);
+
+       ses->ses_count++;
+       if (ses->dfs_root_ses)
+               ses->dfs_root_ses->ses_count++;
+}
+
+static inline bool dfs_src_pathname_equal(const char *s1, const char *s2)
+{
+       if (strlen(s1) != strlen(s2))
+               return false;
+       for (; *s1; s1++, s2++) {
+               if (*s1 == '/' || *s1 == '\\') {
+                       if (*s2 != '/' && *s2 != '\\')
+                               return false;
+               } else if (tolower(*s1) != tolower(*s2))
+                       return false;
+       }
+       return true;
+}
+
 #endif                 /* _CIFSPROTO_H */
 
                 */
        }
 
-#ifdef CONFIG_CIFS_DFS_UPCALL
        kfree(server->origin_fullpath);
        kfree(server->leaf_fullpath);
-#endif
        kfree(server);
 
        length = atomic_dec_return(&tcpSesAllocCount);
        return true;
 }
 
-static bool dfs_src_pathname_equal(const char *s1, const char *s2)
-{
-       if (strlen(s1) != strlen(s2))
-               return false;
-       for (; *s1; s1++, s2++) {
-               if (*s1 == '/' || *s1 == '\\') {
-                       if (*s2 != '/' && *s2 != '\\')
-                               return false;
-               } else if (tolower(*s1) != tolower(*s2))
-                       return false;
-       }
-       return true;
-}
-
 /* this function must be called with srv_lock held */
-static int match_server(struct TCP_Server_Info *server, struct smb3_fs_context *ctx,
-                       bool dfs_super_cmp)
+static int match_server(struct TCP_Server_Info *server, struct smb3_fs_context *ctx)
 {
        struct sockaddr *addr = (struct sockaddr *)&ctx->dstaddr;
 
                               (struct sockaddr *)&server->srcaddr))
                return 0;
        /*
-        * When matching DFS superblocks, we only check for original source pathname as the
-        * currently connected target might be different than the one parsed earlier in i.e.
-        * mount.cifs(8).
+        * - Match for an DFS tcon (@server->origin_fullpath).
+        * - Match for an DFS root server connection (@server->leaf_fullpath).
+        * - If none of the above and @ctx->leaf_fullpath is set, then
+        *   it is a new DFS connection.
+        * - If 'nodfs' mount option was passed, then match only connections
+        *   that have no DFS referrals set
+        *   (e.g. can't failover to other targets).
         */
-       if (dfs_super_cmp) {
-               if (!ctx->source || !server->origin_fullpath ||
-                   !dfs_src_pathname_equal(server->origin_fullpath, ctx->source))
-                       return 0;
-       } else {
-               /* Skip addr, hostname and port matching for DFS connections */
-               if (server->leaf_fullpath) {
+       if (!ctx->nodfs) {
+               if (ctx->source && server->origin_fullpath) {
+                       if (!dfs_src_pathname_equal(ctx->source,
+                                                   server->origin_fullpath))
+                               return 0;
+               } else if (server->leaf_fullpath) {
                        if (!ctx->leaf_fullpath ||
-                           strcasecmp(server->leaf_fullpath, ctx->leaf_fullpath))
+                           strcasecmp(server->leaf_fullpath,
+                                      ctx->leaf_fullpath))
                                return 0;
-               } else if (strcasecmp(server->hostname, ctx->server_hostname) ||
-                          !match_server_address(server, addr) ||
-                          !match_port(server, addr)) {
+               } else if (ctx->leaf_fullpath) {
                        return 0;
                }
+       } else if (server->origin_fullpath || server->leaf_fullpath) {
+               return 0;
        }
 
+       /*
+        * Match for a regular connection (address/hostname/port) which has no
+        * DFS referrals set.
+        */
+       if (!server->origin_fullpath && !server->leaf_fullpath &&
+           (strcasecmp(server->hostname, ctx->server_hostname) ||
+            !match_server_address(server, addr) ||
+            !match_port(server, addr)))
+               return 0;
+
        if (!match_security(server, ctx))
                return 0;
 
                 * Skip ses channels since they're only handled in lower layers
                 * (e.g. cifs_send_recv).
                 */
-               if (CIFS_SERVER_IS_CHAN(server) || !match_server(server, ctx, false)) {
+               if (CIFS_SERVER_IS_CHAN(server) || !match_server(server, ctx)) {
                        spin_unlock(&server->srv_lock);
                        continue;
                }
 static struct cifs_ses *
 cifs_find_smb_ses(struct TCP_Server_Info *server, struct smb3_fs_context *ctx)
 {
-       struct cifs_ses *ses;
+       struct cifs_ses *ses, *ret = NULL;
 
        spin_lock(&cifs_tcp_ses_lock);
        list_for_each_entry(ses, &server->smb_ses_list, smb_ses_list) {
                        continue;
                }
                spin_lock(&ses->chan_lock);
-               if (!match_session(ses, ctx)) {
+               if (match_session(ses, ctx)) {
                        spin_unlock(&ses->chan_lock);
                        spin_unlock(&ses->ses_lock);
-                       continue;
+                       ret = ses;
+                       break;
                }
                spin_unlock(&ses->chan_lock);
                spin_unlock(&ses->ses_lock);
-
-               ++ses->ses_count;
-               spin_unlock(&cifs_tcp_ses_lock);
-               return ses;
        }
+       if (ret)
+               cifs_smb_ses_inc_refcount(ret);
        spin_unlock(&cifs_tcp_ses_lock);
-       return NULL;
+       return ret;
 }
 
-void cifs_put_smb_ses(struct cifs_ses *ses)
+void __cifs_put_smb_ses(struct cifs_ses *ses)
 {
        unsigned int rc, xid;
        unsigned int chan_count;
         */
        spin_lock(&cifs_tcp_ses_lock);
        ses->dfs_root_ses = ctx->dfs_root_ses;
+       if (ses->dfs_root_ses)
+               ses->dfs_root_ses->ses_count++;
        list_add(&ses->smb_ses_list, &server->smb_ses_list);
        spin_unlock(&cifs_tcp_ses_lock);
 
 }
 
 /* this function must be called with tc_lock held */
-static int match_tcon(struct cifs_tcon *tcon, struct smb3_fs_context *ctx, bool dfs_super_cmp)
+static int match_tcon(struct cifs_tcon *tcon, struct smb3_fs_context *ctx)
 {
+       struct TCP_Server_Info *server = tcon->ses->server;
+
        if (tcon->status == TID_EXITING)
                return 0;
-       /* Skip UNC validation when matching DFS superblocks */
-       if (!dfs_super_cmp && strncmp(tcon->tree_name, ctx->UNC, MAX_TREE_SIZE))
+       /* Skip UNC validation when matching DFS connections or superblocks */
+       if (!server->origin_fullpath && !server->leaf_fullpath &&
+           strncmp(tcon->tree_name, ctx->UNC, MAX_TREE_SIZE))
                return 0;
        if (tcon->seal != ctx->seal)
                return 0;
        spin_lock(&cifs_tcp_ses_lock);
        list_for_each_entry(tcon, &ses->tcon_list, tcon_list) {
                spin_lock(&tcon->tc_lock);
-               if (!match_tcon(tcon, ctx, false)) {
+               if (!match_tcon(tcon, ctx)) {
                        spin_unlock(&tcon->tc_lock);
                        continue;
                }
        return 1;
 }
 
-static int
-match_prepath(struct super_block *sb, struct cifs_mnt_data *mnt_data)
+static int match_prepath(struct super_block *sb,
+                        struct TCP_Server_Info *server,
+                        struct cifs_mnt_data *mnt_data)
 {
+       struct smb3_fs_context *ctx = mnt_data->ctx;
        struct cifs_sb_info *old = CIFS_SB(sb);
        struct cifs_sb_info *new = mnt_data->cifs_sb;
        bool old_set = (old->mnt_cifs_flags & CIFS_MOUNT_USE_PREFIX_PATH) &&
        bool new_set = (new->mnt_cifs_flags & CIFS_MOUNT_USE_PREFIX_PATH) &&
                new->prepath;
 
+       if (server->origin_fullpath &&
+           dfs_src_pathname_equal(server->origin_fullpath, ctx->source))
+               return 1;
+
        if (old_set && new_set && !strcmp(new->prepath, old->prepath))
                return 1;
        else if (!old_set && !new_set)
        struct cifs_ses *ses;
        struct cifs_tcon *tcon;
        struct tcon_link *tlink;
-       bool dfs_super_cmp;
        int rc = 0;
 
        spin_lock(&cifs_tcp_ses_lock);
        ses = tcon->ses;
        tcp_srv = ses->server;
 
-       dfs_super_cmp = IS_ENABLED(CONFIG_CIFS_DFS_UPCALL) && tcp_srv->origin_fullpath;
-
        ctx = mnt_data->ctx;
 
        spin_lock(&tcp_srv->srv_lock);
        spin_lock(&ses->ses_lock);
        spin_lock(&ses->chan_lock);
        spin_lock(&tcon->tc_lock);
-       if (!match_server(tcp_srv, ctx, dfs_super_cmp) ||
+       if (!match_server(tcp_srv, ctx) ||
            !match_session(ses, ctx) ||
-           !match_tcon(tcon, ctx, dfs_super_cmp) ||
-           !match_prepath(sb, mnt_data)) {
+           !match_tcon(tcon, ctx) ||
+           !match_prepath(sb, tcp_srv, mnt_data)) {
                rc = 0;
                goto out;
        }
 
 error:
        dfs_put_root_smb_sessions(&mnt_ctx.dfs_ses_list);
-       kfree(mnt_ctx.origin_fullpath);
-       kfree(mnt_ctx.leaf_fullpath);
        cifs_mount_put_conns(&mnt_ctx);
        return rc;
 }
 
        return rc;
 }
 
-static int get_root_smb_session(struct cifs_mount_ctx *mnt_ctx)
+static int add_root_smb_session(struct cifs_mount_ctx *mnt_ctx)
 {
        struct smb3_fs_context *ctx = mnt_ctx->fs_ctx;
        struct dfs_root_ses *root_ses;
 {
        struct smb3_fs_context *ctx = mnt_ctx->fs_ctx;
        struct dfs_info3_param ref = {};
-       bool is_refsrv = false;
+       bool is_refsrv;
        int rc, rc2;
 
        rc = dfs_cache_get_tgt_referral(ref_path + 1, tit, &ref);
        dfs_cache_noreq_update_tgthint(ref_path + 1, tit);
 
        if (rc == -EREMOTE && is_refsrv) {
-               rc2 = get_root_smb_session(mnt_ctx);
+               rc2 = add_root_smb_session(mnt_ctx);
                if (rc2)
                        rc = rc2;
        }
 
 int dfs_mount_share(struct cifs_mount_ctx *mnt_ctx, bool *isdfs)
 {
-       struct cifs_sb_info *cifs_sb = mnt_ctx->cifs_sb;
        struct smb3_fs_context *ctx = mnt_ctx->fs_ctx;
+       struct cifs_ses *ses;
+       char *source = ctx->source;
+       bool nodfs = ctx->nodfs;
        int rc;
 
        *isdfs = false;
-
+       /* Temporarily set @ctx->source to NULL as we're not matching DFS
+        * superblocks yet.  See cifs_match_super() and match_server().
+        */
+       ctx->source = NULL;
        rc = get_session(mnt_ctx, NULL);
        if (rc)
-               return rc;
+               goto out;
+
        ctx->dfs_root_ses = mnt_ctx->ses;
        /*
         * If called with 'nodfs' mount option, then skip DFS resolving.  Otherwise unconditionally
         * Skip prefix path to provide support for DFS referrals from w2k8 servers which don't seem
         * to respond with PATH_NOT_COVERED to requests that include the prefix.
         */
-       if ((cifs_sb->mnt_cifs_flags & CIFS_MOUNT_NO_DFS) ||
-           dfs_get_referral(mnt_ctx, ctx->UNC + 1, NULL, NULL)) {
+       if (!nodfs) {
+               rc = dfs_get_referral(mnt_ctx, ctx->UNC + 1, NULL, NULL);
+               if (rc) {
+                       if (rc != -ENOENT && rc != -EOPNOTSUPP)
+                               goto out;
+                       nodfs = true;
+               }
+       }
+       if (nodfs) {
                rc = cifs_mount_get_tcon(mnt_ctx);
-               if (rc)
-                       return rc;
-
-               rc = cifs_is_path_remote(mnt_ctx);
-               if (!rc || rc != -EREMOTE)
-                       return rc;
+               if (!rc)
+                       rc = cifs_is_path_remote(mnt_ctx);
+               goto out;
        }
 
        *isdfs = true;
-       rc = get_root_smb_session(mnt_ctx);
-       if (rc)
-               return rc;
-
-       return __dfs_mount_share(mnt_ctx);
+       /*
+        * Prevent DFS root session of being put in the first call to
+        * cifs_mount_put_conns().  If another DFS root server was not found
+        * while chasing the referrals (@ctx->dfs_root_ses == @ses), then we
+        * can safely put extra refcount of @ses.
+        */
+       ses = mnt_ctx->ses;
+       mnt_ctx->ses = NULL;
+       mnt_ctx->server = NULL;
+       rc = __dfs_mount_share(mnt_ctx);
+       if (ses == ctx->dfs_root_ses)
+               cifs_put_smb_ses(ses);
+out:
+       /*
+        * Restore previous value of @ctx->source so DFS superblock can be
+        * matched in cifs_match_super().
+        */
+       ctx->source = source;
+       return rc;
 }
 
 /* Update dfs referral path of superblock */