Procházet zdrojové kódy

refactor: Simplify variable handling and validation

- Simplified the  and  classes in .
- Introduced helper methods to reduce code duplication and improve clarity.
- Refactored the  constructor to be more modular.
- Added a  for efficient O(1) variable lookups.

Closes #1252
xcad před 9 měsíci
rodič
revize
acfeff3156
1 změnil soubory, kde provedl 87 přidání a 86 odebrání
  1. 87 86
      cli/core/variables.py

+ 87 - 86
cli/core/variables.py

@@ -61,11 +61,32 @@ class Variable:
       except ValueError as exc:
         raise ValueError(f"Invalid default for variable '{self.name}': {exc}")
 
-  def validate(self, value: Any) -> None:
-    """Validate a value based on the variable's type and constraints."""
-    if self.type not in ["bool"] and (value is None or value == ""):
+  # -------------------------
+  # SECTION: Validation Helpers
+  # -------------------------
+
+  def _validate_not_empty(self, value: Any, converted_value: Any) -> None:
+    """Validate that a value is not empty for non-boolean types."""
+    if self.type not in ["bool"] and (converted_value is None or converted_value == ""):
       raise ValueError("value cannot be empty")
 
+  def _validate_enum_option(self, value: str) -> None:
+    """Validate that a value is in the allowed enum options."""
+    if self.options and value not in self.options:
+      raise ValueError(f"value must be one of: {', '.join(self.options)}")
+
+  def _validate_regex_pattern(self, value: str, pattern: re.Pattern, error_msg: str) -> None:
+    """Validate that a value matches a regex pattern."""
+    if not pattern.fullmatch(value):
+      raise ValueError(error_msg)
+
+  def _validate_url_structure(self, parsed_url) -> None:
+    """Validate that a parsed URL has required components."""
+    if not (parsed_url.scheme and parsed_url.netloc):
+      raise ValueError("value must be a valid URL (include scheme and host)")
+
+  # !SECTION
+
   # -------------------------
   # SECTION: Type Conversion
   # -------------------------
@@ -132,8 +153,7 @@ class Variable:
     if value == "":
       return None
     val = str(value)
-    if self.options and val not in self.options:
-      raise ValueError(f"value must be one of: {', '.join(self.options)}")
+    self._validate_enum_option(val)
     return val
 
   def _convert_hostname(self, value: Any) -> str:
@@ -141,10 +161,8 @@ class Variable:
     val = str(value).strip()
     if not val:
       return ""
-    if val.lower() == "localhost":
-      return val
-    if not HOSTNAME_REGEX.fullmatch(val):
-      raise ValueError("value must be a valid hostname")
+    if val.lower() != "localhost":
+      self._validate_regex_pattern(val, HOSTNAME_REGEX, "value must be a valid hostname")
     return val
 
   def _convert_url(self, value: Any) -> str:
@@ -153,8 +171,7 @@ class Variable:
     if not val:
       return ""
     parsed = urlparse(val)
-    if not (parsed.scheme and parsed.netloc):
-      raise ValueError("value must be a valid URL (include scheme and host)")
+    self._validate_url_structure(parsed)
     return val
 
   def _convert_email(self, value: Any) -> str:
@@ -162,8 +179,7 @@ class Variable:
     val = str(value).strip()
     if not val:
       return ""
-    if not EMAIL_REGEX.fullmatch(val):
-      raise ValueError("value must be a valid email address")
+    self._validate_regex_pattern(val, EMAIL_REGEX, "value must be a valid email address")
     return val
 
   def get_typed_value(self) -> Any:
@@ -243,34 +259,43 @@ class VariableCollection:
     if not isinstance(spec, dict):
       raise ValueError("Spec must be a dictionary")
     
-    self._set: Dict[str, VariableSection] = {}
-    
-    # Initialize sections and their variables
+    self._sections: Dict[str, VariableSection] = {}
+    # NOTE: The _variable_map provides a flat, O(1) lookup for any variable by its name,
+    # avoiding the need to iterate through sections. It stores references to the same
+    # Variable objects contained in the _set structure.
+    self._variable_map: Dict[str, Variable] = {}
+    self._initialize_sections(spec)
+
+  def _initialize_sections(self, spec: dict[str, Any]) -> None:
+    """Initialize sections from the spec."""
     for section_key, section_data in spec.items():
       if not isinstance(section_data, dict):
         continue
-        
-      # Create section data with the key included
-      section_init_data = {
-        "key": section_key,
-        "title": section_data.get("title", section_key.replace("_", " ").title()),
-        "prompt": section_data.get("prompt"),
-        "description": section_data.get("description"),
-        "toggle": section_data.get("toggle"),
-        "required": section_data.get("required", section_key == "general")
-      }
-      
-      section = VariableSection(section_init_data)
       
-      # Initialize variables in this section
-      if "vars" in section_data:
-        for var_name, var_data in section_data["vars"].items():
-          # Add variable name to the data
-          var_init_data = {"name": var_name, **var_data}
-          variable = Variable(var_init_data)
-          section.variables[var_name] = variable
-      
-      self._set[section_key] = section
+      section = self._create_section(section_key, section_data)
+      self._initialize_variables(section, section_data.get("vars", {}))
+      self._sections[section_key] = section
+
+  def _create_section(self, key: str, data: dict[str, Any]) -> VariableSection:
+    """Create a VariableSection from data."""
+    section_init_data = {
+      "key": key,
+      "title": data.get("title", key.replace("_", " ").title()),
+      "prompt": data.get("prompt"),
+      "description": data.get("description"),
+      "toggle": data.get("toggle"),
+      "required": data.get("required", key == "general")
+    }
+    return VariableSection(section_init_data)
+
+  def _initialize_variables(self, section: VariableSection, vars_data: dict[str, Any]) -> None:
+    """Initialize variables for a section."""
+    for var_name, var_data in vars_data.items():
+      var_init_data = {"name": var_name, **var_data}
+      variable = Variable(var_init_data)
+      section.variables[var_name] = variable
+      # NOTE: Populate the direct lookup map for efficient access.
+      self._variable_map[var_name] = variable
 
   # -------------------------
   # SECTION: Helper Methods
@@ -280,61 +305,40 @@ class VariableCollection:
   # by centralizing common variable collection operations
 
   def get_all_values(self) -> dict[str, Any]:
-    """Get all variable values as a dictionary.
-    Returns:
-      Dictionary mapping variable names to their typed values
-    """
-
-    # NOTE: Eliminates the need to iterate through sections and variables manually
-    # in module.py _extract_current_variable_values() method
-
+    """Get all variable values as a dictionary."""
+    # NOTE: This method is optimized to use the _variable_map for direct O(1) access
+    # to each variable, which is much faster than iterating through sections.
     all_values = {}
-    for section in self._set.values():
-      for var_name, variable in section.variables.items():
-        all_values[var_name] = variable.get_typed_value()
+    for var_name, variable in self._variable_map.items():
+      all_values[var_name] = variable.get_typed_value()
     return all_values
 
   def apply_overrides(self, overrides: dict[str, Any], origin_suffix: str = " -> cli") -> list[str]:
-    """Apply multiple variable overrides at once.
-    
-    Args:
-      overrides: Dictionary of variable names to values
-      origin_suffix: Suffix to append to origins for overridden variables
-      
-    Returns:
-      List of variable names that were successfully overridden
-    """
-
-    # NOTE: Replaces the complex _apply_cli_overrides() method in module.py
-    # by centralizing override logic with proper error handling and origin tracking
-
+    """Apply multiple variable overrides at once."""
+    # NOTE: This method uses the _variable_map for a significant performance gain,
+    # as it allows direct O(1) lookup of variables instead of iterating
+    # through all sections to find a match.
     successful_overrides = []
     errors = []
     
     for var_name, value in overrides.items():
       try:
-        # Find and update the variable
-        found = False
-        for section in self._set.values():
-          if var_name in section.variables:
-            variable = section.variables[var_name]
-            
-            # Convert and set the new value
-            converted_value = variable.convert(value)
-            variable.value = converted_value
-            
-            # Update origin to show override
-            if variable.origin:
-              variable.origin = variable.origin + origin_suffix
-            else:
-              variable.origin = origin_suffix.lstrip(" -> ")
-            
-            successful_overrides.append(var_name)
-            found = True
-            break
-        
-        if not found:
+        variable = self._variable_map.get(var_name)
+        if not variable:
           logger.warning(f"Variable '{var_name}' not found in template")
+          continue
+        
+        # Convert and set the new value
+        converted_value = variable.convert(value)
+        variable.value = converted_value
+        
+        # Update origin to show override
+        if variable.origin:
+          variable.origin = variable.origin + origin_suffix
+        else:
+          variable.origin = origin_suffix.lstrip(" -> ")
+        
+        successful_overrides.append(var_name)
           
       except ValueError as e:
         error_msg = f"Invalid override value for '{var_name}': {value} - {e}"
@@ -342,12 +346,11 @@ class VariableCollection:
         logger.error(error_msg)
     
     if errors:
-      # Log errors but don't stop the process
       logger.warning(f"Some CLI overrides failed: {'; '.join(errors)}")
     
   def validate_all(self) -> None:
     """Validate all variables in the collection, skipping disabled sections."""
-    for section in self._set.values():
+    for section in self._sections.values():
       # Check if the section is disabled by a toggle
       if section.toggle:
         toggle_var = section.variables.get(section.toggle)
@@ -364,5 +367,3 @@ class VariableCollection:
   # !SECTION
 
 # !SECTION
-
-# !SECTION