@@ -4,6 +4,7 @@ use pyo3::exceptions::PyValueError;
44use pyo3:: prelude:: * ;
55use pyo3_polars:: derive:: polars_expr;
66use regex:: Regex ;
7+ use std:: collections:: hash_map:: Entry ;
78use std:: collections:: HashMap ;
89
910#[ pyfunction]
@@ -14,6 +15,25 @@ pub fn py_validate_regex(regex: &str) -> PyResult<()> {
1415 }
1516}
1617
18+ /// Get or compile a regex from the cache.
19+ /// Uses the Entry API for efficient single-lookup caching.
20+ fn get_or_compile_regex < ' a > (
21+ regex_cache : & ' a mut HashMap < String , Regex > ,
22+ pattern : & str ,
23+ ) -> PolarsResult < & ' a Regex > {
24+ match regex_cache. entry ( pattern. to_string ( ) ) {
25+ Entry :: Occupied ( entry) => Ok ( entry. into_mut ( ) ) ,
26+ Entry :: Vacant ( entry) => {
27+ let regex = Regex :: new ( pattern) . map_err ( |e| {
28+ PolarsError :: ComputeError (
29+ format ! ( "Invalid regex pattern '{}': {}" , pattern, e) . into ( ) ,
30+ )
31+ } ) ?;
32+ Ok ( entry. insert ( regex) )
33+ }
34+ }
35+ }
36+
1737#[ polars_expr( output_type=Int32 ) ]
1838fn regexp_instr ( inputs : & [ Series ] ) -> PolarsResult < Series > {
1939 let text_series = inputs[ 0 ] . str ( ) ?;
@@ -42,20 +62,7 @@ fn regexp_instr(inputs: &[Series]) -> PolarsResult<Series> {
4262 Some ( 0 )
4363 } else {
4464 // Get or compile regex, return error if invalid
45- if !regex_cache. contains_key ( pattern) {
46- match Regex :: new ( pattern) {
47- Ok ( re) => {
48- regex_cache. insert ( pattern. to_string ( ) , re) ;
49- }
50- Err ( e) => {
51- return Err ( PolarsError :: ComputeError (
52- format ! ( "Invalid regex pattern '{}': {}" , pattern, e) . into ( ) ,
53- ) ) ;
54- }
55- }
56- }
57-
58- let regex = regex_cache. get ( pattern) . unwrap ( ) ;
65+ let regex = get_or_compile_regex ( & mut regex_cache, pattern) ?;
5966
6067 // Try to find a match
6168 if let Some ( captures) = regex. captures ( text) {
@@ -111,20 +118,7 @@ fn regexp_extract_all(inputs: &[Series]) -> PolarsResult<Series> {
111118 Some ( Series :: new_empty ( PlSmallStr :: EMPTY , & DataType :: String ) )
112119 } else {
113120 // Get or compile regex, return error if invalid
114- if !regex_cache. contains_key ( pattern) {
115- match Regex :: new ( pattern) {
116- Ok ( re) => {
117- regex_cache. insert ( pattern. to_string ( ) , re) ;
118- }
119- Err ( e) => {
120- return Err ( PolarsError :: ComputeError (
121- format ! ( "Invalid regex pattern '{}': {}" , pattern, e) . into ( ) ,
122- ) ) ;
123- }
124- }
125- }
126-
127- let regex = regex_cache. get ( pattern) . unwrap ( ) ;
121+ let regex = get_or_compile_regex ( & mut regex_cache, pattern) ?;
128122 let idx_usize = idx as usize ;
129123 let mut matches = Vec :: new ( ) ;
130124
@@ -137,7 +131,10 @@ fn regexp_extract_all(inputs: &[Series]) -> PolarsResult<Series> {
137131 }
138132
139133 // Return as Series
140- Some ( Series :: from_iter ( matches) )
134+ Some (
135+ StringChunked :: from_iter_values ( PlSmallStr :: EMPTY , matches. into_iter ( ) )
136+ . into_series ( ) ,
137+ )
141138 }
142139 }
143140 _ => None , // If any input is null, return null
@@ -157,3 +154,79 @@ fn extract_all_output_type(input_fields: &[Field]) -> PolarsResult<Field> {
157154 DataType :: List ( Box :: new ( DataType :: String ) ) ,
158155 ) )
159156}
157+
158+ #[ cfg( test) ]
159+ mod tests {
160+ use super :: * ;
161+
162+ // Pure Rust tests - these can run with `cargo test`
163+ #[ test]
164+ fn test_get_or_compile_regex_valid_pattern ( ) {
165+ let mut cache = HashMap :: new ( ) ;
166+ let pattern = r"\d+" ;
167+
168+ let result = get_or_compile_regex ( & mut cache, pattern) ;
169+ assert ! ( result. is_ok( ) ) ;
170+
171+ let regex = result. unwrap ( ) ;
172+ assert ! ( regex. is_match( "123" ) ) ;
173+ assert ! ( !regex. is_match( "abc" ) ) ;
174+ }
175+
176+ #[ test]
177+ fn test_get_or_compile_regex_caches ( ) {
178+ let mut cache = HashMap :: new ( ) ;
179+ let pattern = r"[a-z]+" ;
180+
181+ // First call should compile and cache
182+ let result1 = get_or_compile_regex ( & mut cache, pattern) ;
183+ assert ! ( result1. is_ok( ) ) ;
184+ assert_eq ! ( cache. len( ) , 1 ) ;
185+
186+ // Second call should use cache (verify cache size doesn't change)
187+ let result2 = get_or_compile_regex ( & mut cache, pattern) ;
188+ assert ! ( result2. is_ok( ) ) ;
189+ assert_eq ! ( cache. len( ) , 1 ) ;
190+
191+ // Different pattern should add to cache
192+ let result3 = get_or_compile_regex ( & mut cache, r"\d+" ) ;
193+ assert ! ( result3. is_ok( ) ) ;
194+ assert_eq ! ( cache. len( ) , 2 ) ;
195+ }
196+
197+ #[ test]
198+ fn test_get_or_compile_regex_invalid_pattern ( ) {
199+ let mut cache = HashMap :: new ( ) ;
200+ let pattern = r"[invalid(" ;
201+
202+ let result = get_or_compile_regex ( & mut cache, pattern) ;
203+ assert ! ( result. is_err( ) ) ;
204+
205+ // Verify error message format
206+ let err = result. unwrap_err ( ) ;
207+ let err_msg = err. to_string ( ) ;
208+ assert ! ( err_msg. contains( "Invalid regex pattern" ) ) ;
209+ assert ! ( err_msg. contains( "[invalid(" ) ) ;
210+ }
211+
212+ #[ test]
213+ fn test_get_or_compile_regex_special_patterns ( ) {
214+ let mut cache = HashMap :: new ( ) ;
215+
216+ // Test email pattern
217+ let email_pattern = r"(\w+)@(\w+)\.(\w+)" ;
218+ let result = get_or_compile_regex ( & mut cache, email_pattern) ;
219+ assert ! ( result. is_ok( ) ) ;
220+ assert ! ( result. unwrap( ) . is_match( "test@example.com" ) ) ;
221+
222+ // Test word boundary pattern
223+ let word_boundary = r"\bword\b" ;
224+ let result = get_or_compile_regex ( & mut cache, word_boundary) ;
225+ assert ! ( result. is_ok( ) ) ;
226+ assert ! ( result. unwrap( ) . is_match( "a word here" ) ) ;
227+ }
228+
229+ // PyO3 tests - these are tested via Python integration tests
230+ // Note: py_validate_regex is tested in tests/_backends/local/functions/test_regexp_functions.py
231+ // because standalone PyO3 tests require Python runtime linking
232+ }
0 commit comments