11import json
2- from typing import BinaryIO , Dict , Optional
2+ from typing import BinaryIO , Dict , Optional , Type
33
4- from mindee .document_config import DocumentConfig , DocumentConfigDict
5- from mindee .documents .bank_check import BankCheck
6- from mindee .documents .custom_document import CustomDocument
4+ from mindee .documents . base import Document , TypeDocument
5+ from mindee .documents .config import DocumentConfig , DocumentConfigDict
6+ from mindee .documents .custom . custom_v1 import CustomV1
77from mindee .documents .financial_document import FinancialDocument
8- from mindee .documents .invoice import Invoice
9- from mindee .documents .passport import Passport
10- from mindee .documents .receipt import Receipt
8+ from mindee .documents .invoice .invoice_v3 import InvoiceV3
9+ from mindee .documents .passport .passport_v1 import PassportV1
10+ from mindee .documents .receipt .receipt_v3 import ReceiptV3
11+ from mindee .documents .us .bank_check .bank_check_v1 import BankCheckV1
1112from mindee .endpoints import OTS_OWNER , CustomEndpoint , HTTPException , StandardEndpoint
1213from mindee .input .page_options import PageOptions
1314from mindee .input .sources import (
2122from mindee .response import PredictResponse
2223
2324
25+ def get_type_var_name (type_var ) -> str :
26+ """Get the name of the bound class."""
27+ return type_var .__bound__ .__name__
28+
29+
2430class DocumentClient :
2531 input_doc : InputSource
2632 doc_configs : DocumentConfigDict
@@ -38,42 +44,49 @@ def __init__(
3844
3945 def parse (
4046 self ,
41- document_type : str ,
42- username : Optional [str ] = None ,
47+ document_class : TypeDocument ,
48+ endpoint_name : Optional [str ] = None ,
49+ account_name : Optional [str ] = None ,
4350 include_words : bool = False ,
4451 close_file : bool = True ,
4552 page_options : Optional [PageOptions ] = None ,
46- ) -> PredictResponse :
53+ ) -> PredictResponse [ TypeDocument ] :
4754 """
4855 Call prediction API on the document and parse the results.
4956
50- :param document_type: Document type to parse
51- :param username: API username, the endpoint owner
57+ :type document_class: DocT
58+ :param endpoint_name: Document type to parse
59+ :param account_name: API username, the endpoint owner
5260 :param include_words: Include all the words of the document in the response
5361 :param close_file: Whether to `close()` the file after parsing it.
5462 Set to `False` if you need to access the file after this operation.
5563 :param page_options: PageOptions object for cutting multipage documents.
5664 """
57- logger .debug ("Parsing document as '%s'" , document_type )
65+ if get_type_var_name (document_class ) != CustomV1 .__name__ :
66+ endpoint_name = get_type_var_name (document_class )
67+ elif endpoint_name is None :
68+ raise RuntimeError ("document_type is required for CustomDocument" )
69+
70+ logger .debug ("Parsing document as '%s'" , endpoint_name )
5871
5972 found = []
6073 for k in self .doc_configs .keys ():
61- if k [1 ] == document_type :
74+ if k [1 ] == endpoint_name :
6275 found .append (k )
6376
6477 if len (found ) == 0 :
65- raise RuntimeError (f"Document type not configured: { document_type } " )
78+ raise RuntimeError (f"Document type not configured: { endpoint_name } " )
6679
67- if username :
68- config_key = (username , document_type )
80+ if account_name :
81+ config_key = (account_name , endpoint_name )
6982 elif len (found ) == 1 :
7083 config_key = found [0 ]
7184 else :
7285 usernames = [k [0 ] for k in found ]
7386 raise RuntimeError (
7487 (
7588 "Duplicate configuration detected.\n "
76- f"You specified a document_type '{ document_type } ' in your custom config.\n "
89+ f"You specified a document_type '{ endpoint_name } ' in your custom config.\n "
7790 "To avoid confusion, please add the 'account_name' attribute to "
7891 f"the parse method, one of { usernames } ."
7992 )
@@ -87,11 +100,18 @@ def parse(
87100 page_options .on_min_pages ,
88101 page_options .page_indexes ,
89102 )
90- return self ._make_request (doc_config , include_words , close_file )
103+ return self ._make_request (document_class , doc_config , include_words , close_file )
91104
92105 def _make_request (
93- self , doc_config : DocumentConfig , include_words : bool , close_file : bool
94- ) -> PredictResponse :
106+ self ,
107+ document_class : TypeDocument ,
108+ doc_config : DocumentConfig ,
109+ include_words : bool ,
110+ close_file : bool ,
111+ ) -> PredictResponse [TypeDocument ]:
112+ if get_type_var_name (document_class ) != doc_config .document_class .__name__ :
113+ raise RuntimeError ("Document class mismatch!" )
114+
95115 response = doc_config .document_class .request (
96116 doc_config .endpoints ,
97117 self .input_doc ,
@@ -106,7 +126,7 @@ def _make_request(
106126 "API %s HTTP error: %s"
107127 % (response .status_code , json .dumps (dict_response ))
108128 )
109- return PredictResponse (
129+ return PredictResponse [ TypeDocument ] (
110130 http_response = dict_response ,
111131 doc_config = doc_config ,
112132 input_source = self .input_doc ,
@@ -143,27 +163,27 @@ def __init__(self, api_key: str = "", raise_on_error: bool = True):
143163
144164 def _init_default_endpoints (self ) -> None :
145165 self ._doc_configs = {
146- (OTS_OWNER , "invoice" ): DocumentConfig (
147- document_type = "invoice " ,
148- constructor = Invoice ,
166+ (OTS_OWNER , InvoiceV3 . __name__ ): DocumentConfig (
167+ document_type = "invoice_v3 " ,
168+ document_class = InvoiceV3 ,
149169 endpoints = [
150170 StandardEndpoint (
151171 url_name = "invoices" , version = "3" , api_key = self .api_key
152172 )
153173 ],
154174 ),
155- (OTS_OWNER , "receipt" ): DocumentConfig (
156- document_type = "receipt " ,
157- constructor = Receipt ,
175+ (OTS_OWNER , ReceiptV3 . __name__ ): DocumentConfig (
176+ document_type = "receipt_v3 " ,
177+ document_class = ReceiptV3 ,
158178 endpoints = [
159179 StandardEndpoint (
160180 url_name = "expense_receipts" , version = "3" , api_key = self .api_key
161181 )
162182 ],
163183 ),
164- (OTS_OWNER , "financial_doc" ): DocumentConfig (
184+ (OTS_OWNER , FinancialDocument . __name__ ): DocumentConfig (
165185 document_type = "financial_doc" ,
166- constructor = FinancialDocument ,
186+ document_class = FinancialDocument ,
167187 endpoints = [
168188 StandardEndpoint (
169189 url_name = "invoices" , version = "3" , api_key = self .api_key
@@ -173,18 +193,18 @@ def _init_default_endpoints(self) -> None:
173193 ),
174194 ],
175195 ),
176- (OTS_OWNER , "passport" ): DocumentConfig (
177- document_type = "passport " ,
178- constructor = Passport ,
196+ (OTS_OWNER , PassportV1 . __name__ ): DocumentConfig (
197+ document_type = "passport_v1 " ,
198+ document_class = PassportV1 ,
179199 endpoints = [
180200 StandardEndpoint (
181201 url_name = "passport" , version = "1" , api_key = self .api_key
182202 )
183203 ],
184204 ),
185- (OTS_OWNER , "bank_check" ): DocumentConfig (
186- document_type = "bank_check " ,
187- constructor = BankCheck ,
205+ (OTS_OWNER , BankCheckV1 . __name__ ): DocumentConfig (
206+ document_type = "bank_check_v1 " ,
207+ document_class = BankCheckV1 ,
188208 endpoints = [
189209 StandardEndpoint (
190210 url_name = "bank_check" , version = "1" , api_key = self .api_key
@@ -198,18 +218,21 @@ def add_endpoint(
198218 account_name : str ,
199219 endpoint_name : str ,
200220 version : str = "1" ,
221+ document_class : Type [Document ] = CustomV1 ,
201222 ) -> "Client" :
202223 """
203224 Add a custom endpoint, created using the Mindee API Builder.
204225
205226 :param endpoint_name: The "API name" field in the "Settings" page of the API Builder
206227 :param account_name: Your organization's username on the API Builder
207228 :param version: If set, locks the version of the model to use.
208- If not set, use the latest version of the model.
229+ If not set, use the latest version of the model.
230+ :param document_class: A document class in which the response will be extracted.
231+ Must inherit from ``mindee.documents.base.Document``.
209232 """
210233 self ._doc_configs [(account_name , endpoint_name )] = DocumentConfig (
211234 document_type = endpoint_name ,
212- constructor = CustomDocument ,
235+ document_class = document_class ,
213236 endpoints = [
214237 CustomEndpoint (
215238 owner = account_name ,
0 commit comments