diff --git a/commitizen/commands/bump.py b/commitizen/commands/bump.py index f6637b5c7..e4b94c482 100644 --- a/commitizen/commands/bump.py +++ b/commitizen/commands/bump.py @@ -2,13 +2,14 @@ import warnings from logging import getLogger -from typing import TYPE_CHECKING, cast +from typing import TYPE_CHECKING import questionary from commitizen import bump, factory, git, hooks, out from commitizen.changelog_formats import get_changelog_format from commitizen.commands.changelog import Changelog +from commitizen.config.settings import ChainSettings from commitizen.defaults import Settings from commitizen.exceptions import ( BumpCommitFailedError, @@ -70,37 +71,11 @@ def __init__(self, config: BaseConfig, arguments: BumpArgs) -> None: self.config: BaseConfig = config self.arguments = arguments - self.bump_settings = cast( - "BumpArgs", - { - **config.settings, - **{ - k: v - for k in ( - "annotated_tag_message", - "annotated_tag", - "bump_message", - "file_name", - "gpg_sign", - "increment_mode", - "increment", - "major_version_zero", - "prerelease_offset", - "prerelease", - "tag_format", - "template", - ) - if (v := arguments.get(k)) is not None - }, - }, - ) + self.settings = ChainSettings(config.settings, arguments).load_settings() self.cz = factory.committer_factory(self.config) self.changelog_flag = arguments["changelog"] self.changelog_to_stdout = arguments["changelog_to_stdout"] self.git_output_to_stderr = arguments["git_output_to_stderr"] - self.no_verify = arguments["no_verify"] - self.check_consistency = arguments["check_consistency"] - self.retry = arguments["retry"] self.pre_bump_hooks = self.config.settings["pre_bump_hooks"] self.post_bump_hooks = self.config.settings["post_bump_hooks"] deprecated_version_type = arguments.get("version_type") @@ -148,7 +123,7 @@ def _find_increment(self, commits: list[git.GitCommit]) -> Increment | None: # self.cz.bump_map = defaults.bump_map_major_version_zero bump_map = ( self.cz.bump_map_major_version_zero - if self.bump_settings["major_version_zero"] + if self.settings["major_version_zero"] else self.cz.bump_map ) bump_pattern = self.cz.bump_pattern @@ -230,7 +205,7 @@ def _resolve_increment_and_new_version( return increment, current_version.bump( increment, prerelease=self.arguments["prerelease"], - prerelease_offset=self.bump_settings["prerelease_offset"], + prerelease_offset=self.settings["prerelease_offset"], devrelease=self.arguments["devrelease"], is_local_version=self.arguments["local_version"], build_metadata=self.arguments["build_metadata"], @@ -262,7 +237,7 @@ def __call__(self) -> None: ) ) - rules = TagRules.from_settings(cast("Settings", self.bump_settings)) + rules = TagRules.from_settings(self.settings) current_tag = rules.find_tag_for(git.get_tags(), current_version) current_tag_version = ( current_tag.name if current_tag else rules.normalize_tag(current_version) @@ -285,7 +260,7 @@ def __call__(self) -> None: raise DryRunExit() message = bump.create_commit_message( - current_version, new_version, self.bump_settings["bump_message"] + current_version, new_version, self.settings["bump_message"] ) # Report found information information = f"{message}\ntag to create: {new_tag_version}\n" @@ -342,8 +317,8 @@ def __call__(self) -> None: bump.update_version_in_files( str(current_version), str(new_version), - self.bump_settings["version_files"], - check_consistency=self.check_consistency, + self.settings["version_files"], + check_consistency=self.arguments["check_consistency"], encoding=self.config.settings["encoding"], ) ) @@ -372,7 +347,7 @@ def __call__(self) -> None: # FIXME: check if any changes have been staged git.add(*updated_files) c = git.commit(message, args=self._get_commit_args()) - if self.retry and c.return_code != 0 and self.changelog_flag: + if self.arguments["retry"] and c.return_code != 0 and self.changelog_flag: # Maybe pre-commit reformatted some files? Retry once logger.debug("1st git.commit error: %s", c.err) logger.info("1st commit attempt failed; retrying once") @@ -391,18 +366,18 @@ def __call__(self) -> None: new_tag_version, signed=any( ( - self.bump_settings.get("gpg_sign"), - self.config.settings.get("gpg_sign"), + self.settings.get("gpg_sign"), + self.config.settings.get("gpg_sign"), # TODO: remove this ) ), annotated=any( ( - self.bump_settings.get("annotated_tag"), - self.config.settings.get("annotated_tag"), - self.bump_settings.get("annotated_tag_message"), + self.settings.get("annotated_tag"), + self.config.settings.get("annotated_tag"), # TODO: remove this + self.settings.get("annotated_tag_message"), ) ), - msg=self.bump_settings.get("annotated_tag_message", None), + msg=self.settings.get("annotated_tag_message", None), # type: ignore[arg-type] # TODO: also get from self.config.settings? ) if c.return_code != 0: @@ -432,6 +407,6 @@ def __call__(self) -> None: def _get_commit_args(self) -> str: commit_args = ["-a"] - if self.no_verify: + if self.arguments["no_verify"]: commit_args.append("--no-verify") return " ".join(commit_args) diff --git a/commitizen/commands/check.py b/commitizen/commands/check.py index 8ec5b47f8..d543f5171 100644 --- a/commitizen/commands/check.py +++ b/commitizen/commands/check.py @@ -5,6 +5,7 @@ from typing import TYPE_CHECKING, TypedDict from commitizen import factory, git, out +from commitizen.config.settings import ChainSettings from commitizen.exceptions import ( InvalidCommandArgumentError, InvalidCommitMessageError, @@ -20,7 +21,7 @@ class CheckArgs(TypedDict, total=False): commit_msg: str rev_range: str allow_abort: bool - message_length_limit: int | None + message_length_limit: int allowed_prefixes: list[str] message: str use_default_range: bool @@ -37,25 +38,12 @@ def __init__(self, config: BaseConfig, arguments: CheckArgs, *args: object) -> N arguments: All the flags provided by the user cwd: Current work directory """ + self.settings = ChainSettings(config.settings, arguments).load_settings() self.commit_msg_file = arguments.get("commit_msg_file") self.commit_msg = arguments.get("message") self.rev_range = arguments.get("rev_range") - self.allow_abort = bool( - arguments.get("allow_abort", config.settings["allow_abort"]) - ) self.use_default_range = bool(arguments.get("use_default_range")) - self.max_msg_length = arguments.get( - "message_length_limit", config.settings.get("message_length_limit", None) - ) - - # we need to distinguish between None and [], which is a valid value - allowed_prefixes = arguments.get("allowed_prefixes") - self.allowed_prefixes: list[str] = ( - allowed_prefixes - if allowed_prefixes is not None - else config.settings["allowed_prefixes"] - ) num_exclusive_args_provided = sum( arg is not None @@ -97,9 +85,9 @@ def __call__(self) -> None: check := self.cz.validate_commit_message( commit_msg=commit.message, pattern=pattern, - allow_abort=self.allow_abort, - allowed_prefixes=self.allowed_prefixes, - max_msg_length=self.max_msg_length, + allow_abort=self.settings["allow_abort"], + allowed_prefixes=self.settings["allowed_prefixes"], + max_msg_length=self.settings["message_length_limit"] or 0, commit_hash=commit.rev, ) ).is_valid diff --git a/commitizen/commands/commit.py b/commitizen/commands/commit.py index 3894d0b77..6df943735 100644 --- a/commitizen/commands/commit.py +++ b/commitizen/commands/commit.py @@ -10,6 +10,7 @@ import questionary from commitizen import factory, git, out +from commitizen.config.settings import ChainSettings from commitizen.cz.exceptions import CzException from commitizen.cz.utils import get_backup_file_path from commitizen.exceptions import ( @@ -36,7 +37,7 @@ class CommitArgs(TypedDict, total=False): dry_run: bool edit: bool extra_cli_args: str - message_length_limit: int | None + message_length_limit: int no_retry: bool signoff: bool write_message_to_file: Path | None @@ -53,6 +54,7 @@ def __init__(self, config: BaseConfig, arguments: CommitArgs) -> None: self.config: BaseConfig = config self.cz = factory.committer_factory(self.config) self.arguments = arguments + self.settings = ChainSettings(config.settings, arguments).load_settings() self.backup_file_path = get_backup_file_path() def _read_backup_message(self) -> str | None: @@ -61,16 +63,14 @@ def _read_backup_message(self) -> str | None: return None # Read commit message from backup - with open( - self.backup_file_path, encoding=self.config.settings["encoding"] - ) as f: + with open(self.backup_file_path, encoding=self.settings["encoding"]) as f: return f.read().strip() def _get_message_by_prompt_commit_questions(self) -> str: # Prompt user for the commit message questions = self.cz.questions() for question in (q for q in questions if q["type"] == "list"): - question["use_shortcuts"] = self.config.settings["use_shortcuts"] + question["use_shortcuts"] = self.settings["use_shortcuts"] try: answers = questionary.prompt(questions, style=self.cz.style) except ValueError as err: @@ -83,21 +83,16 @@ def _get_message_by_prompt_commit_questions(self) -> str: raise NoAnswersError() message = self.cz.message(answers) - if limit := self.arguments.get( - "message_length_limit", self.config.settings.get("message_length_limit", 0) - ): - self._validate_subject_length(message=message, length_limit=limit) + if (length_limit := self.settings["message_length_limit"]) > 0: + # By the contract, message_length_limit is set to 0 for no limit + subject = message.partition("\n")[0].strip() + if len(subject) > length_limit: + raise CommitMessageLengthExceededError( + f"Length of commit message exceeds limit ({len(subject)}/{length_limit}), subject: '{subject}'" + ) return message - def _validate_subject_length(self, *, message: str, length_limit: int) -> None: - # By the contract, message_length_limit is set to 0 for no limit - subject = message.partition("\n")[0].strip() - if len(subject) > length_limit: - raise CommitMessageLengthExceededError( - f"Length of commit message exceeds limit ({len(subject)}/{length_limit}), subject: '{subject}'" - ) - def manual_edit(self, message: str) -> str: editor = git.get_core_editor() if editor is None: @@ -123,7 +118,7 @@ def _get_message(self) -> str: return commit_message if ( - self.config.settings.get("retry_after_failure") + self.settings.get("retry_after_failure") and not self.arguments.get("no_retry") and (backup_message := self._read_backup_message()) ): @@ -158,14 +153,14 @@ def __call__(self) -> None: if write_message_to_file: with smart_open( - write_message_to_file, "w", encoding=self.config.settings["encoding"] + write_message_to_file, "w", encoding=self.settings["encoding"] ) as file: file.write(commit_message) if dry_run: raise DryRunExit() - if self.config.settings["always_signoff"] or signoff: + if self.settings["always_signoff"] or signoff: extra_args = f"{extra_args} -s".strip() c = git.commit(commit_message, args=extra_args) @@ -174,7 +169,7 @@ def __call__(self) -> None: # Create commit backup with smart_open( - self.backup_file_path, "w", encoding=self.config.settings["encoding"] + self.backup_file_path, "w", encoding=self.settings["encoding"] ) as f: f.write(commit_message) diff --git a/commitizen/config/settings.py b/commitizen/config/settings.py new file mode 100644 index 000000000..719ef045d --- /dev/null +++ b/commitizen/config/settings.py @@ -0,0 +1,27 @@ +from collections import ChainMap +from collections.abc import Mapping +from typing import Any, cast + +from commitizen.defaults import DEFAULT_SETTINGS, Settings + + +class ChainSettings: + def __init__( + self, + config_file_settings: Mapping[str, Any], + cli_settings: Mapping[str, Any] | None = None, + ) -> None: + if cli_settings is None: + cli_settings = {} + self._chain_map: ChainMap[str, Any] = ChainMap[Any, Any]( + self._remove_none_values(cli_settings), + self._remove_none_values(config_file_settings), + DEFAULT_SETTINGS, # type: ignore[arg-type] + ) + + def load_settings(self) -> Settings: + return cast("Settings", dict(self._chain_map)) + + def _remove_none_values(self, settings: Mapping[str, Any]) -> dict[str, Any]: + """HACK: remove None values from settings to avoid incorrectly overriding settings.""" + return {k: v for k, v in settings.items() if v is not None} diff --git a/commitizen/cz/base.py b/commitizen/cz/base.py index 90633c42e..6c98a52b9 100644 --- a/commitizen/cz/base.py +++ b/commitizen/cz/base.py @@ -118,7 +118,7 @@ def validate_commit_message( pattern: re.Pattern[str], allow_abort: bool, allowed_prefixes: list[str], - max_msg_length: int | None, + max_msg_length: int, commit_hash: str, ) -> ValidationResult: """Validate commit message against the pattern.""" @@ -130,7 +130,7 @@ def validate_commit_message( if any(map(commit_msg.startswith, allowed_prefixes)): return ValidationResult(True, []) - if max_msg_length is not None: + if max_msg_length > 0: msg_len = len(commit_msg.partition("\n")[0].strip()) if msg_len > max_msg_length: # TODO: capitalize the first letter of the error message for consistency in v5 diff --git a/commitizen/defaults.py b/commitizen/defaults.py index 6de41f63d..4865ccc18 100644 --- a/commitizen/defaults.py +++ b/commitizen/defaults.py @@ -48,7 +48,7 @@ class Settings(TypedDict, total=False): ignored_tag_formats: Sequence[str] legacy_tag_formats: Sequence[str] major_version_zero: bool - message_length_limit: int | None + message_length_limit: int name: str post_bump_hooks: list[str] | None pre_bump_hooks: list[str] | None @@ -114,7 +114,7 @@ class Settings(TypedDict, total=False): "template": None, # default provided by plugin "extras": {}, "breaking_change_exclamation_in_title": False, - "message_length_limit": None, # None for no limit + "message_length_limit": 0, # 0 for no limit } MAJOR = "MAJOR" diff --git a/tests/commands/test_check_command.py b/tests/commands/test_check_command.py index b5e3fd2b0..31e03a5c3 100644 --- a/tests/commands/test_check_command.py +++ b/tests/commands/test_check_command.py @@ -385,7 +385,7 @@ def test_check_command_cli_overrides_config_message_length_limit( ): message = "fix(scope): some commit message" config.settings["message_length_limit"] = len(message) - 1 - for message_length_limit in [len(message) + 1, None]: + for message_length_limit in [len(message) + 1, 0]: success_mock.reset_mock() commands.Check( config=config, diff --git a/tests/commands/test_commit_command.py b/tests/commands/test_commit_command.py index 87c7aca42..89a9224b8 100644 --- a/tests/commands/test_commit_command.py +++ b/tests/commands/test_commit_command.py @@ -363,5 +363,5 @@ def test_commit_command_with_config_message_length_limit( success_mock.assert_called_once() success_mock.reset_mock() - commands.Commit(config, {"message_length_limit": None})() + commands.Commit(config, {"message_length_limit": 0})() success_mock.assert_called_once() diff --git a/tests/test_conf.py b/tests/test_conf.py index 6e4256f16..ee5eba5b3 100644 --- a/tests/test_conf.py +++ b/tests/test_conf.py @@ -106,7 +106,7 @@ "template": None, "extras": {}, "breaking_change_exclamation_in_title": False, - "message_length_limit": None, + "message_length_limit": 0, } _new_settings: dict[str, Any] = { @@ -146,7 +146,7 @@ "template": None, "extras": {}, "breaking_change_exclamation_in_title": False, - "message_length_limit": None, + "message_length_limit": 0, } diff --git a/tests/test_settings.py b/tests/test_settings.py new file mode 100644 index 000000000..1be3c9616 --- /dev/null +++ b/tests/test_settings.py @@ -0,0 +1,96 @@ +import pytest + +from commitizen.config.settings import ChainSettings +from commitizen.defaults import DEFAULT_SETTINGS, Settings + + +def test_cz_settings_with_empty_config_file_settings() -> None: + """Test that empty config file settings returns default settings.""" + cz_settings = ChainSettings({}) + assert cz_settings.load_settings() == DEFAULT_SETTINGS + + +@pytest.mark.parametrize( + "config_file_settings,expected_name,expected_changelog", + [ + pytest.param( + {"name": "custom_cz", "changelog_file": "CUSTOM_CHANGELOG.md"}, + "custom_cz", + "CUSTOM_CHANGELOG.md", + id="multiple_config_settings", + ), + pytest.param( + {"name": "test_cz"}, + "test_cz", + DEFAULT_SETTINGS["changelog_file"], + id="partial_config_settings", + ), + ], +) +def test_cz_settings_merges_config_file_settings( + config_file_settings: Settings, expected_name: str, expected_changelog: str +) -> None: + """Test that config file settings override default settings.""" + cz_settings = ChainSettings(config_file_settings) + result = cz_settings.load_settings() + + assert result["name"] == expected_name + assert result["changelog_file"] == expected_changelog + assert result["tag_format"] == DEFAULT_SETTINGS["tag_format"] + + +@pytest.mark.parametrize( + "config_file_settings,cli_settings,expected_name,expected_changelog,expected_tag_format", + [ + pytest.param( + {"name": "config_cz", "changelog_file": "CONFIG_CHANGELOG.md"}, + {"name": "cli_cz"}, + "cli_cz", + "CONFIG_CHANGELOG.md", + DEFAULT_SETTINGS["tag_format"], + id="cli_overrides_config_file", + ), + pytest.param( + {}, + {"name": "cli_cz", "changelog_file": "CLI_CHANGELOG.md"}, + "cli_cz", + "CLI_CHANGELOG.md", + DEFAULT_SETTINGS["tag_format"], + id="cli_overrides_defaults", + ), + pytest.param( + { + "name": "config_cz", + "changelog_file": "CONFIG_CHANGELOG.md", + "tag_format": "v$version", + }, + {"name": "cli_cz", "changelog_file": "CLI_CHANGELOG.md"}, + "cli_cz", + "CLI_CHANGELOG.md", + "v$version", + id="cli_overrides_multiple_config_settings", + ), + pytest.param( + {"name": "config_cz", "changelog_file": "CONFIG_CHANGELOG.md"}, + {}, + "config_cz", + "CONFIG_CHANGELOG.md", + DEFAULT_SETTINGS["tag_format"], + id="empty_cli_with_config_file", + ), + ], +) +def test_cz_settings_cli_precedence( + config_file_settings: Settings, + cli_settings: Settings, + expected_name: str, + expected_changelog: str, + expected_tag_format: str, +) -> None: + """Test that CLI settings take precedence over config file and defaults.""" + cz_settings = ChainSettings(config_file_settings, cli_settings) + result = cz_settings.load_settings() + + assert result["name"] == expected_name + assert result["changelog_file"] == expected_changelog + assert result["tag_format"] == expected_tag_format