-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathsvarfci.py
2101 lines (1534 loc) · 96.9 KB
/
svarfci.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
import numpy as np
from itertools import product, combinations
import os
class SVARFCI():
r"""
This class implements the SVAR-FCI algorithm introduced in:
Malinsky, D. and Spirtes, P. (2018). Causal Structure Learning from Multivariate Time Series in Settings with Unmeasured Confounding. In Le, T. D., Zhang, K., Kıcıman, E., Hyvärinen, A., and Liu, L., editors, Proceedings of 2018 ACM SIGKDD Workshop on Causal Disocvery, volume 92 of Proceedings of Machine Learning Research, pages 23–47, London, UK. PMLR.
Our implementation applies several modifications:
1) It assumes the absence of selection variables.
2) It guarantees order-independence by i) using the majority rule to decide whether a given node is in a given separating set and ii) applying a rule to the entire graph and resolving potential conflicts among the proposed orientations by means of the conflict mark 'x' before modifing the graph.
3) It allows for the following conclusion: If X^i_{t-\tau} and X^j_t for \tau > 0 are not m-separated by any subset of D-Sep(X^j_t, X^i_{t-\tau}, \mathcal{M}(\mathcal{G})) then these variables are adjacent in \mathcal{M}(\mathcal{G}). In particular, this conclusions does not require that X^i_{t-\tau} and X^j_t are moreover not m-separated by any subset of D-Sep(X^i_{t-\tau}, X^j_t, \mathcal{M}(\mathcal{G}))
4) Several control parameters apply further modifications, see below.
Parameters passed to the constructor:
- dataframe:
Tigramite dataframe object that contains the the time series dataset \bold{X}
- cond_ind_test:
A conditional independence test object that specifies which conditional independence test CI is to be used
Parameters passed to self.run_svarfci():
- tau_max:
The maximum considered time lag tau_max
- pc_alpha:
The significance level \alpha of conditional independence tests
- max_cond_px:
Consider a pair of variables (X^i_{t-\tau}, X^j_t) with \tau > 0. In the first removal phase (here this is self._run_pc_removal_phase()), the algorithm does not test for conditional independence given subsets of X^i_{t-\tau} of cardinality higher than max_cond_px. In the second removal phase (here this is self._run_dsep_removal_phase()), the algorithm does not test for conditional independence given subsets of pds_t(X^i_{t-\tau}, X^j_t) of cardinality higher than max_cond_px.
- max_p_global:
Restricts all conditional independence tests to conditioning sets with cardinality smaller or equal to max_p_global
- max_p_dsep:
Restricts all conditional independence tests in the second removal phase (here this is self._run_dsep_removal_phase()) to conditioning sets with cardinality smaller or equal to max_p_global
- max_q_global:
For each ordered pair (X^i_{t-\tau}, X^j_t) of adjacent variables and for each cardinality of the conditioning sets test at most max_q_global many conditioning sets (when summing over all tested cardinalities more than max_q_global tests may be made)
- max_pds_set:
In the second removal phase (here this is self._run_dsep_removal_phase()), the algorithm tests for conditional independence given subsets of the pds_t sets defined in the above reference. If for a given link the set pds_t(X^j_t, X^i_{t-\tau}) has more than max_pds_set many elements (or, if the link is also tested in the opposite directed, if pds_t(X^i_{t-\tau}, X^j_t) has more than max_pds_set elements), this link is not tested.
- fix_all_edges_before_final_orientation:
When one of the four previous parameters is not np.inf, the edge removals may terminate before we can be sure that all remaining edges are indeed part of the true PAG. However, soundness of the FCI orientation rules requires that they be applied only once the correct skeleton has been found. Therefore, the rules are only applied to those edges for which we are sure that they are part of the PAG. This can lead to quite uninformative results. If fix_all_edges_before_final_orientation is True, this precaution is overruled and the orientation rules are nevertheless applied to all edges.
- verbosity:
Controls the verbose output self.run_svarfci() and the function it calls.
Return value of self.run_svarfci():
The estimated graph in form of a link matrix. This is a numpy array of shape (self.N, self.N, self.tau_max + 1), where the entry array[i, j, \tau] is a string that visualizes the estimated link from X^i_{i-\tau} to X^j_t. For example, if array[0, 2, 1] = 'o->', then the estimated graph contains the link X^i_{t-1} o-> X^j_t. This numpy array is also saved as instance attribute self.graph. Note that self.N is the number of observed time series and self.tau_max the maximal considered time lag.
A note on middle marks:
In order to distinguish edges that are in the PAG for sure from edges that may not be in the PAG, we use the notion of middle marks that we introduced for LPCMCI. This becomes useful for the precaution discussed in the explanation of the parameter 'fix_all_edges_before_final_orientation', see above. In particular, we use the middle marks '?' and '' (empty). For convenience (to have strings of the same lengths) we here internally denote the empty middle mark by '-'. For post-processing purposes all middle marks are nevertheless set to the empty middle mark (here '-') in line 99, but if verbosity >= 1 a graph with the middle marks will be printed out before.
A note on wildcards:
The middle mark wildcard \ast and the edge mark wildcard are here represented as *, the edge mark wildcard \star as +
"""
def __init__(self, dataframe, cond_ind_test):
"""Class constructor. Store:
i) data
ii) conditional independence test object
iii) some instance attributes"""
# Save the time series data that the algorithm operates on
self.dataframe = dataframe
# Set the conditional independence test to be used
self.cond_ind_test = cond_ind_test
self.cond_ind_test.set_dataframe(self.dataframe)
# Store the shape of the data in the T and N variables
self.T, self.N = self.dataframe.values.shape
def run_svarfci(self,
tau_max = 1,
pc_alpha = 0.05,
max_cond_px = 0,
max_p_global = np.inf,
max_p_dsep = np.inf,
max_q_global = np.inf,
max_pds_set = np.inf,
fix_all_edges_before_final_orientation = True,
verbosity = 0):
"""Run the SVAR-FCI algorithm on the dataset and with the conditional independence test passed to the class constructor and with the options passed to this function."""
# Step 0: Initializations
self._initialize(tau_max, pc_alpha, max_cond_px, max_p_global, max_p_dsep, max_q_global, max_pds_set, fix_all_edges_before_final_orientation, verbosity)
# Step 1: PC removal phase
self._run_pc_removal_phase()
# Step 2: D-Sep removal phase (including preliminary collider orientation phase)
self._run_dsep_removal_phase()
# Step 3: FCI orientation phase
if self.fix_all_edges_before_final_orientation:
self._fix_all_edges()
self._run_fci_orientation_phase()
# Post processing
if self.verbosity >= 1:
print("Ambiguous triples", self.ambiguous_triples)
print("Max pds set: {}\n".format(self.max_pds_set_found))
self._fix_all_edges()
self.graph = self._dict2graph()
self.val_min_matrix = self._dict_to_matrix(self.val_min, self.tau_max, self.N, default = 0)
self.cardinality_matrix = self._dict_to_matrix(self.max_cardinality, self.tau_max, self.N, default = 0)
# Return the estimated graph
return self.graph
def _initialize(self,
tau_max,
pc_alpha,
max_cond_px,
max_p_global,
max_p_dsep,
max_q_global,
max_pds_set,
fix_all_edges_before_final_orientation,
verbosity):
"""Function for
i) saving the arguments passed to self.run_svarfci() as instance attributes
ii) initializing various memory variables for storing the current graph, sepsets etc.
"""
# Save the arguments passed to self.run_svarfci()
self.tau_max = tau_max
self.pc_alpha = pc_alpha
self.max_cond_px = max_cond_px
self.max_p_global = max_p_global
self.max_p_dsep = max_p_dsep
self.max_q_global = max_q_global
self.max_pds_set = max_pds_set
self.fix_all_edges_before_final_orientation = fix_all_edges_before_final_orientation
self.verbosity = verbosity
# Initialize the nested dictionary for storing the current graph.
# Syntax: self.graph_dict[j][(i, -tau)] gives the string representing the link from X^i_{t-tau} to X^j_t
self.graph_dict = {}
for j in range(self.N):
self.graph_dict[j] = {(i, 0): "o?o" for i in range(self.N) if j != i}
self.graph_dict[j].update({(i, -tau): "o?>" for i in range(self.N) for tau in range(1, self.tau_max + 1)})
# Initialize the nested dictionary for storing separating sets
# Syntax: self.sepsets[j][(i, -tau)] stores separating sets of X^i_{t-tau} to X^j_t. For tau = 0, i < j.
self.sepsets = {j: {(i, -tau): set() for i in range(self.N) for tau in range(self.tau_max + 1) if (tau > 0 or i < j)} for j in range(self.N)}
# Initialize dictionaries for storing known ancestorships, non-ancestorships, and ambiguous ancestorships
# Syntax: self.def_ancs[j] contains the set of all known ancestors of X^j_t. Equivalently for the others
self.def_ancs = {j: set() for j in range(self.N)}
self.def_non_ancs = {j: set() for j in range(self.N)}
self.ambiguous_ancestorships = {j: set() for j in range(self.N)}
# Initialize nested dictionaries for saving the minimum test statistic among all conditional independence tests of a given pair of variables, the maximum p-values, as well as the maximal cardinality of the known separating sets.
# Syntax: As for self.sepsets
self.val_min = {j: {(i, -tau): float("inf") for i in range(self.N) for tau in
range(self.tau_max + 1) if (tau > 0 or i < j)} for j in range(self.N)}
self.pval_max = {j: {(i, -tau): 0 for i in range(self.N) for tau in
range(self.tau_max + 1) if (tau > 0 or i < j)} for j in range(self.N)}
self.max_cardinality = {j: {(i, -tau): 0 for i in range(self.N) for tau in
range(self.tau_max + 1) if (tau > 0 or i < j)} for j in range(self.N)}
# Initialize a nested dictionary for caching pds-sets
# Syntax: As for self.sepsets
self._pds_t = {(j, -tau_j): {} for j in range(self.N) for tau_j in range(self.tau_max + 1)}
# Initialize a set for memorizing ambiguous triples
self.ambiguous_triples = set()
# Initialize a variable for remembering the maximal cardinality among all calculated pds-sets
self.max_pds_set_found = -1
################################################################################################
# Only relevant for use with oracle CI
self._oracle = False
################################################################################################
# Return
return True
def _run_pc_removal_phase(self):
"""Run the first removal phase of the FCI algorithm adapted to stationary time series. This is essentially the skeleton phase of the PC algorithm"""
# Verbose output
if self.verbosity >= 1:
print("\n=======================================================")
print("=======================================================")
print("Starting preliminary removal phase")
# Iterate until convergence
# p_pc is the cardinality of the conditioning set
p_pc = 0
while True:
##########################################################################################################
### Run the next removal iteration #######################################################################
# Verbose output
if self.verbosity >= 1:
if p_pc == 0:
print("\nStarting test phase\n")
print("p = {}".format(p_pc))
# Variable to check for convergence
has_converged = True
# Variable for keeping track of edges marked for removal
to_remove = {j: {} for j in range(self.N)}
# Iterate through all links
for (i, j, lag_i) in product(range(self.N), range(self.N), range(-self.tau_max, 1)):
# Decode the triple (i, j, lag_i) into pairs of variables (X, Y)
X = (i, lag_i)
Y = (j, 0)
######################################################################################################
### Exclusion of links ###############################################################################
# Exclude the current link if ...
# ... X = Y
if lag_i == 0 and i == j:
continue
# ... X > Y (so, in fact, we don't distinguish between both directions of the same edge)
if self._is_smaller(Y, X):
continue
# Get the current link from X to Y
link = self._get_link(X, Y)
# Also exclude the current link if ...
# ... X and Y are not adjacent anymore
if link == "":
continue
######################################################################################################
### Preparation of PC search sets ####################################################################
# Search for separating sets in the non-future adjacencies of X, without X and Y themselves
S_search_YX = self._get_non_future_adj([Y]).difference({X, Y})
# Search for separating sets in the non-future adjacencies of Y, without X and Y themselves, always if X and Y are contemporaneous or if specified by self.max_cond_px
test_X = True if (lag_i == 0 or (self.max_cond_px > 0 and self.max_cond_px >= p_pc)) else False
if test_X:
S_search_XY = self._get_non_future_adj([X]).difference({X, Y})
######################################################################################################
### Check whether the link needs testing #############################################################
# If there are less than p_pc elements in the search sets, the link does not need further testing
if len(S_search_YX) < p_pc and (not test_X or len(S_search_XY) < p_pc):
continue
# Force-quit while leep when p_pc exceeds the specified limits
if p_pc > self.max_p_global:
continue
# This link does need testing. Therfore, the algorithm has not converged yet
has_converged = False
######################################################################################################
### Tests for conditional independence ###############################################################
# If self.max_q_global is finite, the below for loop may be broken earlier. To still guarantee order independence, the set from which the potential separating sets are created is ordered in an order independent way. Here, the elements of S_search_YX are ordered according to their minimal test statistic with Y
if not np.isinf(self.max_q_global):
S_search_YX = self._sort_search_set(S_search_YX, Y)
# q_count counts the number of conditional independence tests made for subsets of S_search_YX
q_count = 0
# Run through all cardinality p_pc subsets of S_search_YX
for Z in combinations(S_search_YX, p_pc):
# Stop testing if the number of tests exceeds the bound specified by self.max_q_global
q_count = q_count + 1
if q_count > self.max_q_global:
break
# Test conditional independence of X and Y given Z. Correspondingly updateself.val_min, self.pval_max, and self.cardinality
val, pval = self.cond_ind_test.run_test(X = [X], Y = [Y], Z = list(Z), tau_max = self.tau_max)
if self.verbosity >= 2:
print(" %s _|_ %s | S_pc = %s: val = %.2f / pval = % .4f" %
(X, Y, ' '.join([str(z) for z in list(Z)]), val, pval))
self._update_val_min(X, Y, val)
self._update_pval_max(X, Y, pval)
self._update_cardinality(X, Y, len(Z))
# Check whether the test result was significant
if pval > self.pc_alpha:
# Mark the edge from X to Y for removal, save Z as separating set
to_remove[Y[0]][X] = True
self._save_sepset(X, Y, (frozenset(Z), ""))
# Verbose output
if self.verbosity >= 1:
print("({},{:2}) {:11} {} given {}".format(X[0], X[1], "independent", Y, Z))
# Break the for loop
break
# Run through all cardinality p_pc subsets of S_search_XY
if test_X:
if not np.isinf(self.max_q_global):
S_search_XY = self._sort_search_set(S_search_XY, X)
q_count = 0
for Z in combinations(S_search_XY, p_pc):
q_count = q_count + 1
if q_count > self.max_q_global:
break
val, pval = self.cond_ind_test.run_test(X = [X], Y = [Y], Z = list(Z), tau_max = self.tau_max)
if self.verbosity >= 2:
print(" %s _|_ %s | S_pc = %s: val = %.2f / pval = % .4f" %
(X, Y, ' '.join([str(z) for z in list(Z)]), val, pval))
self._update_val_min(X, Y, val)
self._update_pval_max(X, Y, pval)
self._update_cardinality(X, Y, len(Z))
if pval > self.pc_alpha:
to_remove[Y[0]][X] = True
self._save_sepset(X, Y, (frozenset(Z), ""))
if self.verbosity >= 1:
print("({},{:2}) {:11} {} given {}".format(X[0], X[1], "independent", Y, Z))
break
# end for (i, j, lag_i) in product(range(self.N), range(self.N), range(-self.tau_max, 1))
##########################################################################################################
### Remove edges marked for removal in to_remove #########################################################
# Remove edges
for j in range(self.N):
for (i, lag_i) in to_remove[j].keys():
self._write_link((i, lag_i), (j, 0), "", verbosity = self.verbosity)
# Verbose output
if self.verbosity >= 1:
print("\nTest phase complete")
##########################################################################################################
### Check for convergence ################################################################################
if has_converged:
# If no link needed testing, this algorithm has converged. Therfore, break the while loop
break
else:
# At least one link needed testing, this algorithm has not yet converged. Therefore, increase p_pc
p_pc = p_pc + 1
# end while True
# Verbose output
if self.verbosity >= 1:
print("\nPreliminary removal phase complete")
print("\nGraph:\n--------------------------------")
self._print_graph_dict()
print("--------------------------------")
# Return
return True
def _run_dsep_removal_phase(self):
"""Run the second removal phase of the FCI algorithm, including the preliminary collider orientation that is necessary for determining pds-sets"""
# Verbose output
if self.verbosity >= 1:
print("\n=======================================================")
print("=======================================================")
print("Starting final removal phase")
# Make the preliminary orientations that are necessary for determining pds_t sets
self._run_orientation_phase(rule_list = [["R-00-d"]], voting = "Majority-Preliminary")
# Remember all edges that have not been fully tested due to self.max_pds_set, self.max_q_global or self.max_p_global
self._cannot_fix = set()
# Iterate until convergence
# p_pc is the cardinality of the conditioning set
p_pc = 0
while True:
##########################################################################################################
### Run the next removal iteration #######################################################################
# Verbose output
if self.verbosity >= 1:
if p_pc == 0:
print("\nStarting test phase\n")
print("p = {}".format(p_pc))
# Variable to check for convergence
has_converged = True
# Variable for keeping track of edges marked for removal
to_remove = {j: {} for j in range(self.N)}
# Iterate through all links
for (i, j, lag_i) in product(range(self.N), range(self.N), range(-self.tau_max, 1)):
# Decode the triple (i, j, lag_i) into pairs of variables (X, Y)
X = (i, lag_i)
Y = (j, 0)
######################################################################################################
### Exclusion of links ###############################################################################
# Exclude the current link if ...
# ... X = Y
if lag_i == 0 and i == j:
continue
# ... X > Y
if self._is_smaller(Y, X):
continue
# Get the current link
link = self._get_link(X, Y)
# Also exclude the current link if ...
# ... X and Y are not adjacent anymore
if link == "":
continue
# ... X and Y are adjacent in the true MAG
if link[1] == "-":
continue
######################################################################################################
### Preparation of PC search sets ####################################################################
# Verbose output
if self.verbosity >= 2:
print("_get_pds_t ")
# Search for separating sets in pds_t(Y, X)
S_search_YX = self._get_pds_t(Y, X)
# Search for separating sets in pds_t(X, Y) always if X and Y are contemporaneous or if specified by self.max_cond_px
test_X = True if (lag_i == 0 or (self.max_cond_px > 0 and self.max_cond_px >= p_pc)) else False
if test_X:
S_search_XY = self._get_pds_t(X, Y)
# If the pds_t sets exceed the specified bounds, do not test this link. Remember that the link has not been fully tested
if len(S_search_YX) > self.max_pds_set or (test_X and len(S_search_XY) > self.max_pds_set):
self._cannot_fix.add((X, Y))
continue
######################################################################################################
### Check whether the link needs testing #############################################################
# If there are less than p_pc elements in the search set(s), the link does not need further testing. X and Y are adjacent in the true MAG, unless the link has not been fully tested
if len(S_search_YX) < p_pc and (not test_X or len(S_search_XY) < p_pc):
if (X, Y) not in self._cannot_fix:
self._write_link(X, Y, link[0] + "-" + link[2], verbosity = self.verbosity)
continue
# Force-quit while leep when p_pc exceeds the specified limits
if p_pc > self.max_p_global or p_pc > self.max_p_dsep:
continue
# Since this link does need testing, the algorithm has not converged yet
has_converged = False
######################################################################################################
### Tests for conditional independence ###############################################################
# Verbose output
if self.verbosity >= 1:
print("for S_pc in combinations(S_search_YX, p_pc)")
# If self.max_q_global is finite, the below for loop may be broken earlier. To still guarantee order independence, the set from which the potential separating sets are created is ordered in an order independent way. Here, the elements of S_search_YX are ordered according to their minimal test statistic with Y
if not np.isinf(self.max_q_global):
S_search_YX = self._sort_search_set(S_search_YX, Y)
# q_count counts the number of conditional independence tests made for subsets of S_search_YX
q_count = 0
# Run through all cardinality p_pc subsets of S_search_YX
for Z in combinations(S_search_YX, p_pc):
# Stop testing if the number of tests exceeds the bound specified by self.max_q_global. Remember that the link hast not been fully tested
q_count = q_count + 1
if q_count > self.max_q_global:
self._cannot_fix.add((X, Y))
break
# Test conditional independence of X and Y given Z. Correspondingly updateself.val_min, self.pval_max, and self.cardinality
val, pval = self.cond_ind_test.run_test(X = [X], Y = [Y], Z = list(Z), tau_max = self.tau_max)
if self.verbosity >= 2:
print(" %s _|_ %s | S_pc = %s: val = %.2f / pval = % .4f" %
(X, Y, ' '.join([str(z) for z in list(Z)]), val, pval))
self._update_val_min(X, Y, val)
self._update_pval_max(X, Y, pval)
self._update_cardinality(X, Y, len(Z))
# Check whether the test result was significant
if pval > self.pc_alpha:
# Mark the edge from X to Y for removal and save sepset
to_remove[Y[0]][X] = True
self._save_sepset(X, Y, (frozenset(Z), ""))
# Verbose output
if self.verbosity >= 1:
print("({},{:2}) {:11} {} given {}".format(X[0], X[1], "independent", Y, Z))
# Break the for loop
break
if test_X:
if self.verbosity >= 1:
print("for S_pc in combinations(S_search_XY, p_pc)")
if not np.isinf(self.max_q_global):
S_search_XY = self._sort_search_set(S_search_XY, X)
q_count = 0
for Z in combinations(S_search_XY, p_pc):
q_count = q_count + 1
if q_count > self.max_q_global:
self._cannot_fix.add((X, Y))
break
val, pval = self.cond_ind_test.run_test(X = [X], Y = [Y], Z = list(Z), tau_max = self.tau_max)
if self.verbosity >= 2:
print(" %s _|_ %s | S_pc = %s: val = %.2f / pval = % .4f" %
(X, Y, ' '.join([str(z) for z in list(Z)]), val, pval))
# Update val_min and pval_max
self._update_val_min(X, Y, val)
self._update_pval_max(X, Y, pval)
self._update_cardinality(X, Y, len(Z))
if pval > self.pc_alpha:
to_remove[Y[0]][X] = True
self._save_sepset(X, Y, (frozenset(Z), ""))
if self.verbosity >= 1:
print("({},{:2}) {:11} {} given {}".format(X[0], X[1], "independent", Y, Z))
break
# end for (i, j, lag_i) in product(range(self.N), range(self.N), range(-(tau_max + 1), 1))
##########################################################################################################
### Remove edges marked for removal in to_remove #########################################################
# Remove edges
for j in range(self.N):
for (i, lag_i) in to_remove[j].keys():
self._write_link((i, lag_i), (j, 0), "", verbosity = self.verbosity)
# Verbose output
if self.verbosity >= 1:
print("\nTest phase complete")
##########################################################################################################
### Check for convergence ################################################################################
if has_converged:
# If no link needed testing, this algorithm has converged. Therfore, break the while loop
break
else:
# At least one link needed testing, this algorithm has not yet converged. Therefore, increase p_pc
p_pc = p_pc + 1
# end while True
# Undo all preliminary collider orientations
self._unorient_all_edges()
self.def_non_ancs = {j: set() for j in range(self.N)}
# Verbose output
if self.verbosity >= 1:
print("\nFinal removal phase complete")
print("\nGraph:\n--------------------------------")
self._print_graph_dict()
print("--------------------------------")
# Return
return True
def _run_fci_orientation_phase(self):
"""Run the final orientation phase the FCI algorithm"""
# Verbose output
if self.verbosity >= 1:
print("\n=======================================================")
print("=======================================================")
print("Starting FCI orientation phase")
# Orient colliders colliders
self._run_orientation_phase(rule_list = [["R-00-d"]], voting = "Majority-Final")
# Exhaustively apply the other relevant orientation rules. Rules 5, 6 and 7 are not relevant because by assumption there are no selection variables
self._run_orientation_phase(rule_list = [["R-01"], ["R-02"], ["R-03"], ["R-04"], ["R-08"], ["R-09"], ["R-10"]], voting = "Majority-Final")
# Verbose output
if self.verbosity >= 1:
print("\nFCI orientation phase complete")
print("\nFinal graph:\n--------------------------------")
print("--------------------------------")
self._print_graph_dict()
print("--------------------------------")
print("--------------------------------\n")
# Return
return True
########################################################################################################################
########################################################################################################################
########################################################################################################################
def _run_orientation_phase(self, rule_list, voting):
"""Function for exhaustive application of the orientation rules specified by rule_list. The argument voting specifies the rule with which it is decided whether B is in the separating set of A and C, where A - B - C is an unshielded triple"""
# Verbose output
if self.verbosity >= 1:
print("\nStarting orientation phase")
print("with rule list: ", rule_list)
# Run through all priority levels of rule_list
idx = 0
while idx <= len(rule_list) - 1:
# Some rule require that self._graph_full_dict is updated. Therefore, initialize this variable once the while loop (re)-starts at the first prioprity level
if idx == 0:
self._initialize_full_graph()
###########################################################################################################
### Rule application ######################################################################################
# Get the current rules
current_rules = rule_list[idx]
# Prepare a list to remember marked orientations
to_orient = []
# Run through all current rules
for rule in current_rules:
# Verbose output
if self.verbosity >= 1:
print("\n{}:".format(rule))
# Exhaustively apply the rule to the graph...
orientations = self._apply_rule(rule, voting)
# Verbose output
if self.verbosity >= 1:
for ((i, j, lag_i), new_link) in set(orientations):
print("{:10} ({},{:2}) {:3} ({},{:2}) ==> ({},{:2}) {:3} ({},{:2}) ".format("Marked:", i, lag_i, self._get_link((i, lag_i), (j, 0)), j, 0,i, lag_i, new_link, j, 0))
if len(orientations) == 0:
print("Found nothing")
# ... and stage the results for orientation and removal
to_orient.extend(orientations)
###########################################################################################################
### Aggregation of marked orientations ####################################################################
new_ancs = {j: set() for j in range(self.N)}
new_non_ancs = {j: set() for j in range(self.N)}
# Run through all of the nested dictionary
for ((i, j, lag_i), new_link) in to_orient:
# The old link
old_link = self._get_link((i, lag_i), (j, 0))
# Assert that no preceeding variable is marked as an ancestor of later variable
assert not (lag_i > 0 and new_link[2] == "-")
# New ancestral relation of (i, lag_i) to (j, 0)
if new_link[0] == "-" and old_link[0] != "-":
new_ancs[j].add((i, lag_i))
elif new_link[0] == "<" and old_link[0] != "<":
new_non_ancs[j].add((i, lag_i))
# New ancestral relation of (j, 0) to (i, lag_i == 0)
if lag_i == 0:
if new_link[2] == "-" and old_link[2] != "-":
new_ancs[i].add((j, 0))
elif new_link[2] == ">" and old_link[2] != ">":
new_non_ancs[i].add((j, 0))
###########################################################################################################
### Update ancestral information and determine next step ##################################################
# Update ancestral information. The function called includes conflict resolution
restart = self._apply_new_ancestral_information(new_non_ancs, new_ancs)
# If any useful new information was found, go back to idx = 0, else increase idx by 1
idx = 0 if restart == True else idx + 1
# end while i <= len(self.rule_list) - 1
# The algorithm has converged
# Verbose output
if self.verbosity >= 1:
print("\nOrientation phase complete")
# Return
return True
def _get_pds_t(self, A, B):
"""Return pds_t(A, B) according to the current graph"""
# Unpack A and B, then assert that at least one of them is at lag 0
var_A, lag_A = A
var_B, lag_B = B
assert lag_A == 0 or lag_B == 0
# If pds_t(A, B) is in memory, return from memory
memo = self._pds_t[A].get(B)
if memo is not None:
return memo
# Else, re-compute it with breath-first search according to the current graph
visited = set()
start_from = {((var, lag + lag_A), A) for (var, lag) in self.graph_full_dict[var_A].keys() if lag + lag_A >= -self.tau_max and lag + lag_A <= 0}
while start_from:
new_start_from = set()
for (current_node, previous_node) in start_from:
visited.add((current_node, previous_node))
for (var, lag) in self.graph_full_dict[current_node[0]]:
next_node = (var, lag + current_node[1])
if next_node[1] < -self.tau_max:
continue
if next_node[1] > 0:
continue
if (next_node, current_node) in visited:
continue
if next_node == previous_node:
continue
if self._get_link(next_node, previous_node) == "" and (self._get_link(previous_node, current_node)[2] == "o" or self._get_link(next_node, current_node)[2] == "o"):
continue
new_start_from.add((next_node, current_node))
start_from = new_start_from
# Cache results and return
res = {node for (node, _) in visited if node != A and node != B}
self.max_pds_set_found = max(self.max_pds_set_found, len(res))
self._pds_t[A][B] = res
return self._pds_t[A][B]
def _unorient_all_edges(self):
"""Remove all orientations, except the non-ancestorships implied by time order"""
for j in range(self.N):
for (i, lag_i) in self.graph_dict[j].keys():
link = self._get_link((i, lag_i), (j, 0))
if len(link) > 0:
if lag_i == 0:
new_link = "o" + link[1] + "o"
else:
new_link = "o" + link[1] + ">"
self.graph_dict[j][(i, lag_i)] = new_link
def _fix_all_edges(self):
"""Set the middle mark of all links to '-'"""
for j in range(self.N):
for (i, lag_i) in self.graph_dict[j].keys():
link = self._get_link((i, lag_i), (j, 0))
if len(link) > 0:
new_link = link[0] + "-" + link[2]
self.graph_dict[j][(i, lag_i)] = new_link
def _apply_new_ancestral_information(self, new_non_ancs, new_ancs):
"""Apply the new ancestorships and non-ancestorships specified by new_non_ancs and new_ancs to the current graph. Conflicts are resolved by marking. Returns True if any circle mark was turned into a head or tail, else False."""
#######################################################################################################
### Preprocessing #####################################################################################
# Memory variables
add_to_def_non_ancs = {j: set() for j in range(self.N)}
add_to_def_ancs = {j: set() for j in range(self.N)}
add_to_ambiguous_ancestorships = {j: set() for j in range(self.N)}
put_head_or_tail = False
# Default values
if new_non_ancs is None:
new_non_ancs = {j: set() for j in range(self.N)}
if new_ancs is None:
new_ancs = {j: set() for j in range(self.N)}
# Marking A as ancestor of B implies that B is marked as a non-ancestor of A. This is only non-trivial for A before B
for j in range(self.N):
for (i, lag_i) in new_ancs[j]:
if lag_i == 0:
new_non_ancs[i].add((j, 0))
#######################################################################################################
### Conflict resolution ###############################################################################
# Iterate through new_non_ancs
for j in range(self.N):
for (i, lag_i) in new_non_ancs[j]:
# X = (i, lag_i), Y = (j, 0)
# X is marked as non-ancestor for Y
# Conflict resolution
if (i, lag_i) in self.ambiguous_ancestorships[j]:
# There is a conflict, since it is already marked as ambiguous whether X is an ancestor of Y
if self.verbosity >= 1:
print("{:10} ({}, {:2}) marked as non-anc of {} but saved as ambiguous".format("Conflict:", i, lag_i, (j, 0)))
elif (i, lag_i) in self.def_ancs[j]:
# There is a conflict, since X is already marked as ancestor of Y
add_to_ambiguous_ancestorships[j].add((i, lag_i))
if self.verbosity >= 1:
print("{:10} ({}, {:2}) marked as non-anc of {} but saved as anc".format("Conflict:", i, lag_i, (j, 0)))
elif (i, lag_i) in new_ancs[j]:
# There is a conflict, since X is also marked as a new ancestor of Y
add_to_ambiguous_ancestorships[j].add((i, lag_i))
if self.verbosity >= 1:
print("{:10} ({}, {:2}) marked as both anc- and non-anc of {}".format("Conflict:", i, lag_i, (j, 0)))
else:
# There is no conflict
add_to_def_non_ancs[j].add((i, lag_i))
# Iterate through new_ancs
for j in range(self.N):
for (i, lag_i) in new_ancs[j]:
# X = (i, lag_i), Y = (j, 0)
# X is marked as ancestor for Y
# Conflict resolution
if (i, lag_i) in self.ambiguous_ancestorships[j]:
# There is a conflict, since it is already marked as ambiguous whether X is an ancestor of Y
if self.verbosity >= 1:
print("{:10} ({}, {:2}) marked as anc of {} but saved as ambiguous".format("Conflict:", i, lag_i, (j, 0)))
elif lag_i == 0 and (j, 0) in self.ambiguous_ancestorships[i]:
# There is a conflict, since X and Y are contemporaneous and it is already marked ambiguous as whether Y is an ancestor of Y
# Note: This is required here, because X being an ancestor of Y implies that Y is not an ancestor of X. This ambiguity cannot exist when X is before Y
if self.verbosity >= 1:
print("{:10} ({}, {:2}) marked as anc of {} but saved as ambiguous".format("Conflict:", i, lag_i, (j, 0)))
elif (i, lag_i) in self.def_non_ancs[j]:
# There is a conflict, since X is already marked as non-ancestor of Y
add_to_ambiguous_ancestorships[j].add((i, lag_i))
if self.verbosity >= 1:
print("{:10} ({}, {:2}) marked as anc of {} but saved as non-anc".format("Conflict:", i, lag_i, (j, 0)))
elif (i, lag_i) in new_non_ancs[j]:
# There is a conflict, since X is also marked as a new non-ancestor of Y
add_to_ambiguous_ancestorships[j].add((i, lag_i))
if self.verbosity >= 1:
print("{:10} ({}, {:2}) marked as both anc- and non-anc of {}".format("Conflict:", i, lag_i, (j, 0)))
else:
# There is no conflict
add_to_def_ancs[j].add((i, lag_i))
#######################################################################################################
#######################################################################################################
### Apply the ambiguous information ###################################################################
for j in range(self.N):
for (i, lag_i) in add_to_ambiguous_ancestorships[j]:
old_link = self._get_link((i, lag_i), (j, 0))
if len(old_link) > 0 and old_link[0] != "x":
new_link = "x" + old_link[1] + old_link[2]
self._write_link((i, lag_i), (j, 0), new_link, verbosity = self.verbosity)
if self.verbosity >= 1:
if (i, lag_i) in self.def_ancs[j]:
print("{:10} Removing ({}, {:2}) as anc of {}".format("Update:", i, lag_i, (j, 0)))
if (i, lag_i) in self.def_non_ancs[j]:
print("{:10} Removing ({}, {:2}) as non-anc of {}".format("Update:", i, lag_i, (j, 0)))
self.def_ancs[j].discard((i, lag_i))
self.def_non_ancs[j].discard((i, lag_i))
if lag_i == 0:
if self.verbosity >= 1 and (j, 0) in self.def_ancs[i]:
print("{:10} Removing {} as anc of {}".format("Update:", i, lag_i, (j, 0)))
self.def_ancs[i].discard((j, 0))
# Do we also need the following?
# self.def_non_ancs[i].discard((j, 0))
if self.verbosity >= 1 and (i, lag_i) not in self.ambiguous_ancestorships[j]:
print("{:10} Marking ancestorship of ({}, {:2}) to {} as ambiguous".format("Update:", i, lag_i, (j, 0)))
self.ambiguous_ancestorships[j].add((i, lag_i))
#######################################################################################################
### Apply the unambiguous information #################################################################
for j in range(self.N):
for (i, lag_i) in add_to_def_non_ancs[j]:
old_link = self._get_link((i, lag_i), (j, 0))
if len(old_link) > 0 and old_link[0] != "<":
new_link = "<" + old_link[1] + old_link[2]
self._write_link((i, lag_i), (j, 0), new_link, verbosity = self.verbosity)
put_head_or_tail = True
if self.verbosity >= 1 and (i, lag_i) not in self.def_non_ancs[j]:
print("{:10} Marking ({}, {:2}) as non-anc of {}".format("Update:", i, lag_i, (j, 0)))
self.def_non_ancs[j].add((i, lag_i))
for (i, lag_i) in add_to_def_ancs[j]:
old_link = self._get_link((i, lag_i), (j, 0))
if len(old_link) > 0 and (old_link[0] != "-" or old_link[2] != ">"):
new_link = "-" + old_link[1] + ">"
self._write_link((i, lag_i), (j, 0), new_link, verbosity = self.verbosity)
put_head_or_tail = True
if self.verbosity >= 1 and (i, lag_i) not in self.def_ancs[j]:
print("{:10} Marking ({}, {:2}) as anc of {}".format("Update:", i, lag_i, (j, 0)))
self.def_ancs[j].add((i, lag_i))
if lag_i == 0:
if self.verbosity >= 1 and (j, 0) not in self.def_non_ancs[i]:
print("{:10} Marking {} as non-anc of {}".format("Update:",(j, 0), (i, 0)))
self.def_non_ancs[i].add((j, 0))
#######################################################################################################
return put_head_or_tail
def _apply_rule(self, rule, voting):
"""Call the orientation-removal-rule specified by the string argument rule. Pass on voting."""
if rule == "R-00-d":
return self._apply_R00(voting)
elif rule == "R-01":
return self._apply_R01(voting)
elif rule == "R-02":
return self._apply_R02()
elif rule == "R-03":
return self._apply_R03(voting)
elif rule == "R-04":
return self._apply_R04(voting)
elif rule == "R-08":
return self._apply_R08()
elif rule == "R-09":
return self._apply_R09(voting)
elif rule == "R-10":
return self._apply_R10(voting)