libkvikio  23.12.00
cuda.hpp
1 /*
2  * Copyright (c) 2022-2023, NVIDIA CORPORATION.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #pragma once
17 
18 #include <cuda.h>
19 
20 #include <kvikio/shim/utils.hpp>
21 
22 namespace kvikio {
23 
31 class cudaAPI {
32  public:
33  decltype(cuInit)* Init{nullptr};
34  decltype(cuMemHostAlloc)* MemHostAlloc{nullptr};
35  decltype(cuMemFreeHost)* MemFreeHost{nullptr};
36  decltype(cuMemcpyHtoD)* MemcpyHtoD{nullptr};
37  decltype(cuMemcpyDtoH)* MemcpyDtoH{nullptr};
38  decltype(cuPointerGetAttribute)* PointerGetAttribute{nullptr};
39  decltype(cuPointerGetAttributes)* PointerGetAttributes{nullptr};
40  decltype(cuCtxPushCurrent)* CtxPushCurrent{nullptr};
41  decltype(cuCtxPopCurrent)* CtxPopCurrent{nullptr};
42  decltype(cuCtxGetCurrent)* CtxGetCurrent{nullptr};
43  decltype(cuMemGetAddressRange)* MemGetAddressRange{nullptr};
44  decltype(cuGetErrorName)* GetErrorName{nullptr};
45  decltype(cuGetErrorString)* GetErrorString{nullptr};
46  decltype(cuDeviceGet)* DeviceGet{nullptr};
47  decltype(cuDevicePrimaryCtxRetain)* DevicePrimaryCtxRetain{nullptr};
48  decltype(cuDevicePrimaryCtxRelease)* DevicePrimaryCtxRelease{nullptr};
49  decltype(cuStreamSynchronize)* StreamSynchronize{nullptr};
50 
51  private:
52  cudaAPI()
53  {
54  void* lib = load_library("libcuda.so.1");
55  // Notice, the API version loaded must match the version used downstream. That is,
56  // if a project uses the `_v2` CUDA Driver API or the newest Runtime API, the symbols
57  // loaded should also be the `_v2` symbols. Thus, we use KVIKIO_STRINGIFY() to get
58  // the name of the symbol through cude.h.
59  get_symbol(MemHostAlloc, lib, KVIKIO_STRINGIFY(cuMemHostAlloc));
60  get_symbol(MemFreeHost, lib, KVIKIO_STRINGIFY(cuMemFreeHost));
61  get_symbol(MemcpyHtoD, lib, KVIKIO_STRINGIFY(cuMemcpyHtoD));
62  get_symbol(MemcpyDtoH, lib, KVIKIO_STRINGIFY(cuMemcpyDtoH));
63  get_symbol(PointerGetAttribute, lib, KVIKIO_STRINGIFY(cuPointerGetAttribute));
64  get_symbol(PointerGetAttributes, lib, KVIKIO_STRINGIFY(cuPointerGetAttributes));
65  get_symbol(CtxPushCurrent, lib, KVIKIO_STRINGIFY(cuCtxPushCurrent));
66  get_symbol(CtxPopCurrent, lib, KVIKIO_STRINGIFY(cuCtxPopCurrent));
67  get_symbol(CtxGetCurrent, lib, KVIKIO_STRINGIFY(cuCtxGetCurrent));
68  get_symbol(MemGetAddressRange, lib, KVIKIO_STRINGIFY(cuMemGetAddressRange));
69  get_symbol(GetErrorName, lib, KVIKIO_STRINGIFY(cuGetErrorName));
70  get_symbol(GetErrorString, lib, KVIKIO_STRINGIFY(cuGetErrorString));
71  get_symbol(DeviceGet, lib, KVIKIO_STRINGIFY(cuDeviceGet));
72  get_symbol(DevicePrimaryCtxRetain, lib, KVIKIO_STRINGIFY(cuDevicePrimaryCtxRetain));
73  get_symbol(DevicePrimaryCtxRelease, lib, KVIKIO_STRINGIFY(cuDevicePrimaryCtxRelease));
74  get_symbol(StreamSynchronize, lib, KVIKIO_STRINGIFY(cuStreamSynchronize));
75  }
76 
77  public:
78  cudaAPI(cudaAPI const&) = delete;
79  void operator=(cudaAPI const&) = delete;
80 
81  static cudaAPI& instance()
82  {
83  static cudaAPI _instance;
84  return _instance;
85  }
86 };
87 
88 } // namespace kvikio
Shim layer of the cuda C-API.
Definition: cuda.hpp:31