Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions cuda_core/cuda/core/_context.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ cdef class Context:
cdef:
ContextHandle _h_context
int _device_id
object __weakref__

@staticmethod
cdef Context _from_handle(type cls, ContextHandle h_context, int device_id)
2 changes: 1 addition & 1 deletion cuda_core/cuda/core/_device.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -955,7 +955,7 @@ class Device:
Default value of `None` return the currently used device.

"""
__slots__ = ("_device_id", "_memory_resource", "_has_inited", "_properties", "_uuid", "_context")
__slots__ = ("_device_id", "_memory_resource", "_has_inited", "_properties", "_uuid", "_context", "__weakref__")

def __new__(cls, device_id: Device | int | None = None):
if isinstance(device_id, Device):
Expand Down
1 change: 1 addition & 0 deletions cuda_core/cuda/core/_event.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ cdef class Event:
bint _ipc_enabled
object _ipc_descriptor
int _device_id
object __weakref__

@staticmethod
cdef Event _init(type cls, int device_id, ContextHandle h_context, options, bint is_free)
Expand Down
1 change: 1 addition & 0 deletions cuda_core/cuda/core/_launch_config.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ cdef class LaunchConfig:
public bint cooperative_launch

vector[cydriver.CUlaunchAttribute] _attrs
object __weakref__

cdef cydriver.CUlaunchConfig _to_native_launch_config(self)

Expand Down
1 change: 1 addition & 0 deletions cuda_core/cuda/core/_memory/_buffer.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ cdef class Buffer:
object _owner
_MemAttrs _mem_attrs
bint _mem_attrs_inited
object __weakref__


cdef class MemoryResource:
Expand Down
2 changes: 1 addition & 1 deletion cuda_core/cuda/core/_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -546,7 +546,7 @@ class ObjectCode:
:class:`~cuda.core.Program`
"""

__slots__ = ("_handle", "_code_type", "_module", "_loader", "_sym_map", "_name")
__slots__ = ("_handle", "_code_type", "_module", "_loader", "_sym_map", "_name", "__weakref__")
_supported_code_type = ("cubin", "ptx", "ltoir", "fatbin", "object", "library")

def __new__(self, *args, **kwargs):
Expand Down
1 change: 1 addition & 0 deletions cuda_core/cuda/core/_stream.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ cdef class Stream:
int _device_id
int _nonblocking
int _priority
object __weakref__

@staticmethod
cdef Stream _from_handle(type cls, StreamHandle h_stream)
Expand Down
75 changes: 75 additions & 0 deletions cuda_core/tests/test_weakref.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# SPDX-License-Identifier: Apache-2.0

import weakref

import pytest
from cuda.core import Device


@pytest.fixture(scope="module")
def device():
dev = Device()
dev.set_current()
return dev


@pytest.fixture
def stream(device):
return device.create_stream()


@pytest.fixture
def event(device):
return device.create_event()


@pytest.fixture
def context(device):
return device.context


@pytest.fixture
def buffer(device):
return device.allocate(1024)


@pytest.fixture
def launch_config():
from cuda.core import LaunchConfig

return LaunchConfig(grid=(1,), block=(1,))


@pytest.fixture
def object_code():
from cuda.core import Program

prog = Program('extern "C" __global__ void test_kernel() {}', "c++")
return prog.compile("ptx")


@pytest.fixture
def kernel(object_code):
return object_code.get_kernel("test_kernel")


WEAK_REFERENCEABLE = [
"device",
"stream",
"event",
"context",
"buffer",
"launch_config",
"object_code",
"kernel",
]


@pytest.mark.parametrize("fixture_name", WEAK_REFERENCEABLE)
def test_weakref(fixture_name, request):
"""Core API classes should be weak-referenceable."""
obj = request.getfixturevalue(fixture_name)
ref = weakref.ref(obj)
assert ref() is obj
Loading