Skip to content

Commit 404bb85

Browse files
committed
clean up rust duplication
1 parent 46dc054 commit 404bb85

File tree

2 files changed

+107
-31
lines changed

2 files changed

+107
-31
lines changed

rust/src/regex/mod.rs

Lines changed: 102 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ use pyo3::exceptions::PyValueError;
44
use pyo3::prelude::*;
55
use pyo3_polars::derive::polars_expr;
66
use regex::Regex;
7+
use std::collections::hash_map::Entry;
78
use std::collections::HashMap;
89

910
#[pyfunction]
@@ -14,6 +15,25 @@ pub fn py_validate_regex(regex: &str) -> PyResult<()> {
1415
}
1516
}
1617

18+
/// Get or compile a regex from the cache.
19+
/// Uses the Entry API for efficient single-lookup caching.
20+
fn get_or_compile_regex<'a>(
21+
regex_cache: &'a mut HashMap<String, Regex>,
22+
pattern: &str,
23+
) -> PolarsResult<&'a Regex> {
24+
match regex_cache.entry(pattern.to_string()) {
25+
Entry::Occupied(entry) => Ok(entry.into_mut()),
26+
Entry::Vacant(entry) => {
27+
let regex = Regex::new(pattern).map_err(|e| {
28+
PolarsError::ComputeError(
29+
format!("Invalid regex pattern '{}': {}", pattern, e).into(),
30+
)
31+
})?;
32+
Ok(entry.insert(regex))
33+
}
34+
}
35+
}
36+
1737
#[polars_expr(output_type=Int32)]
1838
fn regexp_instr(inputs: &[Series]) -> PolarsResult<Series> {
1939
let text_series = inputs[0].str()?;
@@ -42,20 +62,7 @@ fn regexp_instr(inputs: &[Series]) -> PolarsResult<Series> {
4262
Some(0)
4363
} else {
4464
// Get or compile regex, return error if invalid
45-
if !regex_cache.contains_key(pattern) {
46-
match Regex::new(pattern) {
47-
Ok(re) => {
48-
regex_cache.insert(pattern.to_string(), re);
49-
}
50-
Err(e) => {
51-
return Err(PolarsError::ComputeError(
52-
format!("Invalid regex pattern '{}': {}", pattern, e).into(),
53-
));
54-
}
55-
}
56-
}
57-
58-
let regex = regex_cache.get(pattern).unwrap();
65+
let regex = get_or_compile_regex(&mut regex_cache, pattern)?;
5966

6067
// Try to find a match
6168
if let Some(captures) = regex.captures(text) {
@@ -111,20 +118,7 @@ fn regexp_extract_all(inputs: &[Series]) -> PolarsResult<Series> {
111118
Some(Series::new_empty(PlSmallStr::EMPTY, &DataType::String))
112119
} else {
113120
// Get or compile regex, return error if invalid
114-
if !regex_cache.contains_key(pattern) {
115-
match Regex::new(pattern) {
116-
Ok(re) => {
117-
regex_cache.insert(pattern.to_string(), re);
118-
}
119-
Err(e) => {
120-
return Err(PolarsError::ComputeError(
121-
format!("Invalid regex pattern '{}': {}", pattern, e).into(),
122-
));
123-
}
124-
}
125-
}
126-
127-
let regex = regex_cache.get(pattern).unwrap();
121+
let regex = get_or_compile_regex(&mut regex_cache, pattern)?;
128122
let idx_usize = idx as usize;
129123
let mut matches = Vec::new();
130124

@@ -137,7 +131,10 @@ fn regexp_extract_all(inputs: &[Series]) -> PolarsResult<Series> {
137131
}
138132

139133
// Return as Series
140-
Some(Series::from_iter(matches))
134+
Some(
135+
StringChunked::from_iter_values(PlSmallStr::EMPTY, matches.into_iter())
136+
.into_series(),
137+
)
141138
}
142139
}
143140
_ => None, // If any input is null, return null
@@ -157,3 +154,79 @@ fn extract_all_output_type(input_fields: &[Field]) -> PolarsResult<Field> {
157154
DataType::List(Box::new(DataType::String)),
158155
))
159156
}
157+
158+
#[cfg(test)]
159+
mod tests {
160+
use super::*;
161+
162+
// Pure Rust tests - these can run with `cargo test`
163+
#[test]
164+
fn test_get_or_compile_regex_valid_pattern() {
165+
let mut cache = HashMap::new();
166+
let pattern = r"\d+";
167+
168+
let result = get_or_compile_regex(&mut cache, pattern);
169+
assert!(result.is_ok());
170+
171+
let regex = result.unwrap();
172+
assert!(regex.is_match("123"));
173+
assert!(!regex.is_match("abc"));
174+
}
175+
176+
#[test]
177+
fn test_get_or_compile_regex_caches() {
178+
let mut cache = HashMap::new();
179+
let pattern = r"[a-z]+";
180+
181+
// First call should compile and cache
182+
let result1 = get_or_compile_regex(&mut cache, pattern);
183+
assert!(result1.is_ok());
184+
assert_eq!(cache.len(), 1);
185+
186+
// Second call should use cache (verify cache size doesn't change)
187+
let result2 = get_or_compile_regex(&mut cache, pattern);
188+
assert!(result2.is_ok());
189+
assert_eq!(cache.len(), 1);
190+
191+
// Different pattern should add to cache
192+
let result3 = get_or_compile_regex(&mut cache, r"\d+");
193+
assert!(result3.is_ok());
194+
assert_eq!(cache.len(), 2);
195+
}
196+
197+
#[test]
198+
fn test_get_or_compile_regex_invalid_pattern() {
199+
let mut cache = HashMap::new();
200+
let pattern = r"[invalid(";
201+
202+
let result = get_or_compile_regex(&mut cache, pattern);
203+
assert!(result.is_err());
204+
205+
// Verify error message format
206+
let err = result.unwrap_err();
207+
let err_msg = err.to_string();
208+
assert!(err_msg.contains("Invalid regex pattern"));
209+
assert!(err_msg.contains("[invalid("));
210+
}
211+
212+
#[test]
213+
fn test_get_or_compile_regex_special_patterns() {
214+
let mut cache = HashMap::new();
215+
216+
// Test email pattern
217+
let email_pattern = r"(\w+)@(\w+)\.(\w+)";
218+
let result = get_or_compile_regex(&mut cache, email_pattern);
219+
assert!(result.is_ok());
220+
assert!(result.unwrap().is_match("test@example.com"));
221+
222+
// Test word boundary pattern
223+
let word_boundary = r"\bword\b";
224+
let result = get_or_compile_regex(&mut cache, word_boundary);
225+
assert!(result.is_ok());
226+
assert!(result.unwrap().is_match("a word here"));
227+
}
228+
229+
// PyO3 tests - these are tested via Python integration tests
230+
// Note: py_validate_regex is tested in tests/_backends/local/functions/test_regexp_functions.py
231+
// because standalone PyO3 tests require Python runtime linking
232+
}

tests/_backends/local/functions/test_regexp_functions.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -468,13 +468,16 @@ def test_regexp_invalid_pattern_with_column_pattern(local_session):
468468
)
469469
# raises an execution time error because we cannot validate the pattern when the logical plan is being
470470
# constructed.
471+
# Functions that use Polars built-in methods produce "regex error: regex parse error:" messages
472+
# Functions that use our Rust plugins produce "Invalid regex pattern" messages
471473
with pytest.raises(ExecutionError, match="Failed to execute query: regex error: regex parse error:"):
472474
df.select(text.regexp_count("text_col", col("pattern"))).to_polars()
473475
with pytest.raises(ExecutionError, match="Failed to execute query: regex error: regex parse error:"):
474476
df.select(text.regexp_extract("text_col", col("pattern"), 0)).to_polars()
475-
with pytest.raises(ExecutionError, match="Failed to execute query: .* Invalid regex pattern .*"):
477+
with pytest.raises(ExecutionError, match="Failed to execute query: .* Invalid regex pattern .* regex parse error"):
476478
df.select(text.regexp_extract_all("text_col", col("pattern"), 0)).to_polars()
477-
with pytest.raises(ExecutionError, match="Failed to execute query: .* Invalid regex pattern .*"):
479+
# regexp_instr with literal idx uses Polars str.extract (fast path), so error comes from Polars
480+
with pytest.raises(ExecutionError, match="Failed to execute query: regex error: regex parse error:"):
478481
df.select(text.regexp_instr("text_col", col("pattern"), 0)).to_polars()
479482
with pytest.raises(ExecutionError, match="Failed to execute query: regex error: regex parse error:"):
480483
df.select(text.regexp_substr("text_col", col("pattern"))).to_polars()

0 commit comments

Comments
 (0)