Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from pydantic import ValidationError
from sqlalchemy import and_, delete, func, not_, select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import joinedload, with_loader_criteria
from sqlalchemy.orm import joinedload

from dstack._internal import settings
from dstack._internal.core.backends.base.compute import (
Expand Down Expand Up @@ -218,9 +218,8 @@ async def _process_instance(session: AsyncSession, instance: InstanceModel):
.options(joinedload(InstanceModel.project).joinedload(ProjectModel.backends))
.options(joinedload(InstanceModel.jobs).load_only(JobModel.id, JobModel.status))
.options(
joinedload(InstanceModel.fleet).joinedload(FleetModel.instances),
with_loader_criteria(
InstanceModel, InstanceModel.deleted == False, include_aliases=True
joinedload(InstanceModel.fleet).joinedload(
FleetModel.instances.and_(InstanceModel.deleted == False)
),
)
.execution_options(populate_existing=True)
Expand All @@ -233,9 +232,8 @@ async def _process_instance(session: AsyncSession, instance: InstanceModel):
.options(joinedload(InstanceModel.project))
.options(joinedload(InstanceModel.jobs).load_only(JobModel.id, JobModel.status))
.options(
joinedload(InstanceModel.fleet).joinedload(FleetModel.instances),
with_loader_criteria(
InstanceModel, InstanceModel.deleted == False, include_aliases=True
joinedload(InstanceModel.fleet).joinedload(
FleetModel.instances.and_(InstanceModel.deleted == False)
),
)
.execution_options(populate_existing=True)
Expand Down
59 changes: 39 additions & 20 deletions src/dstack/_internal/server/services/fleets.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,12 @@
)
from dstack._internal.core.models.projects import Project
from dstack._internal.core.models.resources import ResourcesSpec
from dstack._internal.core.models.runs import JobProvisioningData, Requirements, get_policy_map
from dstack._internal.core.models.runs import (
JobProvisioningData,
Requirements,
RunStatus,
get_policy_map,
)
from dstack._internal.core.models.users import GlobalRole
from dstack._internal.core.services import validate_dstack_resource_name
from dstack._internal.core.services.diff import ModelDiff, copy_model, diff_models
Expand All @@ -53,6 +58,7 @@
JobModel,
MemberModel,
ProjectModel,
RunModel,
UserModel,
)
from dstack._internal.server.services import events
Expand Down Expand Up @@ -613,48 +619,61 @@ async def delete_fleets(
instance_nums: Optional[List[int]] = None,
):
res = await session.execute(
select(FleetModel)
select(FleetModel.id)
.where(
FleetModel.project_id == project.id,
FleetModel.name.in_(names),
FleetModel.deleted == False,
)
.options(joinedload(FleetModel.instances))
.order_by(FleetModel.id) # take locks in order
.with_for_update(key_share=True)
)
fleet_models = res.scalars().unique().all()
fleets_ids = sorted([f.id for f in fleet_models])
instances_ids = sorted([i.id for f in fleet_models for i in f.instances])
await session.commit()
logger.info("Deleting fleets: %s", [v.name for v in fleet_models])
fleets_ids = list(res.scalars().unique().all())
res = await session.execute(
select(InstanceModel.id)
.where(
InstanceModel.fleet_id.in_(fleets_ids),
InstanceModel.deleted == False,
)
.order_by(InstanceModel.id) # take locks in order
.with_for_update(key_share=True)
)
instances_ids = list(res.scalars().unique().all())
if is_db_sqlite():
# Start new transaction to see committed changes after lock
await session.commit()
async with (
get_locker(get_db().dialect_name).lock_ctx(FleetModel.__tablename__, fleets_ids),
get_locker(get_db().dialect_name).lock_ctx(InstanceModel.__tablename__, instances_ids),
):
# Refetch after lock
# TODO: Lock instances with FOR UPDATE?
# TODO: Do not lock fleet when deleting only instances
# Refetch after lock.
# TODO: Do not lock fleet when deleting only instances.
res = await session.execute(
select(FleetModel)
.where(
FleetModel.project_id == project.id,
FleetModel.name.in_(names),
FleetModel.deleted == False,
)
.where(FleetModel.id.in_(fleets_ids))
.options(
selectinload(FleetModel.instances)
joinedload(FleetModel.instances.and_(InstanceModel.id.in_(instances_ids)))
.joinedload(InstanceModel.jobs)
.load_only(JobModel.id)
)
.options(selectinload(FleetModel.runs))
.options(
joinedload(
FleetModel.runs.and_(RunModel.status.not_in(RunStatus.finished_statuses()))
)
)
.execution_options(populate_existing=True)
.order_by(FleetModel.id) # take locks in order
.with_for_update(key_share=True)
)
fleet_models = res.scalars().unique().all()
fleets = [fleet_model_to_fleet(m) for m in fleet_models]
for fleet in fleets:
if fleet.spec.configuration.ssh_config is not None:
_check_can_manage_ssh_fleets(user=user, project=project)
if instance_nums is None:
logger.info("Deleting fleets: %s", [f.name for f in fleet_models])
else:
logger.info(
"Deleting fleets %s instances %s", [f.name for f in fleet_models], instance_nums
)
for fleet_model in fleet_models:
_terminate_fleet_instances(fleet_model=fleet_model, instance_nums=instance_nums)
# TERMINATING fleets are deleted by process_fleets after instances are terminated
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -597,6 +597,34 @@ async def test_terminate(self, test_db, session: AsyncSession):
assert instance.deleted_at is not None
assert instance.finished_at is not None

@pytest.mark.asyncio
@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
async def test_terminates_terminating_deleted_instance(self, test_db, session: AsyncSession):
# There was a race condition when instance could stay in Terminating while marked as deleted.
# TODO: Drop this after all such "bad" instances are processed.
project = await create_project(session=session)
instance = await create_instance(
session=session, project=project, status=InstanceStatus.TERMINATING
)
instance.deleted = True
instance.termination_reason = InstanceTerminationReason.IDLE_TIMEOUT
instance.last_job_processed_at = instance.deleted_at = (
get_current_datetime() + dt.timedelta(minutes=-19)
)
await session.commit()

with self.mock_terminate_in_backend() as mock:
await process_instances()
mock.assert_called_once()

await session.refresh(instance)

assert instance is not None
assert instance.status == InstanceStatus.TERMINATED
assert instance.deleted == True
assert instance.deleted_at is not None
assert instance.finished_at is not None

@pytest.mark.asyncio
@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
@pytest.mark.parametrize(
Expand Down