Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ jobs:
strategy:
fail-fast: false
matrix:
python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
python-version: ["3.8", "3.9", "3.10", "3.11", "3.12", "3.13", "3.14"]

defaults:
run:
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ requires-python = ">=3.8"
dependencies = [
"httpx>=0.21.0,<1",
"packaging",
"pydantic>1.10.7",
"pydantic>=2.0",
"typing_extensions>=4.5.0",
]

Expand Down
2 changes: 1 addition & 1 deletion replicate/account.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class Account(Resource):
name: str
"""The name of the account."""

github_url: Optional[str]
github_url: Optional[str] = None
"""The GitHub URL of the account."""


Expand Down
11 changes: 3 additions & 8 deletions replicate/deployment.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, TypedDict, Union

import pydantic
from typing_extensions import Unpack, deprecated

from replicate.account import Account
Expand All @@ -13,12 +14,6 @@
)
from replicate.resource import Namespace, Resource

try:
from pydantic import v1 as pydantic # type: ignore
except ImportError:
import pydantic # type: ignore


if TYPE_CHECKING:
from replicate.client import Client
from replicate.prediction import Predictions
Expand Down Expand Up @@ -66,7 +61,7 @@ class Release(Resource):
The time the release was created.
"""

created_by: Optional[Account]
created_by: Optional[Account] = None
"""
The account that created the release.
"""
Expand Down Expand Up @@ -96,7 +91,7 @@ class Configuration(Resource):
The deployment configuration.
"""

current_release: Optional[Release]
current_release: Optional[Release] = None
"""
The current release of the deployment.
"""
Expand Down
2 changes: 1 addition & 1 deletion replicate/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ class File(Resource):
created_at: str
"""The time the file was created."""

expires_at: Optional[str]
expires_at: Optional[str] = None
"""The time the file will expire."""

urls: Dict[str, str]
Expand Down
23 changes: 9 additions & 14 deletions replicate/model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Tuple, Union, overload

import pydantic
from typing_extensions import NotRequired, TypedDict, Unpack, deprecated

from replicate.exceptions import ReplicateException
Expand All @@ -15,12 +16,6 @@
from replicate.resource import Namespace, Resource
from replicate.version import Version, Versions

try:
from pydantic import v1 as pydantic # type: ignore
except ImportError:
import pydantic # type: ignore


if TYPE_CHECKING:
from replicate.client import Client
from replicate.prediction import Predictions
Expand Down Expand Up @@ -48,7 +43,7 @@ class Model(Resource):
The name of the model.
"""

description: Optional[str]
description: Optional[str] = None
"""
The description of the model.
"""
Expand All @@ -58,17 +53,17 @@ class Model(Resource):
The visibility of the model. Can be 'public' or 'private'.
"""

github_url: Optional[str]
github_url: Optional[str] = None
"""
The GitHub URL of the model.
"""

paper_url: Optional[str]
paper_url: Optional[str] = None
"""
The URL of the paper related to the model.
"""

license_url: Optional[str]
license_url: Optional[str] = None
"""
The URL of the license for the model.
"""
Expand All @@ -78,17 +73,17 @@ class Model(Resource):
The number of runs of the model.
"""

cover_image_url: Optional[str]
cover_image_url: Optional[str] = None
"""
The URL of the cover image for the model.
"""

default_example: Optional[Prediction]
default_example: Optional[Prediction] = None
"""
The default example of the model.
"""

latest_version: Optional[Version]
latest_version: Optional[Version] = None
"""
The latest version of the model.
"""
Expand Down Expand Up @@ -137,7 +132,7 @@ def reload(self) -> None:
"""

obj = self._client.models.get(f"{self.owner}/{self.name}")
for name, value in obj.dict().items():
for name, value in obj.model_dump().items():
setattr(self, name, value)


Expand Down
5 changes: 1 addition & 4 deletions replicate/pagination.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,7 @@
Union,
)

try:
from pydantic import v1 as pydantic # type: ignore
except ImportError:
import pydantic # type: ignore
import pydantic

from replicate.resource import Resource

Expand Down
32 changes: 14 additions & 18 deletions replicate/prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
)

import httpx
import pydantic
from typing_extensions import NotRequired, TypedDict, Unpack

from replicate.exceptions import ModelError, ReplicateError
Expand All @@ -27,11 +28,6 @@
from replicate.stream import EventSource
from replicate.version import Version

try:
from pydantic import v1 as pydantic # type: ignore
except ImportError:
import pydantic # type: ignore

if TYPE_CHECKING:
from replicate.client import Client
from replicate.deployment import Deployment
Expand All @@ -58,31 +54,31 @@ class Prediction(Resource):
status: Literal["starting", "processing", "succeeded", "failed", "canceled"]
"""The status of the prediction."""

input: Optional[Dict[str, Any]]
input: Optional[Dict[str, Any]] = None
"""The input to the prediction."""

output: Optional[Any]
output: Optional[Any] = None
"""The output of the prediction."""

logs: Optional[str]
logs: Optional[str] = None
"""The logs of the prediction."""

error: Optional[str]
error: Optional[str] = None
"""The error encountered during the prediction, if any."""

metrics: Optional[Dict[str, Any]]
metrics: Optional[Dict[str, Any]] = None
"""Metrics for the prediction."""

created_at: Optional[str]
created_at: Optional[str] = None
"""When the prediction was created."""

started_at: Optional[str]
started_at: Optional[str] = None
"""When the prediction was started."""

completed_at: Optional[str]
completed_at: Optional[str] = None
"""When the prediction was completed, if finished."""

urls: Optional[Dict[str, str]]
urls: Optional[Dict[str, str]] = None
"""
URLs associated with the prediction.

Expand Down Expand Up @@ -214,7 +210,7 @@ def cancel(self) -> None:
"""

canceled = self._client.predictions.cancel(self.id)
for name, value in canceled.dict().items():
for name, value in canceled.model_dump().items():
setattr(self, name, value)

async def async_cancel(self) -> None:
Expand All @@ -223,7 +219,7 @@ async def async_cancel(self) -> None:
"""

canceled = await self._client.predictions.async_cancel(self.id)
for name, value in canceled.dict().items():
for name, value in canceled.model_dump().items():
setattr(self, name, value)

def reload(self) -> None:
Expand All @@ -232,7 +228,7 @@ def reload(self) -> None:
"""

updated = self._client.predictions.get(self.id)
for name, value in updated.dict().items():
for name, value in updated.model_dump().items():
setattr(self, name, value)

async def async_reload(self) -> None:
Expand All @@ -241,7 +237,7 @@ async def async_reload(self) -> None:
"""

updated = await self._client.predictions.async_get(self.id)
for name, value in updated.dict().items():
for name, value in updated.model_dump().items():
setattr(self, name, value)

def output_iterator(self) -> Iterator[Any]:
Expand Down
5 changes: 1 addition & 4 deletions replicate/resource.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
import abc
from typing import TYPE_CHECKING

try:
from pydantic import v1 as pydantic # type: ignore
except ImportError:
import pydantic # type: ignore
import pydantic

if TYPE_CHECKING:
from replicate.client import Client
Expand Down
9 changes: 2 additions & 7 deletions replicate/stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,13 @@
)

import httpx
import pydantic
from typing_extensions import Unpack

from replicate import identifier
from replicate.exceptions import ReplicateError
from replicate.helpers import transform_output

try:
from pydantic import v1 as pydantic # type: ignore
except ImportError:
import pydantic # type: ignore


if TYPE_CHECKING:
from replicate.client import Client
from replicate.identifier import ModelVersionIdentifier
Expand All @@ -49,7 +44,7 @@ class EventType(Enum):
event: EventType
data: str
id: str
retry: Optional[int]
retry: Optional[int] = None

def __str__(self) -> str:
if self.event == ServerSentEvent.EventType.OUTPUT:
Expand Down
Loading