config_manager.py 20 KB


  1. from __future__ import annotations
  2. import logging
  3. import shutil
  4. import tempfile
  5. from dataclasses import dataclass
  6. from pathlib import Path
  7. from typing import Any
  8. import yaml
  9. from ..exceptions import ConfigError, ConfigValidationError, YAMLParseError
  10. logger = logging.getLogger(__name__)
  11. @dataclass
  12. class LibraryConfig:
  13. """Configuration for a template library."""
  14. name: str
  15. library_type: str = "git"
  16. url: str | None = None
  17. directory: str | None = None
  18. branch: str = "main"
  19. path: str | None = None
  20. enabled: bool = True
  21. class ConfigManager:
  22. """Manages configuration for the CLI application."""
  23. def __init__(self, config_path: str | Path | None = None) -> None:
  24. """Initialize the configuration manager.
  25. Args:
  26. config_path: Path to the configuration file. If None, auto-detects:
  27. 1. Checks for ./config.yaml (local project config)
  28. 2. Falls back to ~/.config/boilerplates/config.yaml (global config)
  29. """
  30. if config_path is None:
  31. # Check for local config.yaml in current directory first
  32. local_config = Path.cwd() / "config.yaml"
  33. if local_config.exists() and local_config.is_file():
  34. self.config_path = local_config
  35. self.is_local = True
  36. logger.debug(f"Using local config: {local_config}")
  37. else:
  38. # Fall back to global config
  39. config_dir = Path.home() / ".config" / "boilerplates"
  40. config_dir.mkdir(parents=True, exist_ok=True)
  41. self.config_path = config_dir / "config.yaml"
  42. self.is_local = False
  43. else:
  44. self.config_path = Path(config_path)
  45. self.is_local = False
  46. # Create default config if it doesn't exist (only for global config)
  47. if not self.config_path.exists():
  48. if not self.is_local:
  49. self._create_default_config()
  50. else:
  51. raise ConfigError(f"Local config file not found: {self.config_path}")
  52. else:
  53. # Migrate existing config if needed
  54. self._migrate_config_if_needed()
  55. def _create_default_config(self) -> None:
  56. """Create a default configuration file."""
  57. default_config = {
  58. "defaults": {},
  59. "preferences": {"editor": "vim", "output_dir": None, "library_paths": []},
  60. "libraries": [
  61. {
  62. "name": "default",
  63. "type": "git",
  64. "url": "https://github.com/christianlempa/boilerplates.git",
  65. "branch": "main",
  66. "directory": "library",
  67. "enabled": True,
  68. }
  69. ],
  70. }
  71. self._write_config(default_config)
  72. logger.info(f"Created default configuration at {self.config_path}")
  73. def _migrate_config_if_needed(self) -> None:
  74. """Migrate existing config to add missing sections and library types."""
  75. try:
  76. config = self._read_config()
  77. needs_migration = False
  78. # Add libraries section if missing
  79. if "libraries" not in config:
  80. logger.info("Migrating config: adding libraries section")
  81. config["libraries"] = [
  82. {
  83. "name": "default",
  84. "type": "git",
  85. "url": "https://github.com/christianlempa/boilerplates.git",
  86. "branch": "refactor/boilerplates-v2",
  87. "directory": "library",
  88. "enabled": True,
  89. }
  90. ]
  91. needs_migration = True
  92. else:
  93. # Migrate existing libraries to add 'type' field if missing
  94. # For backward compatibility, assume all old libraries without
  95. # 'type' are git libraries
  96. libraries = config.get("libraries", [])
  97. for library in libraries:
  98. if "type" not in library:
  99. lib_name = library.get("name", "unknown")
  100. logger.info(f"Migrating library '{lib_name}': adding type: git")
  101. library["type"] = "git"
  102. needs_migration = True
  103. # Write back if migration was needed
  104. if needs_migration:
  105. self._write_config(config)
  106. logger.info("Config migration completed successfully")
  107. except Exception as e:
  108. logger.warning(f"Config migration failed: {e}")
  109. def _read_config(self) -> dict[str, Any]:
  110. """Read configuration from file.
  111. Returns:
  112. Dictionary containing the configuration.
  113. Raises:
  114. YAMLParseError: If YAML parsing fails.
  115. ConfigValidationError: If configuration structure is invalid.
  116. ConfigError: If reading fails for other reasons.
  117. """
  118. try:
  119. with self.config_path.open() as f:
  120. config = yaml.safe_load(f) or {}
  121. # Validate config structure
  122. self._validate_config_structure(config)
  123. return config
  124. except yaml.YAMLError as e:
  125. logger.error(f"Failed to parse YAML configuration: {e}")
  126. raise YAMLParseError(str(self.config_path), e) from e
  127. except ConfigValidationError:
  128. # Re-raise validation errors as-is
  129. raise
  130. except OSError as e:
  131. logger.error(f"Failed to read configuration file: {e}")
  132. raise ConfigError(f"Failed to read configuration file '{self.config_path}': {e}") from e
  133. def _write_config(self, config: dict[str, Any]) -> None:
  134. """Write configuration to file atomically using temp file + rename pattern.
  135. This prevents config file corruption if write operation fails partway through.
  136. Args:
  137. config: Dictionary containing the configuration to write.
  138. Raises:
  139. ConfigValidationError: If configuration structure is invalid.
  140. ConfigError: If writing fails for any reason.
  141. """
  142. tmp_path = None
  143. try:
  144. # Validate config structure before writing
  145. self._validate_config_structure(config)
  146. # Ensure parent directory exists
  147. self.config_path.parent.mkdir(parents=True, exist_ok=True)
  148. # Write to temporary file in same directory for atomic rename
  149. with tempfile.NamedTemporaryFile(
  150. mode="w",
  151. delete=False,
  152. dir=self.config_path.parent,
  153. prefix=".config_",
  154. suffix=".tmp",
  155. ) as tmp_file:
  156. yaml.dump(config, tmp_file, default_flow_style=False)
  157. tmp_path = tmp_file.name
  158. # Atomic rename (overwrites existing file on POSIX systems)
  159. shutil.move(tmp_path, self.config_path)
  160. logger.debug(f"Configuration written atomically to {self.config_path}")
  161. except ConfigValidationError:
  162. # Re-raise validation errors as-is
  163. if tmp_path:
  164. Path(tmp_path).unlink(missing_ok=True)
  165. raise
  166. except (OSError, yaml.YAMLError) as e:
  167. # Clean up temp file if it exists
  168. if tmp_path:
  169. try:
  170. Path(tmp_path).unlink(missing_ok=True)
  171. except OSError:
  172. logger.warning(f"Failed to clean up temporary file: {tmp_path}")
  173. logger.error(f"Failed to write configuration file: {e}")
  174. raise ConfigError(f"Failed to write configuration to '{self.config_path}': {e}") from e
  175. def _validate_config_structure(self, config: dict[str, Any]) -> None:
  176. """Validate the configuration structure - basic type checking.
  177. Args:
  178. config: Configuration dictionary to validate.
  179. Raises:
  180. ConfigValidationError: If configuration structure is invalid.
  181. """
  182. if not isinstance(config, dict):
  183. raise ConfigValidationError("Configuration must be a dictionary")
  184. # Validate top-level types
  185. self._validate_top_level_types(config)
  186. # Validate defaults structure
  187. self._validate_defaults_types(config)
  188. # Validate libraries structure
  189. self._validate_libraries_fields(config)
  190. def _validate_top_level_types(self, config: dict[str, Any]) -> None:
  191. """Validate top-level config section types."""
  192. if "defaults" in config and not isinstance(config["defaults"], dict):
  193. raise ConfigValidationError("'defaults' must be a dictionary")
  194. if "preferences" in config and not isinstance(config["preferences"], dict):
  195. raise ConfigValidationError("'preferences' must be a dictionary")
  196. if "libraries" in config and not isinstance(config["libraries"], list):
  197. raise ConfigValidationError("'libraries' must be a list")
  198. def _validate_defaults_types(self, config: dict[str, Any]) -> None:
  199. """Validate defaults section has correct types."""
  200. if "defaults" not in config:
  201. return
  202. for module_name, module_defaults in config["defaults"].items():
  203. if not isinstance(module_defaults, dict):
  204. raise ConfigValidationError(f"Defaults for module '{module_name}' must be a dictionary")
  205. def _validate_libraries_fields(self, config: dict[str, Any]) -> None:
  206. """Validate libraries have required fields."""
  207. if "libraries" not in config:
  208. return
  209. for i, library in enumerate(config["libraries"]):
  210. if not isinstance(library, dict):
  211. raise ConfigValidationError(f"Library at index {i} must be a dictionary")
  212. if "name" not in library:
  213. raise ConfigValidationError(f"Library at index {i} missing required field 'name'")
  214. lib_type = library.get("type", "git")
  215. if lib_type == "git" and ("url" not in library or "directory" not in library):
  216. raise ConfigValidationError(
  217. f"Git library at index {i} missing required fields 'url' and/or 'directory'"
  218. )
  219. if lib_type == "static" and "path" not in library:
  220. raise ConfigValidationError(f"Static library at index {i} missing required field 'path'")
  221. def get_config_path(self) -> Path:
  222. """Get the path to the configuration file being used.
  223. Returns:
  224. Path to the configuration file (global or local).
  225. """
  226. return self.config_path
  227. def is_using_local_config(self) -> bool:
  228. """Check if a local configuration file is being used.
  229. Returns:
  230. True if using local config, False if using global config.
  231. """
  232. return self.is_local
  233. def get_defaults(self, module_name: str) -> dict[str, Any]:
  234. """Get default variable values for a module.
  235. Returns defaults in a flat format:
  236. {
  237. "var_name": "value",
  238. "var2_name": "value2"
  239. }
  240. Args:
  241. module_name: Name of the module
  242. Returns:
  243. Dictionary of default values (flat key-value pairs)
  244. """
  245. config = self._read_config()
  246. defaults = config.get("defaults", {})
  247. return defaults.get(module_name, {})
  248. def set_defaults(self, module_name: str, defaults: dict[str, Any]) -> None:
  249. """Set default variable values for a module with comprehensive validation.
  250. Args:
  251. module_name: Name of the module
  252. defaults: Dictionary of defaults (flat key-value pairs):
  253. {"var_name": "value", "var2_name": "value2"}
  254. Raises:
  255. ConfigValidationError: If module name or variable names are invalid.
  256. """
  257. # Basic validation
  258. if not isinstance(module_name, str) or not module_name:
  259. raise ConfigValidationError("Module name must be a non-empty string")
  260. if not isinstance(defaults, dict):
  261. raise ConfigValidationError("Defaults must be a dictionary")
  262. config = self._read_config()
  263. if "defaults" not in config:
  264. config["defaults"] = {}
  265. config["defaults"][module_name] = defaults
  266. self._write_config(config)
  267. logger.info(f"Updated defaults for module '{module_name}'")
  268. def set_default_value(self, module_name: str, var_name: str, value: Any) -> None:
  269. """Set a single default variable value with comprehensive validation.
  270. Args:
  271. module_name: Name of the module
  272. var_name: Name of the variable
  273. value: Default value to set
  274. Raises:
  275. ConfigValidationError: If module name or variable name is invalid.
  276. """
  277. # Basic validation
  278. if not isinstance(module_name, str) or not module_name:
  279. raise ConfigValidationError("Module name must be a non-empty string")
  280. if not isinstance(var_name, str) or not var_name:
  281. raise ConfigValidationError("Variable name must be a non-empty string")
  282. defaults = self.get_defaults(module_name)
  283. defaults[var_name] = value
  284. self.set_defaults(module_name, defaults)
  285. logger.info(f"Set default for '{module_name}.{var_name}' = '{value}'")
  286. def get_default_value(self, module_name: str, var_name: str) -> Any | None:
  287. """Get a single default variable value.
  288. Args:
  289. module_name: Name of the module
  290. var_name: Name of the variable
  291. Returns:
  292. Default value or None if not set
  293. """
  294. defaults = self.get_defaults(module_name)
  295. return defaults.get(var_name)
  296. def clear_defaults(self, module_name: str) -> None:
  297. """Clear all defaults for a module.
  298. Args:
  299. module_name: Name of the module
  300. """
  301. config = self._read_config()
  302. if "defaults" in config and module_name in config["defaults"]:
  303. del config["defaults"][module_name]
  304. self._write_config(config)
  305. logger.info(f"Cleared defaults for module '{module_name}'")
  306. def get_preference(self, key: str) -> Any | None:
  307. """Get a user preference value.
  308. Args:
  309. key: Preference key (e.g., 'editor', 'output_dir', 'library_paths')
  310. Returns:
  311. Preference value or None if not set
  312. """
  313. config = self._read_config()
  314. preferences = config.get("preferences", {})
  315. return preferences.get(key)
  316. def set_preference(self, key: str, value: Any) -> None:
  317. """Set a user preference value with comprehensive validation.
  318. Args:
  319. key: Preference key
  320. value: Preference value
  321. Raises:
  322. ConfigValidationError: If key or value is invalid for known preference types.
  323. """
  324. # Basic validation
  325. if not isinstance(key, str) or not key:
  326. raise ConfigValidationError("Preference key must be a non-empty string")
  327. config = self._read_config()
  328. if "preferences" not in config:
  329. config["preferences"] = {}
  330. config["preferences"][key] = value
  331. self._write_config(config)
  332. logger.info(f"Set preference '{key}' = '{value}'")
  333. def get_all_preferences(self) -> dict[str, Any]:
  334. """Get all user preferences.
  335. Returns:
  336. Dictionary of all preferences
  337. """
  338. config = self._read_config()
  339. return config.get("preferences", {})
  340. def get_libraries(self) -> list[dict[str, Any]]:
  341. """Get all configured libraries.
  342. Returns:
  343. List of library configurations
  344. """
  345. config = self._read_config()
  346. return config.get("libraries", [])
  347. def get_library_by_name(self, name: str) -> dict[str, Any] | None:
  348. """Get a specific library by name.
  349. Args:
  350. name: Name of the library
  351. Returns:
  352. Library configuration dictionary or None if not found
  353. """
  354. libraries = self.get_libraries()
  355. for library in libraries:
  356. if library.get("name") == name:
  357. return library
  358. return None
  359. def add_library(self, lib_config: LibraryConfig) -> None:
  360. """Add a new library to the configuration.
  361. Args:
  362. lib_config: Library configuration
  363. Raises:
  364. ConfigValidationError: If library with the same name already exists or validation fails
  365. """
  366. # Basic validation
  367. if not isinstance(lib_config.name, str) or not lib_config.name:
  368. raise ConfigValidationError("Library name must be a non-empty string")
  369. if lib_config.library_type not in ("git", "static"):
  370. raise ConfigValidationError(f"Library type must be 'git' or 'static', got '{lib_config.library_type}'")
  371. if self.get_library_by_name(lib_config.name):
  372. raise ConfigValidationError(f"Library '{lib_config.name}' already exists")
  373. # Type-specific validation
  374. if lib_config.library_type == "git":
  375. if not lib_config.url or not lib_config.directory:
  376. raise ConfigValidationError("Git libraries require 'url' and 'directory' parameters")
  377. library_dict = {
  378. "name": lib_config.name,
  379. "type": "git",
  380. "url": lib_config.url,
  381. "branch": lib_config.branch,
  382. "directory": lib_config.directory,
  383. "enabled": lib_config.enabled,
  384. }
  385. else: # static
  386. if not lib_config.path:
  387. raise ConfigValidationError("Static libraries require 'path' parameter")
  388. # For backward compatibility with older CLI versions,
  389. # add dummy values for git-specific fields
  390. library_dict = {
  391. "name": lib_config.name,
  392. "type": "static",
  393. "url": "", # Empty string for backward compatibility
  394. "branch": "main", # Default value for backward compatibility
  395. "directory": ".", # Default value for backward compatibility
  396. "path": lib_config.path,
  397. "enabled": lib_config.enabled,
  398. }
  399. config = self._read_config()
  400. if "libraries" not in config:
  401. config["libraries"] = []
  402. config["libraries"].append(library_dict)
  403. self._write_config(config)
  404. logger.info(f"Added {lib_config.library_type} library '{lib_config.name}'")
  405. def remove_library(self, name: str) -> None:
  406. """Remove a library from the configuration.
  407. Args:
  408. name: Name of the library to remove
  409. Raises:
  410. ConfigError: If library is not found
  411. """
  412. config = self._read_config()
  413. libraries = config.get("libraries", [])
  414. # Find and remove the library
  415. new_libraries = [lib for lib in libraries if lib.get("name") != name]
  416. if len(new_libraries) == len(libraries):
  417. raise ConfigError(f"Library '{name}' not found")
  418. config["libraries"] = new_libraries
  419. self._write_config(config)
  420. logger.info(f"Removed library '{name}'")
  421. def update_library(self, name: str, **kwargs: Any) -> None:
  422. """Update a library's configuration.
  423. Args:
  424. name: Name of the library to update
  425. **kwargs: Fields to update (url, branch, directory, enabled)
  426. Raises:
  427. ConfigError: If library is not found
  428. ConfigValidationError: If validation fails
  429. """
  430. config = self._read_config()
  431. libraries = config.get("libraries", [])
  432. # Find the library
  433. library_found = False
  434. for library in libraries:
  435. if library.get("name") == name:
  436. library_found = True
  437. # Update allowed fields
  438. if "url" in kwargs:
  439. library["url"] = kwargs["url"]
  440. if "branch" in kwargs:
  441. library["branch"] = kwargs["branch"]
  442. if "directory" in kwargs:
  443. library["directory"] = kwargs["directory"]
  444. if "enabled" in kwargs:
  445. library["enabled"] = kwargs["enabled"]
  446. break
  447. if not library_found:
  448. raise ConfigError(f"Library '{name}' not found")
  449. config["libraries"] = libraries
  450. self._write_config(config)
  451. logger.info(f"Updated library '{name}'")
  452. def get_libraries_path(self) -> Path:
  453. """Get the path to the libraries directory.
  454. Returns:
  455. Path to the libraries directory (same directory as config file)
  456. """
  457. return self.config_path.parent / "libraries"