-
Notifications
You must be signed in to change notification settings - Fork 48
/
switchml_nccl.patch
83 lines (81 loc) · 3.21 KB
/
switchml_nccl.patch
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
diff --git a/src/init.cc b/src/init.cc
index 2be994d..0026739 100644
--- a/src/init.cc
+++ b/src/init.cc
@@ -68,31 +68,25 @@ ncclResult_t initCollNet(ncclCollNet_t* collnet) {
ncclResult_t initNetPlugin(ncclNet_t** net, ncclCollNet_t** collnet) {
void* netPluginLib = dlopen("libnccl-net.so", RTLD_NOW | RTLD_LOCAL);
if (netPluginLib == NULL) {
- // dlopen does not guarantee to set errno, but dlerror only gives us a
- // string, so checking errno doesn't hurt to try to provide a better
- // error message
- if (errno == ENOENT) {
- INFO(NCCL_INIT|NCCL_NET, "NET/Plugin : No plugin found (libnccl-net.so), using internal implementation");
- } else {
- INFO(NCCL_INIT|NCCL_NET, "NET/Plugin : Plugin load returned %d : %s.", errno, dlerror());
- }
+ INFO(NCCL_INIT|NCCL_NET, "NET/Plugin : Plugin load returned %d : %s.", errno, dlerror());
return ncclSuccess;
}
ncclNet_t* extNet = (ncclNet_t*) dlsym(netPluginLib, STR(NCCL_PLUGIN_SYMBOL));
if (extNet == NULL) {
INFO(NCCL_INIT|NCCL_NET, "NET/Plugin: Failed to find " STR(NCCL_PLUGIN_SYMBOL) " symbol.");
- } else if (initNet(extNet) == ncclSuccess) {
+ } else if (extNet->name == NULL || initNet(extNet) == ncclSuccess) {
*net = extNet;
- // Check for CollNet
- ncclCollNet_t* extCollNet = (ncclCollNet_t*) dlsym(netPluginLib, STR(NCCL_COLLNET_PLUGIN_SYMBOL));
- if (extCollNet == NULL) {
- INFO(NCCL_INIT|NCCL_NET, "NET/Plugin: Failed to find " STR(NCCL_COLLNET_PLUGIN_SYMBOL) " symbol.");
- } else if (initCollNet(extCollNet) == ncclSuccess) {
- *collnet = extCollNet;
- }
- return ncclSuccess;
}
- if (netPluginLib != NULL) dlclose(netPluginLib);
+ // Check for CollNet
+ ncclCollNet_t* extCollNet = (ncclCollNet_t*) dlsym(netPluginLib, STR(NCCL_COLLNET_PLUGIN_SYMBOL));
+ if (extCollNet == NULL) {
+ INFO(NCCL_INIT|NCCL_NET, "NET/Plugin: Failed to find " STR(NCCL_COLLNET_PLUGIN_SYMBOL) " symbol.");
+ } else {
+ *collnet = extCollNet;
+ }
+ if (extNet == NULL && extCollNet == NULL) {
+ dlclose(netPluginLib);
+ }
return ncclSuccess;
}
@@ -101,12 +95,29 @@ ncclResult_t initNet() {
NCCLCHECK(bootstrapNetInit());
NCCLCHECK(initNetPlugin(&ncclNet, &ncclCollNet));
- if (ncclNet != NULL) return ncclSuccess;
- if (initNet(&ncclNetIb) == ncclSuccess) {
- ncclNet = &ncclNetIb;
+ if (ncclNet == NULL) {
+ if (initNet(&ncclNetIb) == ncclSuccess) {
+ ncclNet = &ncclNetIb;
+ } else {
+ NCCLCHECK(initNet(&ncclNetSocket));
+ ncclNet = &ncclNetSocket;
+ }
+ } else if (ncclNet->name == NULL){
+ *ncclNet = ncclNetIb;
+ if (initNet(ncclNet) != ncclSuccess) {
+ *ncclNet = ncclNetSocket;
+ NCCLCHECK(initNet(ncclNet));
+ }
} else {
- NCCLCHECK(initNet(&ncclNetSocket));
- ncclNet = &ncclNetSocket;
+ INFO(NCCL_INIT|NCCL_NET, "NET/Plugin : Using plugin: %s.", ncclNet->name);
+ }
+ if (ncclCollNet != NULL) {
+ if (initCollNet(ncclCollNet) != ncclSuccess) {
+ INFO(NCCL_INIT|NCCL_NET, "NET/Plugin: Failed to initialize " STR(NCCL_COLLNET_PLUGIN_SYMBOL) " symbol.");
+ ncclCollNet = NULL;
+ } else {
+ INFO(NCCL_INIT|NCCL_NET, "NET/Plugin : Using collectives plugin: %s.", ncclCollNet->name);
+ }
}
return ncclSuccess;
}