prompt.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275
  1. from __future__ import annotations
  2. import logging
  3. from typing import Any, Callable
  4. from rich.console import Console
  5. from rich.prompt import Confirm, IntPrompt, Prompt
  6. from .collection import VariableCollection
  7. from .display import DisplayManager
  8. from .variable import Variable
  9. logger = logging.getLogger(__name__)
  10. class PromptHandler:
  11. """Simple interactive prompt handler for collecting template variables."""
  12. def __init__(self) -> None:
  13. self.console = Console()
  14. self.display = DisplayManager()
  15. def collect_variables(self, variables: VariableCollection) -> dict[str, Any]:
  16. """Collect values for variables by iterating through sections.
  17. Args:
  18. variables: VariableCollection with organized sections and variables
  19. Returns:
  20. Dict of variable names to collected values
  21. """
  22. if not Confirm.ask("Customize any settings?", default=False):
  23. self.console.print("") # Add blank line after prompt
  24. logger.info("User opted to keep all default values")
  25. return {}
  26. self.console.print("") # Add blank line after prompt
  27. collected: dict[str, Any] = {}
  28. # Process each section
  29. for section_key, section in variables.get_sections().items():
  30. if not section.variables:
  31. continue
  32. # Check if dependencies are satisfied
  33. if not self._check_section_dependencies(variables, section_key, section):
  34. continue
  35. # Always show section header first
  36. self.display.display_section_header(section.title, section.description)
  37. # Handle section toggle and determine if enabled
  38. section_will_be_enabled = self._handle_section_toggle(section, collected)
  39. # Collect variables in this section
  40. self._collect_section_variables(section, section_key, section_will_be_enabled, variables, collected)
  41. logger.info(f"Variable collection completed. Collected {len(collected)} values")
  42. return collected
  43. def _check_section_dependencies(self, variables: VariableCollection, section_key: str, section) -> bool:
  44. """Check if section dependencies are satisfied and display skip message if not."""
  45. if not variables.is_section_satisfied(section_key):
  46. # Get list of unsatisfied dependencies for better user feedback
  47. unsatisfied_keys = [dep for dep in section.needs if not variables.is_section_satisfied(dep)]
  48. # Convert section keys to titles for user-friendly display
  49. unsatisfied_titles = []
  50. for dep_key in unsatisfied_keys:
  51. dep_section = variables.get_section(dep_key)
  52. unsatisfied_titles.append(dep_section.title if dep_section else dep_key)
  53. dep_names = ", ".join(unsatisfied_titles) if unsatisfied_titles else "unknown"
  54. self.display.display_skipped(section.title, f"requires {dep_names} to be enabled")
  55. logger.debug(f"Skipping section '{section_key}' - dependencies not satisfied: {dep_names}")
  56. return False
  57. return True
  58. def _handle_section_toggle(self, section, collected: dict[str, Any]) -> bool:
  59. """Handle section toggle prompt and return whether section will be enabled."""
  60. # Required sections are always enabled
  61. if section.required:
  62. logger.debug(f"Processing required section '{section.key}' without toggle prompt")
  63. return True
  64. # Handle optional sections with toggle
  65. if not section.toggle:
  66. return True
  67. toggle_var = section.variables.get(section.toggle)
  68. if not toggle_var:
  69. return True
  70. # Prompt for toggle variable
  71. current_value = toggle_var.convert(toggle_var.value)
  72. new_value = self._prompt_variable(toggle_var, required=section.required)
  73. if new_value != current_value:
  74. collected[toggle_var.name] = new_value
  75. toggle_var.value = new_value
  76. # Return whether section is enabled
  77. return section.is_enabled()
  78. def _collect_section_variables(
  79. self,
  80. section,
  81. section_key: str,
  82. section_enabled: bool,
  83. variables: VariableCollection,
  84. collected: dict[str, Any],
  85. ) -> None:
  86. """Collect values for all variables in a section."""
  87. for var_name, variable in section.variables.items():
  88. # Skip toggle variable (already handled)
  89. if section.toggle and var_name == section.toggle:
  90. continue
  91. # Skip variables with unsatisfied needs
  92. if not variables.is_variable_satisfied(var_name):
  93. logger.debug(f"Skipping variable '{var_name}' - needs not satisfied")
  94. continue
  95. # Skip all variables if section is disabled
  96. if not section_enabled:
  97. logger.debug(f"Skipping variable '{var_name}' from disabled section '{section_key}'")
  98. continue
  99. # Prompt for the variable and update if changed
  100. self._prompt_and_update_variable(variable, collected)
  101. def _prompt_and_update_variable(self, variable: Variable, collected: dict[str, Any]) -> None:
  102. """Prompt for a variable and update collected values if changed."""
  103. current_value = variable.convert(variable.value)
  104. new_value = self._prompt_variable(variable, required=False)
  105. # For autogenerated variables, always update even if None (signals autogeneration)
  106. if variable.autogenerated and new_value is None:
  107. collected[variable.name] = None
  108. variable.value = None
  109. elif new_value != current_value:
  110. collected[variable.name] = new_value
  111. variable.value = new_value
  112. def _prompt_variable(self, variable: Variable, _required: bool = False) -> Any:
  113. """Prompt for a single variable value based on its type.
  114. Args:
  115. variable: The variable to prompt for
  116. _required: Whether the containing section is required (unused, kept for API compatibility)
  117. Returns:
  118. The validated value entered by the user
  119. """
  120. logger.debug(f"Prompting for variable '{variable.name}' (type: {variable.type})")
  121. # Use variable's native methods for prompt text and default value
  122. prompt_text = variable.get_prompt_text()
  123. default_value = variable.get_normalized_default()
  124. # Add lock icon before default value for sensitive or autogenerated variables
  125. if variable.sensitive or variable.autogenerated:
  126. # Format: "Prompt text 🔒 (default)"
  127. # The lock icon goes between the text and the default value in parentheses
  128. prompt_text = f"{prompt_text} {self.display.get_lock_icon()}"
  129. # Check if this specific variable is required (has no default and not autogenerated)
  130. var_is_required = variable.is_required()
  131. # If variable is required, mark it in the prompt
  132. if var_is_required:
  133. prompt_text = f"{prompt_text} [bold red]*required[/bold red]"
  134. handler = self._get_prompt_handler(variable)
  135. # Add validation hint (includes both extra text and enum options)
  136. hint = variable.get_validation_hint()
  137. if hint:
  138. # Show options/extra inline inside parentheses, before the default
  139. prompt_text = f"{prompt_text} [dim]({hint})[/dim]"
  140. while True:
  141. try:
  142. raw = handler(prompt_text, default_value)
  143. # Use Variable's centralized validation method that handles:
  144. # - Type conversion
  145. # - Autogenerated variable detection
  146. # - Required field validation
  147. return variable.validate_and_convert(raw, check_required=True)
  148. # Return the converted value (caller will update variable.value)
  149. except ValueError as exc:
  150. # Conversion/validation failed — show a consistent error message and retry
  151. self._show_validation_error(str(exc))
  152. except Exception as e:
  153. # Unexpected error — log and retry using the stored (unconverted) value
  154. logger.error(f"Error prompting for variable '{variable.name}': {e!s}")
  155. default_value = variable.value
  156. handler = self._get_prompt_handler(variable)
  157. def _get_prompt_handler(self, variable: Variable) -> Callable:
  158. """Return the prompt function for a variable type."""
  159. handlers = {
  160. "bool": self._prompt_bool,
  161. "int": self._prompt_int,
  162. # For enum prompts we pass the variable.extra through so options and extra
  163. # can be combined into a single inline hint.
  164. "enum": lambda text, default: self._prompt_enum(
  165. text,
  166. variable.options or [],
  167. default,
  168. extra=getattr(variable, "extra", None),
  169. ),
  170. }
  171. return handlers.get(
  172. variable.type,
  173. lambda text, default: self._prompt_string(text, default, is_sensitive=variable.sensitive),
  174. )
  175. def _show_validation_error(self, message: str) -> None:
  176. """Display validation feedback consistently."""
  177. self.display.display_validation_error(message)
  178. def _prompt_string(self, prompt_text: str, default: Any = None, is_sensitive: bool = False) -> str | None:
  179. value = Prompt.ask(
  180. prompt_text,
  181. default=str(default) if default is not None else "",
  182. show_default=True,
  183. password=is_sensitive,
  184. )
  185. stripped = value.strip() if value else None
  186. return stripped if stripped else None
  187. def _prompt_bool(self, prompt_text: str, default: Any = None) -> bool | None:
  188. if default is None:
  189. return Confirm.ask(prompt_text, default=None)
  190. converted = default if isinstance(default, bool) else str(default).lower() in ("true", "1", "yes", "on")
  191. return Confirm.ask(prompt_text, default=converted)
  192. def _prompt_int(self, prompt_text: str, default: Any = None) -> int | None:
  193. converted = None
  194. if default is not None:
  195. try:
  196. converted = int(default)
  197. except (ValueError, TypeError):
  198. logger.warning(f"Invalid default integer value: {default}")
  199. return IntPrompt.ask(prompt_text, default=converted)
  200. def _prompt_enum(
  201. self,
  202. prompt_text: str,
  203. options: list[str],
  204. default: Any = None,
  205. _extra: str | None = None,
  206. ) -> str:
  207. """Prompt for enum selection with validation.
  208. Note: prompt_text should already include hint from variable.get_validation_hint()
  209. but we keep this for backward compatibility and fallback.
  210. """
  211. if not options:
  212. return self._prompt_string(prompt_text, default)
  213. # Validate default is in options
  214. if default and str(default) not in options:
  215. default = options[0]
  216. while True:
  217. value = Prompt.ask(
  218. prompt_text,
  219. default=str(default) if default else options[0],
  220. show_default=True,
  221. )
  222. if value in options:
  223. return value
  224. self.console.print(f"[red]Invalid choice. Select from: {', '.join(options)}[/red]")