@@ -756,24 +756,27 @@ def infer_config(args, constructor, trace_patch):
756756 return trace_specs , grouped_mappings , sizeref , color_range
757757
758758
759- def make_figure (args , constructor , trace_patch = {}, layout_patch = {}):
760- apply_default_cascade (args )
761- trace_specs , grouped_mappings , sizeref , color_range = infer_config (
762- args , constructor , trace_patch
763- )
764- grouper = [x .grouper or one_group for x in grouped_mappings ] or [one_group ]
765- grouped = args ["data_frame" ].groupby (grouper , sort = False )
759+ def get_orderings (args , grouper , grouped ):
760+ """
761+ `orders` is the user-supplied ordering (with the remaining data-frame-supplied
762+ ordering appended if the column is used for grouping)
763+ `group_names` is the set of groups, ordered by the order above
764+ """
766765 orders = {} if "category_orders" not in args else args ["category_orders" ].copy ()
767766 group_names = []
768767 for group_name in grouped .groups :
769768 if len (grouper ) == 1 :
770769 group_name = (group_name ,)
771770 group_names .append (group_name )
772- for col , val in zip (grouper , group_name ):
773- if col not in orders :
774- orders [col ] = []
775- if val not in orders [col ]:
776- orders [col ].append (val )
771+ for col in grouper :
772+ if col != one_group :
773+ uniques = args ["data_frame" ][col ].unique ()
774+ if col not in orders :
775+ orders [col ] = list (uniques )
776+ else :
777+ for val in uniques :
778+ if val not in orders [col ]:
779+ orders [col ].append (val )
777780
778781 for i , col in reversed (list (enumerate (grouper ))):
779782 if col != one_group :
@@ -782,10 +785,23 @@ def make_figure(args, constructor, trace_patch={}, layout_patch={}):
782785 key = lambda g : orders [col ].index (g [i ]) if g [i ] in orders [col ] else - 1 ,
783786 )
784787
788+ return orders , group_names
789+
790+
791+ def make_figure (args , constructor , trace_patch = {}, layout_patch = {}):
792+ apply_default_cascade (args )
793+ trace_specs , grouped_mappings , sizeref , color_range = infer_config (
794+ args , constructor , trace_patch
795+ )
796+ grouper = [x .grouper or one_group for x in grouped_mappings ] or [one_group ]
797+ grouped = args ["data_frame" ].groupby (grouper , sort = False )
798+
799+ orders , sorted_group_names = get_orderings (args , grouper , grouped )
800+
785801 trace_names_by_frame = {}
786802 frames = OrderedDict ()
787803 trendline_rows = []
788- for group_name in group_names :
804+ for group_name in sorted_group_names :
789805 group = grouped .get_group (group_name if len (group_name ) > 1 else group_name [0 ])
790806 mapping_labels = OrderedDict ()
791807 trace_name_labels = OrderedDict ()
0 commit comments