collection.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587
  1. from __future__ import annotations
  2. from collections import defaultdict
  3. from typing import Any, Dict, List, Optional, Set, Union
  4. import logging
  5. from .variable import Variable
  6. from .section import VariableSection
  7. logger = logging.getLogger(__name__)
  8. class VariableCollection:
  9. """Manages variables grouped by sections and builds Jinja context."""
  10. def __init__(self, spec: dict[str, Any]) -> None:
  11. """Initialize VariableCollection from a specification dictionary.
  12. Args:
  13. spec: Dictionary containing the complete variable specification structure
  14. Expected format (as used in compose.py):
  15. {
  16. "section_key": {
  17. "title": "Section Title",
  18. "prompt": "Optional prompt text",
  19. "toggle": "optional_toggle_var_name",
  20. "description": "Optional description",
  21. "vars": {
  22. "var_name": {
  23. "description": "Variable description",
  24. "type": "str",
  25. "default": "default_value",
  26. ...
  27. }
  28. }
  29. }
  30. }
  31. """
  32. if not isinstance(spec, dict):
  33. raise ValueError("Spec must be a dictionary")
  34. self._sections: Dict[str, VariableSection] = {}
  35. # NOTE: The _variable_map provides a flat, O(1) lookup for any variable by its name,
  36. # avoiding the need to iterate through sections. It stores references to the same
  37. # Variable objects contained in the _set structure.
  38. self._variable_map: Dict[str, Variable] = {}
  39. self._initialize_sections(spec)
  40. # Validate dependencies after all sections are loaded
  41. self._validate_dependencies()
  42. def _initialize_sections(self, spec: dict[str, Any]) -> None:
  43. """Initialize sections from the spec."""
  44. for section_key, section_data in spec.items():
  45. if not isinstance(section_data, dict):
  46. continue
  47. section = self._create_section(section_key, section_data)
  48. # Guard against None from empty YAML sections (vars: with no content)
  49. vars_data = section_data.get("vars") or {}
  50. self._initialize_variables(section, vars_data)
  51. self._sections[section_key] = section
  52. # Validate all variable names are unique across sections
  53. self._validate_unique_variable_names()
  54. def _create_section(self, key: str, data: dict[str, Any]) -> VariableSection:
  55. """Create a VariableSection from data."""
  56. section_init_data = {
  57. "key": key,
  58. "title": data.get("title", key.replace("_", " ").title()),
  59. "description": data.get("description"),
  60. "toggle": data.get("toggle"),
  61. "required": data.get("required", key == "general"),
  62. "needs": data.get("needs")
  63. }
  64. return VariableSection(section_init_data)
  65. def _initialize_variables(self, section: VariableSection, vars_data: dict[str, Any]) -> None:
  66. """Initialize variables for a section."""
  67. # Guard against None from empty YAML sections
  68. if vars_data is None:
  69. vars_data = {}
  70. for var_name, var_data in vars_data.items():
  71. var_init_data = {"name": var_name, **var_data}
  72. variable = Variable(var_init_data)
  73. section.variables[var_name] = variable
  74. # NOTE: Populate the direct lookup map for efficient access.
  75. self._variable_map[var_name] = variable
  76. # Validate toggle variable after all variables are added
  77. self._validate_section_toggle(section)
  78. # TODO: Add more section-level validation:
  79. # - Validate that required sections have at least one non-toggle variable
  80. # - Validate that enum variables have non-empty options lists
  81. # - Validate that variable names follow naming conventions (e.g., lowercase_with_underscores)
  82. # - Validate that default values are compatible with their type definitions
  83. def _validate_unique_variable_names(self) -> None:
  84. """Validate that all variable names are unique across all sections."""
  85. var_to_sections: Dict[str, List[str]] = defaultdict(list)
  86. # Build mapping of variable names to sections
  87. for section_key, section in self._sections.items():
  88. for var_name in section.variables:
  89. var_to_sections[var_name].append(section_key)
  90. # Find duplicates and format error
  91. duplicates = {var: sections for var, sections in var_to_sections.items() if len(sections) > 1}
  92. if duplicates:
  93. errors = ["Variable names must be unique across all sections, but found duplicates:"]
  94. errors.extend(f" - '{var}' appears in sections: {', '.join(secs)}" for var, secs in sorted(duplicates.items()))
  95. errors.append("\nPlease rename variables to be unique or consolidate them into a single section.")
  96. error_msg = "\n".join(errors)
  97. logger.error(error_msg)
  98. raise ValueError(error_msg)
  99. def _validate_section_toggle(self, section: VariableSection) -> None:
  100. """Validate that toggle variable is of type bool if it exists.
  101. If the toggle variable doesn't exist (e.g., filtered out), removes the toggle.
  102. Args:
  103. section: The section to validate
  104. Raises:
  105. ValueError: If toggle variable exists but is not boolean type
  106. """
  107. if not section.toggle:
  108. return
  109. toggle_var = section.variables.get(section.toggle)
  110. if not toggle_var:
  111. # Toggle variable doesn't exist (e.g., was filtered out) - remove toggle metadata
  112. section.toggle = None
  113. return
  114. if toggle_var.type != "bool":
  115. raise ValueError(
  116. f"Section '{section.key}' toggle variable '{section.toggle}' must be type 'bool', "
  117. f"but is type '{toggle_var.type}'"
  118. )
  119. def _validate_dependencies(self) -> None:
  120. """Validate section dependencies for cycles and missing references.
  121. Raises:
  122. ValueError: If circular dependencies or missing section references are found
  123. """
  124. # Check for missing dependencies
  125. for section_key, section in self._sections.items():
  126. for dep in section.needs:
  127. if dep not in self._sections:
  128. raise ValueError(
  129. f"Section '{section_key}' depends on '{dep}', but '{dep}' does not exist"
  130. )
  131. # Check for circular dependencies using depth-first search
  132. visited = set()
  133. rec_stack = set()
  134. def has_cycle(section_key: str) -> bool:
  135. visited.add(section_key)
  136. rec_stack.add(section_key)
  137. section = self._sections[section_key]
  138. for dep in section.needs:
  139. if dep not in visited:
  140. if has_cycle(dep):
  141. return True
  142. elif dep in rec_stack:
  143. raise ValueError(
  144. f"Circular dependency detected: '{section_key}' depends on '{dep}', "
  145. f"which creates a cycle"
  146. )
  147. rec_stack.remove(section_key)
  148. return False
  149. for section_key in self._sections:
  150. if section_key not in visited:
  151. has_cycle(section_key)
  152. def is_section_satisfied(self, section_key: str) -> bool:
  153. """Check if all dependencies for a section are satisfied.
  154. A dependency is satisfied if:
  155. 1. The dependency section exists
  156. 2. The dependency section is enabled (if it has a toggle)
  157. Args:
  158. section_key: The key of the section to check
  159. Returns:
  160. True if all dependencies are satisfied, False otherwise
  161. """
  162. section = self._sections.get(section_key)
  163. if not section:
  164. return False
  165. # No dependencies = always satisfied
  166. if not section.needs:
  167. return True
  168. # Check each dependency
  169. for dep_key in section.needs:
  170. dep_section = self._sections.get(dep_key)
  171. if not dep_section:
  172. logger.warning(f"Section '{section_key}' depends on missing section '{dep_key}'")
  173. return False
  174. # Check if dependency is enabled
  175. if not dep_section.is_enabled():
  176. logger.debug(f"Section '{section_key}' dependency '{dep_key}' is disabled")
  177. return False
  178. return True
  179. def sort_sections(self) -> None:
  180. """Sort sections with the following priority:
  181. 1. Dependencies come before dependents (topological sort)
  182. 2. Required sections first (in their original order)
  183. 3. Enabled sections with satisfied dependencies next (in their original order)
  184. 4. Disabled sections or sections with unsatisfied dependencies last (in their original order)
  185. This maintains the original ordering within each group while organizing
  186. sections logically for display and user interaction, and ensures that
  187. sections are prompted in the correct dependency order.
  188. """
  189. # First, perform topological sort to respect dependencies
  190. sorted_keys = self._topological_sort()
  191. # Then apply priority sorting within dependency groups
  192. section_items = [(key, self._sections[key]) for key in sorted_keys]
  193. # Define sort key: (priority, original_index)
  194. # Priority: 0 = required, 1 = enabled with satisfied dependencies, 2 = disabled or unsatisfied dependencies
  195. def get_sort_key(item_with_index):
  196. index, (key, section) = item_with_index
  197. if section.required:
  198. priority = 0
  199. elif section.is_enabled() and self.is_section_satisfied(key):
  200. priority = 1
  201. else:
  202. priority = 2
  203. return (priority, index)
  204. # Sort with original index to maintain order within each priority group
  205. # Note: This preserves the topological order from earlier
  206. sorted_items = sorted(
  207. enumerate(section_items),
  208. key=get_sort_key
  209. )
  210. # Rebuild _sections dict in new order
  211. self._sections = {key: section for _, (key, section) in sorted_items}
  212. def _topological_sort(self) -> List[str]:
  213. """Perform topological sort on sections based on dependencies using Kahn's algorithm."""
  214. in_degree = {key: len(section.needs) for key, section in self._sections.items()}
  215. queue = [key for key, degree in in_degree.items() if degree == 0]
  216. queue.sort(key=lambda k: list(self._sections.keys()).index(k)) # Preserve original order
  217. result = []
  218. while queue:
  219. current = queue.pop(0)
  220. result.append(current)
  221. # Update in-degree for dependent sections
  222. for key, section in self._sections.items():
  223. if current in section.needs:
  224. in_degree[key] -= 1
  225. if in_degree[key] == 0:
  226. queue.append(key)
  227. # Fallback to original order if cycle detected
  228. if len(result) != len(self._sections):
  229. logger.warning("Topological sort incomplete - using original order")
  230. return list(self._sections.keys())
  231. return result
  232. def get_sections(self) -> Dict[str, VariableSection]:
  233. """Get all sections in the collection."""
  234. return self._sections.copy()
  235. def get_section(self, key: str) -> Optional[VariableSection]:
  236. """Get a specific section by its key."""
  237. return self._sections.get(key)
  238. def has_sections(self) -> bool:
  239. """Check if the collection has any sections."""
  240. return bool(self._sections)
  241. def get_all_values(self) -> dict[str, Any]:
  242. """Get all variable values as a dictionary."""
  243. # NOTE: Uses _variable_map for O(1) access
  244. return {name: var.convert(var.value) for name, var in self._variable_map.items()}
  245. def get_satisfied_values(self) -> dict[str, Any]:
  246. """Get variable values only from sections with satisfied dependencies.
  247. This respects both toggle states and section dependencies, ensuring that:
  248. - Variables from disabled sections (toggle=false) are excluded
  249. - Variables from sections with unsatisfied dependencies are excluded
  250. Returns:
  251. Dictionary of variable names to values for satisfied sections only
  252. """
  253. satisfied_values = {}
  254. for section_key, section in self._sections.items():
  255. # Skip sections with unsatisfied dependencies
  256. if not self.is_section_satisfied(section_key):
  257. logger.debug(f"Excluding variables from section '{section_key}' - dependencies not satisfied")
  258. continue
  259. # Skip disabled sections (toggle check)
  260. if not section.is_enabled():
  261. logger.debug(f"Excluding variables from section '{section_key}' - section is disabled")
  262. continue
  263. # Include all variables from this satisfied section
  264. for var_name, variable in section.variables.items():
  265. satisfied_values[var_name] = variable.convert(variable.value)
  266. return satisfied_values
  267. def get_sensitive_variables(self) -> Dict[str, Any]:
  268. """Get only the sensitive variables with their values."""
  269. return {name: var.value for name, var in self._variable_map.items() if var.sensitive and var.value}
  270. def apply_defaults(self, defaults: dict[str, Any], origin: str = "cli") -> list[str]:
  271. """Apply default values to variables, updating their origin.
  272. Args:
  273. defaults: Dictionary mapping variable names to their default values
  274. origin: Source of these defaults (e.g., 'config', 'cli')
  275. Returns:
  276. List of variable names that were successfully updated
  277. """
  278. # NOTE: This method uses the _variable_map for a significant performance gain,
  279. # as it allows direct O(1) lookup of variables instead of iterating
  280. # through all sections to find a match.
  281. successful = []
  282. errors = []
  283. for var_name, value in defaults.items():
  284. try:
  285. variable = self._variable_map.get(var_name)
  286. if not variable:
  287. logger.warning(f"Variable '{var_name}' not found in template")
  288. continue
  289. # Store original value before overriding (for display purposes)
  290. # Only store if this is the first time config is being applied
  291. if origin == "config" and not hasattr(variable, '_original_stored'):
  292. variable.original_value = variable.value
  293. variable._original_stored = True
  294. # Convert and set the new value
  295. converted_value = variable.convert(value)
  296. variable.value = converted_value
  297. # Set origin to the current source (not a chain)
  298. variable.origin = origin
  299. successful.append(var_name)
  300. except ValueError as e:
  301. error_msg = f"Invalid value for '{var_name}': {value} - {e}"
  302. errors.append(error_msg)
  303. logger.error(error_msg)
  304. if errors:
  305. logger.warning(f"Some defaults failed to apply: {'; '.join(errors)}")
  306. return successful
  307. def validate_all(self) -> None:
  308. """Validate all variables in the collection, skipping disabled and unsatisfied sections."""
  309. errors: list[str] = []
  310. for section_key, section in self._sections.items():
  311. # Skip sections with unsatisfied dependencies or disabled via toggle
  312. if not self.is_section_satisfied(section_key) or not section.is_enabled():
  313. logger.debug(f"Skipping validation for section '{section_key}'")
  314. continue
  315. # Validate each variable in the section
  316. for var_name, variable in section.variables.items():
  317. try:
  318. # Skip autogenerated variables when empty
  319. if variable.autogenerated and not variable.value:
  320. continue
  321. # Check required fields
  322. if variable.value is None:
  323. if variable.is_required():
  324. errors.append(f"{section.key}.{var_name} (required - no default provided)")
  325. continue
  326. # Validate typed value
  327. typed = variable.convert(variable.value)
  328. if variable.type not in ("bool",) and not typed:
  329. msg = f"{section.key}.{var_name}"
  330. errors.append(f"{msg} (required - cannot be empty)" if variable.is_required() else f"{msg} (empty)")
  331. except ValueError as e:
  332. errors.append(f"{section.key}.{var_name} (invalid format: {e})")
  333. if errors:
  334. error_msg = "Variable validation failed: " + ", ".join(errors)
  335. logger.error(error_msg)
  336. raise ValueError(error_msg)
  337. def merge(self, other_spec: Union[Dict[str, Any], 'VariableCollection'], origin: str = "override") -> 'VariableCollection':
  338. """Merge another spec or VariableCollection into this one with precedence tracking.
  339. OPTIMIZED: Works directly on objects without dict conversions for better performance.
  340. The other spec/collection has higher precedence and will override values in self.
  341. Creates a new VariableCollection with merged data.
  342. Args:
  343. other_spec: Either a spec dictionary or another VariableCollection to merge
  344. origin: Origin label for variables from other_spec (e.g., 'template', 'config')
  345. Returns:
  346. New VariableCollection with merged data
  347. Example:
  348. module_vars = VariableCollection(module_spec)
  349. template_vars = module_vars.merge(template_spec, origin='template')
  350. # Variables from template_spec override module_spec
  351. # Origins tracked: 'module' or 'module -> template'
  352. """
  353. # Convert dict to VariableCollection if needed (only once)
  354. if isinstance(other_spec, dict):
  355. other = VariableCollection(other_spec)
  356. else:
  357. other = other_spec
  358. # Create new collection without calling __init__ (optimization)
  359. merged = VariableCollection.__new__(VariableCollection)
  360. merged._sections = {}
  361. merged._variable_map = {}
  362. # First pass: clone sections from self
  363. for section_key, self_section in self._sections.items():
  364. if section_key in other._sections:
  365. # Section exists in both - will merge
  366. merged._sections[section_key] = self._merge_sections(
  367. self_section,
  368. other._sections[section_key],
  369. origin
  370. )
  371. else:
  372. # Section only in self - clone it
  373. merged._sections[section_key] = self_section.clone()
  374. # Second pass: add sections that only exist in other
  375. for section_key, other_section in other._sections.items():
  376. if section_key not in merged._sections:
  377. # New section from other - clone with origin update
  378. merged._sections[section_key] = other_section.clone(origin_update=origin)
  379. # Rebuild variable map for O(1) lookups
  380. for section in merged._sections.values():
  381. for var_name, variable in section.variables.items():
  382. merged._variable_map[var_name] = variable
  383. return merged
  384. def _merge_sections(self, self_section: VariableSection, other_section: VariableSection, origin: str) -> VariableSection:
  385. """Merge two sections, with other_section taking precedence."""
  386. merged_section = self_section.clone()
  387. # Update section metadata from other (other takes precedence)
  388. for attr in ('title', 'description', 'toggle'):
  389. if getattr(other_section, attr):
  390. setattr(merged_section, attr, getattr(other_section, attr))
  391. merged_section.required = other_section.required
  392. if other_section.needs:
  393. merged_section.needs = other_section.needs.copy()
  394. # Merge variables
  395. for var_name, other_var in other_section.variables.items():
  396. if var_name in merged_section.variables:
  397. # Variable exists in both - merge with other taking precedence
  398. self_var = merged_section.variables[var_name]
  399. # Build update dict with ONLY explicitly provided fields from other
  400. update = {'origin': origin}
  401. field_map = {
  402. 'type': other_var.type,
  403. 'description': other_var.description,
  404. 'prompt': other_var.prompt,
  405. 'options': other_var.options,
  406. 'sensitive': other_var.sensitive,
  407. 'extra': other_var.extra
  408. }
  409. # Add fields that were explicitly provided and have values
  410. for field, value in field_map.items():
  411. if field in other_var._explicit_fields and value:
  412. update[field] = value
  413. # Special handling for value/default
  414. if ('value' in other_var._explicit_fields or 'default' in other_var._explicit_fields) and other_var.value is not None:
  415. update['value'] = other_var.value
  416. merged_section.variables[var_name] = self_var.clone(update=update)
  417. else:
  418. # New variable from other - clone with origin
  419. merged_section.variables[var_name] = other_var.clone(update={'origin': origin})
  420. return merged_section
  421. def filter_to_used(self, used_variables: Set[str], keep_sensitive: bool = True) -> 'VariableCollection':
  422. """Filter collection to only variables that are used (or sensitive).
  423. OPTIMIZED: Works directly on objects without dict conversions for better performance.
  424. Creates a new VariableCollection containing only the variables in used_variables.
  425. Sections with no remaining variables are removed.
  426. Args:
  427. used_variables: Set of variable names that are actually used
  428. keep_sensitive: If True, also keep sensitive variables even if not in used set
  429. Returns:
  430. New VariableCollection with filtered variables
  431. Example:
  432. all_vars = VariableCollection(spec)
  433. used_vars = all_vars.filter_to_used({'var1', 'var2', 'var3'})
  434. # Only var1, var2, var3 (and any sensitive vars) remain
  435. """
  436. # Create new collection without calling __init__ (optimization)
  437. filtered = VariableCollection.__new__(VariableCollection)
  438. filtered._sections = {}
  439. filtered._variable_map = {}
  440. # Filter each section
  441. for section_key, section in self._sections.items():
  442. # Create a new section with same metadata
  443. filtered_section = VariableSection({
  444. 'key': section.key,
  445. 'title': section.title,
  446. 'description': section.description,
  447. 'toggle': section.toggle,
  448. 'required': section.required,
  449. 'needs': section.needs.copy() if section.needs else None,
  450. })
  451. # Clone only the variables that should be included
  452. for var_name, variable in section.variables.items():
  453. # Include if used OR if sensitive (and keep_sensitive is True)
  454. should_include = (
  455. var_name in used_variables or
  456. (keep_sensitive and variable.sensitive)
  457. )
  458. if should_include:
  459. filtered_section.variables[var_name] = variable.clone()
  460. # Only add section if it has variables
  461. if filtered_section.variables:
  462. filtered._sections[section_key] = filtered_section
  463. # Add variables to map
  464. for var_name, variable in filtered_section.variables.items():
  465. filtered._variable_map[var_name] = variable
  466. return filtered
  467. def get_all_variable_names(self) -> Set[str]:
  468. """Get set of all variable names across all sections.
  469. Returns:
  470. Set of all variable names
  471. """
  472. return set(self._variable_map.keys())