1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16 package edu.internet2.middleware.shibboleth.wayf;
17
18 import java.io.IOException;
19 import java.io.UnsupportedEncodingException;
20 import java.net.MalformedURLException;
21 import java.net.URL;
22 import java.net.URLDecoder;
23 import java.net.URLEncoder;
24 import java.util.ArrayList;
25 import java.util.Collection;
26 import java.util.Comparator;
27 import java.util.Date;
28 import java.util.HashSet;
29 import java.util.Hashtable;
30 import java.util.List;
31 import java.util.Locale;
32 import java.util.Map;
33 import java.util.Set;
34 import java.util.TreeSet;
35
36 import javax.servlet.RequestDispatcher;
37 import javax.servlet.ServletException;
38 import javax.servlet.http.HttpServletRequest;
39 import javax.servlet.http.HttpServletResponse;
40
41 import org.opensaml.saml2.common.Extensions;
42 import org.opensaml.saml2.metadata.EntityDescriptor;
43 import org.opensaml.saml2.metadata.RoleDescriptor;
44 import org.opensaml.saml2.metadata.SPSSODescriptor;
45 import org.opensaml.samlext.idpdisco.DiscoveryResponse;
46 import org.opensaml.xml.XMLObject;
47 import org.slf4j.Logger;
48 import org.slf4j.LoggerFactory;
49 import org.w3c.dom.Element;
50 import org.w3c.dom.NodeList;
51
52 import edu.internet2.middleware.shibboleth.common.ShibbolethConfigurationException;
53 import edu.internet2.middleware.shibboleth.wayf.plugins.Plugin;
54 import edu.internet2.middleware.shibboleth.wayf.plugins.PluginContext;
55 import edu.internet2.middleware.shibboleth.wayf.plugins.PluginMetadataParameter;
56 import edu.internet2.middleware.shibboleth.wayf.plugins.WayfRequestHandled;
57
58
59
60
61 public class DiscoveryServiceHandler {
62
63
64
65
66
67
68
69 private static final String SHIRE_PARAM_NAME = "shire";
70
71
72
73 private static final String TARGET_PARAM_NAME = "target";
74
75
76
77 private static final String TIME_PARAM_NAME = "time";
78
79
80
81 private static final String PROVIDERID_PARAM_NAME = "providerId";
82
83
84
85
86 private static final String PROVIDERID_OBJECT_PARAM_NAME = "providerObject";
87
88
89
90
91
92
93 private static final String ENTITYID_PARAM_NAME = "entityID";
94
95
96
97 private static final String RETURN_PARAM_NAME = "return";
98
99
100
101 private static final String RETURN_ATTRIBUTE_NAME = "returnX";
102
103
104
105 private static final String RETURN_INDEX_NAME = "returnIndex";
106
107
108
109
110 private static final String RETURNID_PARAM_NAME = "returnIDParam";
111
112
113
114
115 private static final String RETURNID_DEFAULT_VALUE = "entityID";
116
117
118
119 private static final String ISPASSIVE_PARAM_NAME = "isPassive";
120
121
122
123
124 private static final String POLICY_PARAM_NAME = "policy";
125
126
127
128
129 private static final String KNOWN_POLICY_NAME
130 = "urn:oasis:names:tc:SAML:profiles:SSO:idp-discoveryprotocol:single";
131
132
133
134
135 private static final Logger LOG = LoggerFactory.getLogger(DiscoveryServiceHandler.class.getName());
136
137
138
139
140 private final String location;
141
142
143
144
145 private final boolean isDefault;
146
147
148
149
150 private final HandlerConfig config;
151
152
153
154
155 private final List <IdPSiteSet> siteSets;
156
157
158
159
160 private final List <Plugin> plugins;
161
162
163
164
165
166
167
168
169
170 protected DiscoveryServiceHandler(Element config,
171 Hashtable <String, IdPSiteSet> federations,
172 Hashtable <String, Plugin> plugins,
173 HandlerConfig defaultConfig) throws ShibbolethConfigurationException
174 {
175 siteSets = new ArrayList <IdPSiteSet>(federations.size());
176 this.plugins = new ArrayList <Plugin>(plugins.size());
177
178
179
180
181
182 this.config = new HandlerConfig(config, defaultConfig);
183
184 location = config.getAttribute("location");
185
186 if (location == null || location.equals("")) {
187
188 LOG.error("DiscoveryService must have a location specified");
189 throw new ShibbolethConfigurationException("DiscoveryService must have a location specified");
190 }
191
192
193
194
195
196 String attribute = config.getAttribute("default");
197 if (attribute != null && !attribute.equals("")) {
198 isDefault = Boolean.valueOf(attribute).booleanValue();
199 } else {
200 isDefault = false;
201 }
202
203
204
205
206
207 NodeList list = config.getElementsByTagName("Federation");
208
209 for (int i = 0; i < list.getLength(); i++ ) {
210
211 attribute = ((Element) list.item(i)).getAttribute("identifier");
212
213 IdPSiteSet siteset = federations.get(attribute);
214
215 if (siteset == null) {
216 LOG.error("Handler " + location + ": could not find metadata for <Federation> with identifier " + attribute + ".");
217 throw new ShibbolethConfigurationException(
218 "Handler " + location + ": could not find metadata for <Federation> identifier " + attribute + ".");
219 }
220
221 siteSets.add(siteset);
222 }
223
224 if (siteSets.size() == 0) {
225
226
227
228 siteSets.addAll(federations.values());
229 }
230
231
232
233
234
235 list = config.getElementsByTagName("PluginInstance");
236
237 for (int i = 0; i < list.getLength(); i++ ) {
238
239 attribute = ((Element) list.item(i)).getAttribute("identifier");
240
241 Plugin plugin = plugins.get(attribute);
242
243 if (plugin == null) {
244 LOG.error("Handler " + location + ": could not find plugin for identifier " + attribute);
245 throw new ShibbolethConfigurationException(
246 "Handler " + location + ": could not find plugin for identifier " + attribute);
247 }
248
249 this.plugins.add(plugin);
250 }
251
252
253
254
255
256
257
258
259 for (IdPSiteSet site: siteSets) {
260 for (Plugin plugin: this.plugins) {
261 site.addPlugin(plugin);
262 }
263 }
264 }
265
266
267
268
269
270
271
272
273
274
275 protected String getLocation() {
276 return location;
277 }
278
279
280
281
282
283 protected boolean isDefault() {
284 return isDefault;
285 }
286
287
288
289
290
291 public void doGet(HttpServletRequest req, HttpServletResponse res) {
292
293 String policy = req.getParameter(POLICY_PARAM_NAME);
294
295 if (null != policy && !KNOWN_POLICY_NAME.equals(policy)) {
296
297
298
299 LOG.error("Unknown policy " + policy);
300 handleError(req, res, "Unknown policy " + policy);
301 return;
302 }
303
304
305
306
307 String requestType = req.getParameter("action");
308
309 if (requestType == null || requestType.equals("")) {
310 requestType = "lookup";
311 }
312
313 try {
314
315 if (requestType.equals("search")) {
316
317 String parameter = req.getParameter("string");
318 if (parameter != null && parameter.equals("")) {
319 parameter = null;
320 }
321 handleLookup(req, res, parameter);
322
323 } else if (requestType.equals("selection")) {
324
325 handleSelection(req, res);
326 } else {
327 handleLookup(req, res, null);
328 }
329 } catch (WayfException we) {
330 LOG.error("Error processing DS request:", we);
331 handleError(req, res, we.getLocalizedMessage());
332 } catch (WayfRequestHandled we) {
333
334
335
336 }
337
338 }
339
340
341
342
343
344
345
346
347
348 private void handleSelection(HttpServletRequest req,
349 HttpServletResponse res) throws WayfRequestHandled, WayfException
350 {
351
352 String idpName = req.getParameter("origin");
353 LOG.debug("Processing handle selection: " + idpName);
354
355 String sPName = getSPId(req);
356
357 if (idpName == null || idpName.equals("")) {
358 handleLookup(req, res, null);
359 return;
360 }
361
362 if (getValue(req, SHIRE_PARAM_NAME) == null) {
363
364
365
366 setupReturnAddress(sPName, req);
367 }
368
369
370
371 IdPSite site = null;
372
373 for (Plugin plugin:plugins) {
374 for (IdPSiteSet idPSiteSet: siteSets) {
375 PluginMetadataParameter param = idPSiteSet.paramFor(plugin);
376 plugin.selected(req, res, param, idpName);
377 if (site == null && idPSiteSet.containsIdP(idpName)) {
378 site = idPSiteSet.getSite(idpName);
379 }
380 }
381 }
382
383 if (site == null) {
384 handleLookup(req, res, null);
385 } else {
386 forwardRequest(req, res, site);
387 }
388 }
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404 private void setupReturnAddress(String spName, HttpServletRequest req) throws WayfException{
405
406 DiscoveryResponse[] discoveryServices;
407 Set<XMLObject> objects = new HashSet<XMLObject>();
408 String defaultName = null;
409 boolean foundSPName = false;
410
411 for (IdPSiteSet metadataProvider:siteSets) {
412
413
414
415
416
417 if (metadataProvider.containsSP(spName)) {
418
419
420
421
422
423 foundSPName = true;
424 EntityDescriptor entity = metadataProvider.getEntity(spName);
425 List<RoleDescriptor> roles = entity.getRoleDescriptors();
426
427 for (RoleDescriptor role:roles) {
428
429
430
431
432
433 if (role instanceof SPSSODescriptor) {
434
435
436
437
438
439 Extensions exts = role.getExtensions();
440 if (exts != null) {
441 objects.addAll(exts.getOrderedChildren());
442 }
443 }
444 }
445 }
446 }
447 if (!foundSPName) {
448 LOG.error("Could not locate SP " + spName + " in metadata");
449 }
450
451
452
453
454
455 discoveryServices = new DiscoveryResponse[objects.size()];
456 int dsCount = 0;
457
458 for (XMLObject obj:objects) {
459 if (obj instanceof DiscoveryResponse) {
460 DiscoveryResponse ds = (DiscoveryResponse) obj;
461 discoveryServices[dsCount++] = ds;
462 if (ds.isDefault() || null == defaultName) {
463 defaultName = ds.getLocation();
464 }
465 }
466 }
467
468
469
470
471
472 String returnName = req.getParameter(RETURN_PARAM_NAME);
473
474 if (returnName == null || returnName.length() == 0) {
475 returnName = getValue(req, RETURN_ATTRIBUTE_NAME);
476 }
477
478
479
480
481
482 String returnIndex = req.getParameter(RETURN_INDEX_NAME);
483
484 if (returnName != null && returnName.length() != 0) {
485
486
487
488 String nameNoParam = returnName;
489 URL providedReturnURL;
490 int index = nameNoParam.indexOf('?');
491 boolean found = false;
492
493 if (index >= 0) {
494 nameNoParam = nameNoParam.substring(0,index);
495 }
496
497 try {
498 providedReturnURL = new URL(nameNoParam);
499 } catch (MalformedURLException e) {
500 throw new WayfException("Couldn't parse provided return name " + nameNoParam, e);
501 }
502
503
504 for (DiscoveryResponse disc: discoveryServices) {
505 if (equalsURL(disc, providedReturnURL)) {
506 found = true;
507 break;
508 }
509 }
510 if (!found) {
511 throw new WayfException("Couldn't find endpoint " + nameNoParam + " in metadata");
512 }
513 } else if (returnIndex != null && returnIndex.length() != 0) {
514
515 int index;
516 try {
517 index = Integer.parseInt(returnIndex);
518 } catch (NumberFormatException e) {
519 throw new WayfException("Couldn't convert " + returnIndex + " into an index");
520 }
521
522
523
524 boolean found = false;
525
526 for (DiscoveryResponse disc: discoveryServices) {
527 if (index == disc.getIndex()) {
528 found = true;
529 returnName = disc.getLocation();
530 break;
531 }
532 }
533 if (!found) {
534 throw new WayfException("Couldn't not find endpoint " + returnIndex + "in metadata");
535 }
536 } else {
537
538
539
540 returnName = defaultName;
541 }
542
543
544
545
546 req.setAttribute(RETURN_ATTRIBUTE_NAME, returnName);
547 }
548
549
550
551
552
553
554
555
556
557 private static boolean equalsURL(DiscoveryResponse discovery, URL providedName) {
558
559
560
561
562 if (null == discovery) {
563 return false;
564 }
565
566 URL discoveryName;
567 try {
568 discoveryName = new URL(discovery.getLocation());
569 } catch (MalformedURLException e) {
570
571
572
573 LOG.warn("Found invalid discovery end point : " + discovery.getLocation(), e);
574 return false;
575 }
576
577 return providedName.equals(discoveryName);
578
579 }
580
581
582
583
584
585
586
587
588
589
590 private void handleLookup(HttpServletRequest req,
591 HttpServletResponse res,
592 String searchName) throws WayfException, WayfRequestHandled {
593
594 String shire = getValue(req, SHIRE_PARAM_NAME);
595 String providerId = getSPId(req);
596 EntityDescriptor sp = null;
597 boolean twoZeroProtocol = (shire == null);
598 boolean isPassive = (twoZeroProtocol &&
599 "true".equalsIgnoreCase(getValue(req, ISPASSIVE_PARAM_NAME)));
600
601 Collection <IdPSiteSetEntry> siteLists = null;
602 Collection<IdPSite> searchResults = null;
603
604 if (config.getProvideListOfLists()) {
605 siteLists = new ArrayList <IdPSiteSetEntry>(siteSets.size());
606 }
607
608 Collection <IdPSite> sites = null;
609 Comparator<IdPSite> comparator = new IdPSite.Compare(req);
610
611 if (config.getProvideList()) {
612 sites = new TreeSet<IdPSite>(comparator);
613 }
614
615 if (searchName != null && !searchName.equals("")) {
616 searchResults = new TreeSet<IdPSite>(comparator);
617 }
618
619 LOG.debug("Processing Idp Lookup for : " + providerId);
620
621
622
623
624
625
626 PluginContext[] ctx = new PluginContext[plugins.size()];
627 List<IdPSite> hintList = new ArrayList<IdPSite>();
628
629 if (twoZeroProtocol) {
630 setupReturnAddress(providerId, req);
631 }
632
633
634
635
636 try {
637 for (IdPSiteSet metadataProvider:siteSets) {
638
639
640
641
642
643 if (metadataProvider.containsSP(providerId) || !config.getLookupSp()) {
644
645 Collection <IdPSite> search = null;
646
647 if (null == sp) {
648 sp = metadataProvider.getEntity(providerId);
649 }
650
651 if (searchResults != null) {
652 search = new TreeSet<IdPSite>(comparator);
653 }
654
655 Map <String, IdPSite> theseSites = metadataProvider.getIdPSites(searchName, config, search);
656
657
658
659
660 for (int i = 0; i < plugins.size(); i++) {
661
662 Plugin plugin = plugins.get(i);
663
664 if (searchResults == null) {
665
666
667
668 ctx[i] = plugin.lookup(req,
669 res,
670 metadataProvider.paramFor(plugin),
671 theseSites,
672 ctx[i],
673 hintList);
674 } else {
675 ctx[i] = plugin.search(req,
676 res,
677 metadataProvider.paramFor(plugin),
678 searchName,
679 theseSites,
680 ctx[i],
681 searchResults,
682 hintList);
683 }
684 }
685
686 if (null == theseSites || theseSites.isEmpty()) {
687 continue;
688 }
689
690
691
692
693
694
695 Collection<IdPSite> values = new TreeSet<IdPSite>(comparator);
696 if (null != theseSites) {
697 values.addAll(theseSites.values());
698 }
699
700 if (siteLists != null) {
701 siteLists.add(new IdPSiteSetEntry(metadataProvider,values));
702 }
703
704 if (sites != null) {
705 sites.addAll(values);
706 }
707
708 if (searchResults != null) {
709 searchResults.addAll(search);
710 }
711 }
712 }
713
714 if (isPassive) {
715
716
717
718 if (0 != hintList.size()) {
719
720
721
722 forwardRequest(req, res, hintList.get(0));
723 } else {
724 forwardRequest(req, res, null);
725 }
726 return;
727 }
728
729
730
731
732
733
734 if (twoZeroProtocol) {
735
736
737
738 String returnString = (String) req.getAttribute(RETURN_ATTRIBUTE_NAME);
739 if (null == returnString || 0 == returnString.length()) {
740 throw new WayfException("Parameter " + RETURN_PARAM_NAME + " not supplied");
741 }
742
743 String returnId = getValue(req, RETURNID_PARAM_NAME);
744 if (null == returnId || 0 == returnId.length()) {
745 returnId = RETURNID_DEFAULT_VALUE;
746 }
747
748
749
750 req.setAttribute(RETURN_ATTRIBUTE_NAME, returnString);
751 req.setAttribute(RETURNID_PARAM_NAME, returnId);
752 req.setAttribute(ENTITYID_PARAM_NAME, providerId);
753
754 } else {
755 String target = getValue(req, TARGET_PARAM_NAME);
756 if (null == target || 0 == target.length()) {
757 throw new WayfException("Could not extract target from provided parameters");
758 }
759 req.setAttribute(SHIRE_PARAM_NAME, shire);
760 req.setAttribute(TARGET_PARAM_NAME, target);
761 req.setAttribute(PROVIDERID_PARAM_NAME, providerId);
762
763
764
765 req.setAttribute("time", new Long(new Date().getTime() / 1000).toString());
766
767 }
768
769
770
771
772 setDisplayLanguage(sites, req);
773 req.setAttribute("sites", sites);
774 if (null != siteLists) {
775 for (IdPSiteSetEntry siteSetEntry:siteLists) {
776 setDisplayLanguage(siteSetEntry.getSites(), req);
777 }
778 }
779
780 req.setAttribute(PROVIDERID_OBJECT_PARAM_NAME, sp);
781
782 req.setAttribute("siteLists", siteLists);
783 req.setAttribute("requestURL", req.getRequestURI().toString());
784
785 if (searchResults != null) {
786 if (searchResults.size() != 0) {
787 setDisplayLanguage(searchResults, req);
788 req.setAttribute("searchresults", searchResults);
789 } else {
790 req.setAttribute("searchResultsEmpty", "true");
791 }
792 }
793
794 if (hintList.size() > 0) {
795 setDisplayLanguage(hintList, req);
796 req.setAttribute("cookieList", hintList);
797 }
798
799 LOG.debug("Displaying WAYF selection page.");
800 RequestDispatcher rd = req.getRequestDispatcher(config.getJspFile());
801
802
803
804
805 rd.forward(req, res);
806 } catch (IOException ioe) {
807 LOG.error("Problem displaying WAYF UI.\n" + ioe.getMessage());
808 throw new WayfException("Problem displaying WAYF UI", ioe);
809 } catch (ServletException se) {
810 LOG.error("Problem displaying WAYF UI.\n" + se.getMessage());
811 throw new WayfException("Problem displaying WAYF UI", se);
812 }
813 }
814
815
816
817
818
819
820
821
822 private void setDisplayLanguage(Collection<IdPSite> sites, HttpServletRequest req) {
823
824 if (null == sites) {
825 return;
826 }
827 Locale locale = req.getLocale();
828 if (null == locale) {
829 Locale.getDefault();
830 }
831 String lang = locale.getLanguage();
832
833 for (IdPSite site : sites) {
834 site.setDisplayLanguage(lang);
835 }
836 }
837
838
839
840
841
842
843
844
845
846
847 public static void forwardRequest(HttpServletRequest req, HttpServletResponse res, IdPSite site)
848 throws WayfException {
849
850 String shire = getValue(req, SHIRE_PARAM_NAME);
851 String providerId = getSPId(req);
852 boolean twoZeroProtocol = (shire == null);
853
854 if (!twoZeroProtocol) {
855 String handleService = site.getAddressForWAYF();
856 if (handleService != null ) {
857
858 String target = getValue(req, TARGET_PARAM_NAME);
859 if (null == target || 0 == target.length()) {
860 throw new WayfException("Could not extract target from provided parameters");
861 }
862
863 LOG.info("Redirecting to selected Handle Service: " + handleService);
864 try {
865 StringBuffer buffer = new StringBuffer(handleService +
866 "?" + TARGET_PARAM_NAME + "=");
867 buffer.append(URLEncoder.encode(target, "UTF-8"));
868 buffer.append("&" + SHIRE_PARAM_NAME + "=");
869 buffer.append(URLEncoder.encode(shire, "UTF-8"));
870 buffer.append("&" + PROVIDERID_PARAM_NAME + "=");
871 buffer.append(URLEncoder.encode(providerId, "UTF-8"));
872
873
874
875
876 buffer.append("&" + TIME_PARAM_NAME + "=");
877 buffer.append(new Long(new Date().getTime() / 1000).toString());
878 res.sendRedirect(buffer.toString());
879 } catch (IOException ioe) {
880
881
882
883 throw new WayfException("Error forwarding to IdP: \n" + ioe.getMessage());
884 }
885 } else {
886 String s = "Error finding to IdP: " + site.getDisplayName(req);
887 LOG.error(s);
888 throw new WayfException(s);
889 }
890 } else {
891 String returnUrl = (String) req.getAttribute(RETURN_ATTRIBUTE_NAME);
892
893 if (null == returnUrl || 0 == returnUrl.length()) {
894 throw new WayfException("Could not find return parameter");
895 }
896 try {
897 returnUrl = URLDecoder.decode(returnUrl, "UTF-8");
898 } catch (UnsupportedEncodingException e) {
899 throw new WayfException("Did not understand parameter ", e);
900 }
901 String redirect;
902 if (site != null) {
903 StringBuffer buffer = new StringBuffer(returnUrl);
904
905
906
907 String returnParam = getValue(req, RETURNID_PARAM_NAME);
908 if (null == returnParam || 0 == returnParam.length()) {
909 returnParam = RETURNID_DEFAULT_VALUE;
910 }
911
912
913
914
915 if (returnUrl.indexOf('?') >= 0) {
916
917
918
919 buffer.append("&" + returnParam + "=");
920 } else {
921
922
923
924 buffer.append("?" + returnParam + "=");
925 }
926 buffer.append(site.getName());
927 redirect = buffer.toString();
928 } else {
929
930
931
932 redirect = returnUrl;
933 }
934
935 LOG.debug("Dispatching to " + redirect);
936
937 try {
938 res.sendRedirect(redirect);
939 } catch (IOException ioe) {
940
941
942
943 throw new WayfException("Error forwarding back to Sp: \n" + ioe.getMessage());
944 }
945 }
946 }
947
948
949
950
951
952
953
954
955
956 private void handleError(HttpServletRequest req, HttpServletResponse res, String message) {
957
958 LOG.debug("Displaying WAYF error page.");
959 req.setAttribute("errorText", message);
960 req.setAttribute("requestURL", req.getRequestURI().toString());
961 RequestDispatcher rd = req.getRequestDispatcher(config.getErrorJspFile());
962
963 try {
964 rd.forward(req, res);
965 } catch (IOException ioe) {
966 LOG.error("Problem trying to display WAYF error page: " + ioe.toString());
967 } catch (ServletException se) {
968 LOG.error("Problem trying to display WAYF error page: " + se.toString());
969 }
970 }
971
972
973
974
975
976
977
978 private static String getValue(HttpServletRequest req, String name) {
979
980
981 String value = req.getParameter(name);
982 if (value != null) {
983 return value;
984 }
985 return (String) req.getAttribute(name);
986 }
987
988 private static String getSPId(HttpServletRequest req) throws WayfException {
989
990
991
992
993 String param = req.getParameter(ENTITYID_PARAM_NAME);
994 if (param != null && !(param.length() == 0)) {
995 return param;
996 }
997
998 param = (String) req.getAttribute(ENTITYID_PARAM_NAME);
999 if (param != null && !(param.length() == 0)) {
1000 return param;
1001 }
1002
1003
1004
1005 param = req.getParameter(PROVIDERID_PARAM_NAME);
1006 if (param != null && !(param.length() == 0)) {
1007 return param;
1008 }
1009
1010 param = (String) req.getAttribute(PROVIDERID_PARAM_NAME);
1011 if (param != null && !(param.length() == 0)) {
1012 return param;
1013 }
1014 throw new WayfException("Could not locate SP identifier in parameters");
1015 }
1016 }