#include "fs_context.h"
 #include "dfs.h"
 
+#define DFS_DOM(ctx) (ctx->dfs_root_ses ? ctx->dfs_root_ses->dns_dom : NULL)
+
 /**
  * dfs_parse_target_referral - set fs context for dfs target referral
  *
        if (rc)
                goto out;
 
-       rc = dns_resolve_server_name_to_ip(path, (struct sockaddr *)&ctx->dstaddr, NULL);
-
+       rc = dns_resolve_server_name_to_ip(DFS_DOM(ctx), path,
+                                          (struct sockaddr *)&ctx->dstaddr,
+                                          NULL);
 out:
        kfree(path);
        return rc;
        int rc;
 
        ctx->leaf_fullpath = (char *)full_path;
+       ctx->dns_dom = DFS_DOM(ctx);
        rc = cifs_mount_get_session(mnt_ctx);
-       ctx->leaf_fullpath = NULL;
+       ctx->leaf_fullpath = ctx->dns_dom = NULL;
 
        return rc;
 }
        int rc = 0;
 
        if (!ctx->nodfs && ctx->dfs_automount) {
-               rc = dns_resolve_server_name_to_ip(ctx->source, addr, NULL);
+               rc = dns_resolve_server_name_to_ip(NULL, ctx->source,
+                                                  addr, NULL);
                if (!rc)
                        cifs_set_port(addr, ctx->port);
                ctx->dfs_automount = false;
 
 #include "cifsproto.h"
 #include "cifs_debug.h"
 
+static int resolve_name(const char *name, size_t namelen,
+                       struct sockaddr *addr, time64_t *expiry)
+{
+       char *ip;
+       int rc;
+
+       rc = dns_query(current->nsproxy->net_ns, NULL, name,
+                      namelen, NULL, &ip, expiry, false);
+       if (rc < 0) {
+               cifs_dbg(FYI, "%s: unable to resolve: %*.*s\n",
+                        __func__, (int)namelen, (int)namelen, name);
+       } else {
+               cifs_dbg(FYI, "%s: resolved: %*.*s to %s expiry %llu\n",
+                        __func__, (int)namelen, (int)namelen, name, ip,
+                        expiry ? (*expiry) : 0);
+
+               rc = cifs_convert_address(addr, ip, strlen(ip));
+               kfree(ip);
+               if (!rc) {
+                       cifs_dbg(FYI, "%s: unable to determine ip address\n",
+                                __func__);
+                       rc = -EHOSTUNREACH;
+               } else {
+                       rc = 0;
+               }
+       }
+       return rc;
+}
+
 /**
  * dns_resolve_server_name_to_ip - Resolve UNC server name to ip address.
+ * @dom: optional DNS domain name
  * @unc: UNC path specifying the server (with '/' as delimiter)
  * @ip_addr: Where to return the IP address.
  * @expiry: Where to return the expiry time for the dns record.
  *
  * Returns zero success, -ve on error.
  */
-int
-dns_resolve_server_name_to_ip(const char *unc, struct sockaddr *ip_addr, time64_t *expiry)
+int dns_resolve_server_name_to_ip(const char *dom, const char *unc,
+                                 struct sockaddr *ip_addr, time64_t *expiry)
 {
-       const char *hostname, *sep;
-       char *ip;
-       int len, rc;
+       const char *name;
+       size_t namelen, len;
+       char *s;
+       int rc;
 
        if (!ip_addr || !unc)
                return -EINVAL;
 
-       len = strlen(unc);
-       if (len < 3) {
-               cifs_dbg(FYI, "%s: unc is too short: %s\n", __func__, unc);
+       cifs_dbg(FYI, "%s: dom=%s unc=%s\n", __func__, dom, unc);
+       if (strlen(unc) < 3)
                return -EINVAL;
-       }
-
-       /* Discount leading slashes for cifs */
-       len -= 2;
-       hostname = unc + 2;
 
-       /* Search for server name delimiter */
-       sep = memchr(hostname, '/', len);
-       if (sep)
-               len = sep - hostname;
-       else
-               cifs_dbg(FYI, "%s: probably server name is whole unc: %s\n",
-                        __func__, unc);
+       extract_unc_hostname(unc, &name, &namelen);
+       if (!namelen)
+               return -EINVAL;
 
+       cifs_dbg(FYI, "%s: hostname=%.*s\n", __func__, (int)namelen, name);
        /* Try to interpret hostname as an IPv4 or IPv6 address */
-       rc = cifs_convert_address(ip_addr, hostname, len);
+       rc = cifs_convert_address(ip_addr, name, namelen);
        if (rc > 0) {
-               cifs_dbg(FYI, "%s: unc is IP, skipping dns upcall: %*.*s\n", __func__, len, len,
-                        hostname);
+               cifs_dbg(FYI, "%s: unc is IP, skipping dns upcall: %*.*s\n",
+                        __func__, (int)namelen, (int)namelen, name);
                return 0;
        }
 
-       /* Perform the upcall */
-       rc = dns_query(current->nsproxy->net_ns, NULL, hostname, len,
-                      NULL, &ip, expiry, false);
-       if (rc < 0) {
-               cifs_dbg(FYI, "%s: unable to resolve: %*.*s\n",
-                        __func__, len, len, hostname);
-       } else {
-               cifs_dbg(FYI, "%s: resolved: %*.*s to %s expiry %llu\n",
-                        __func__, len, len, hostname, ip,
-                        expiry ? (*expiry) : 0);
-
-               rc = cifs_convert_address(ip_addr, ip, strlen(ip));
-               kfree(ip);
+       /*
+        * If @name contains a NetBIOS name and @dom has been specified, then
+        * convert @name to an FQDN and try resolving it first.
+        */
+       if (dom && *dom && cifs_netbios_name(name, namelen)) {
+               len = strnlen(dom, CIFS_MAX_DOMAINNAME_LEN) + namelen + 2;
+               s = kmalloc(len, GFP_KERNEL);
+               if (!s)
+                       return -ENOMEM;
 
-               if (!rc) {
-                       cifs_dbg(FYI, "%s: unable to determine ip address\n", __func__);
-                       rc = -EHOSTUNREACH;
-               } else
-                       rc = 0;
+               scnprintf(s, len, "%.*s.%s", (int)namelen, name, dom);
+               rc = resolve_name(s, len - 1, ip_addr, expiry);
+               kfree(s);
+               if (!rc)
+                       return 0;
        }
-       return rc;
+       return resolve_name(name, namelen, ip_addr, expiry);
 }