1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16 package edu.internet2.middleware.shibboleth.wayf;
17
18 import java.io.IOException;
19 import java.io.UnsupportedEncodingException;
20 import java.net.MalformedURLException;
21 import java.net.URL;
22 import java.net.URLDecoder;
23 import java.net.URLEncoder;
24 import java.util.ArrayList;
25 import java.util.Collection;
26 import java.util.Comparator;
27 import java.util.Date;
28 import java.util.HashSet;
29 import java.util.Hashtable;
30 import java.util.List;
31 import java.util.Locale;
32 import java.util.Map;
33 import java.util.Set;
34 import java.util.TreeSet;
35
36 import javax.servlet.RequestDispatcher;
37 import javax.servlet.ServletException;
38 import javax.servlet.http.HttpServletRequest;
39 import javax.servlet.http.HttpServletResponse;
40
41 import org.opensaml.saml2.common.Extensions;
42 import org.opensaml.saml2.metadata.EntityDescriptor;
43 import org.opensaml.saml2.metadata.RoleDescriptor;
44 import org.opensaml.saml2.metadata.SPSSODescriptor;
45 import org.opensaml.samlext.idpdisco.DiscoveryResponse;
46 import org.opensaml.xml.XMLObject;
47 import org.slf4j.Logger;
48 import org.slf4j.LoggerFactory;
49 import org.w3c.dom.Element;
50 import org.w3c.dom.NodeList;
51
52 import edu.internet2.middleware.shibboleth.common.ShibbolethConfigurationException;
53 import edu.internet2.middleware.shibboleth.wayf.plugins.Plugin;
54 import edu.internet2.middleware.shibboleth.wayf.plugins.PluginContext;
55 import edu.internet2.middleware.shibboleth.wayf.plugins.PluginMetadataParameter;
56 import edu.internet2.middleware.shibboleth.wayf.plugins.WayfRequestHandled;
57
58
59
60
61 public class DiscoveryServiceHandler {
62
63
64
65
66
67
68
69 private static final String SHIRE_PARAM_NAME = "shire";
70
71
72
73 private static final String TARGET_PARAM_NAME = "target";
74
75
76
77 private static final String TIME_PARAM_NAME = "time";
78
79
80
81 private static final String PROVIDERID_PARAM_NAME = "providerId";
82
83
84
85
86
87
88
89 private static final String ENTITYID_PARAM_NAME = "entityID";
90
91
92
93 private static final String RETURN_PARAM_NAME = "return";
94
95
96
97 private static final String RETURN_ATTRIBUTE_NAME = "returnX";
98
99
100
101 private static final String RETURN_INDEX_NAME = "returnIndex";
102
103
104
105
106 private static final String RETURNID_PARAM_NAME = "returnIDParam";
107
108
109
110
111 private static final String RETURNID_DEFAULT_VALUE = "entityID";
112
113
114
115 private static final String ISPASSIVE_PARAM_NAME = "isPassive";
116
117
118
119
120 private static final String POLICY_PARAM_NAME = "policy";
121
122
123
124
125 private static final String KNOWN_POLICY_NAME
126 = "urn:oasis:names:tc:SAML:profiles:SSO:idp-discoveryprotocol:single";
127
128
129
130
131 private static final Logger LOG = LoggerFactory.getLogger(DiscoveryServiceHandler.class.getName());
132
133
134
135
136 private final String location;
137
138
139
140
141 private final boolean isDefault;
142
143
144
145
146 private final HandlerConfig config;
147
148
149
150
151 private final List <IdPSiteSet> siteSets;
152
153
154
155
156 private final List <Plugin> plugins;
157
158
159
160
161
162
163
164
165
166 protected DiscoveryServiceHandler(Element config,
167 Hashtable <String, IdPSiteSet> federations,
168 Hashtable <String, Plugin> plugins,
169 HandlerConfig defaultConfig) throws ShibbolethConfigurationException
170 {
171 siteSets = new ArrayList <IdPSiteSet>(federations.size());
172 this.plugins = new ArrayList <Plugin>(plugins.size());
173
174
175
176
177
178 this.config = new HandlerConfig(config, defaultConfig);
179
180 location = config.getAttribute("location");
181
182 if (location == null || location.equals("")) {
183
184 LOG.error("DiscoveryService must have a location specified");
185 throw new ShibbolethConfigurationException("DiscoveryService must have a location specified");
186 }
187
188
189
190
191
192 String attribute = config.getAttribute("default");
193 if (attribute != null && !attribute.equals("")) {
194 isDefault = Boolean.valueOf(attribute).booleanValue();
195 } else {
196 isDefault = false;
197 }
198
199
200
201
202
203 NodeList list = config.getElementsByTagName("Federation");
204
205 for (int i = 0; i < list.getLength(); i++ ) {
206
207 attribute = ((Element) list.item(i)).getAttribute("identifier");
208
209 IdPSiteSet siteset = federations.get(attribute);
210
211 if (siteset == null) {
212 LOG.error("Handler " + location + ": could not find metadata for <Federation> with identifier " + attribute + ".");
213 throw new ShibbolethConfigurationException(
214 "Handler " + location + ": could not find metadata for <Federation> identifier " + attribute + ".");
215 }
216
217 siteSets.add(siteset);
218 }
219
220 if (siteSets.size() == 0) {
221
222
223
224 siteSets.addAll(federations.values());
225 }
226
227
228
229
230
231 list = config.getElementsByTagName("PluginInstance");
232
233 for (int i = 0; i < list.getLength(); i++ ) {
234
235 attribute = ((Element) list.item(i)).getAttribute("identifier");
236
237 Plugin plugin = plugins.get(attribute);
238
239 if (plugin == null) {
240 LOG.error("Handler " + location + ": could not find plugin for identifier " + attribute);
241 throw new ShibbolethConfigurationException(
242 "Handler " + location + ": could not find plugin for identifier " + attribute);
243 }
244
245 this.plugins.add(plugin);
246 }
247
248
249
250
251
252
253
254
255 for (IdPSiteSet site: siteSets) {
256 for (Plugin plugin: this.plugins) {
257 site.addPlugin(plugin);
258 }
259 }
260 }
261
262
263
264
265
266
267
268
269
270
271 protected String getLocation() {
272 return location;
273 }
274
275
276
277
278
279 protected boolean isDefault() {
280 return isDefault;
281 }
282
283
284
285
286
287 public void doGet(HttpServletRequest req, HttpServletResponse res) {
288
289 String policy = req.getParameter(POLICY_PARAM_NAME);
290
291 if (null != policy && !KNOWN_POLICY_NAME.equals(policy)) {
292
293
294
295 LOG.error("Unknown policy " + policy);
296 handleError(req, res, "Unknown policy " + policy);
297 return;
298 }
299
300
301
302
303 String requestType = req.getParameter("action");
304
305 if (requestType == null || requestType.equals("")) {
306 requestType = "lookup";
307 }
308
309 try {
310
311 if (requestType.equals("search")) {
312
313 String parameter = req.getParameter("string");
314 if (parameter != null && parameter.equals("")) {
315 parameter = null;
316 }
317 handleLookup(req, res, parameter);
318
319 } else if (requestType.equals("selection")) {
320
321 handleSelection(req, res);
322 } else {
323 handleLookup(req, res, null);
324 }
325 } catch (WayfException we) {
326 LOG.error("Error processing DS request:", we);
327 handleError(req, res, we.getLocalizedMessage());
328 } catch (WayfRequestHandled we) {
329
330
331
332 }
333
334 }
335
336
337
338
339
340
341
342
343
344 private void handleSelection(HttpServletRequest req,
345 HttpServletResponse res) throws WayfRequestHandled, WayfException
346 {
347
348 String idpName = req.getParameter("origin");
349 LOG.debug("Processing handle selection: " + idpName);
350
351 String sPName = getSPId(req);
352
353 if (idpName == null || idpName.equals("")) {
354 handleLookup(req, res, null);
355 return;
356 }
357
358 if (getValue(req, SHIRE_PARAM_NAME) == null) {
359
360
361
362 setupReturnAddress(sPName, req);
363 }
364
365
366
367 IdPSite site = null;
368
369 for (Plugin plugin:plugins) {
370 for (IdPSiteSet idPSiteSet: siteSets) {
371 PluginMetadataParameter param = idPSiteSet.paramFor(plugin);
372 plugin.selected(req, res, param, idpName);
373 if (site == null && idPSiteSet.containsIdP(idpName)) {
374 site = idPSiteSet.getSite(idpName);
375 }
376 }
377 }
378
379 if (site == null) {
380 handleLookup(req, res, null);
381 } else {
382 forwardRequest(req, res, site);
383 }
384 }
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400 private void setupReturnAddress(String spName, HttpServletRequest req) throws WayfException{
401
402 DiscoveryResponse[] discoveryServices;
403 Set<XMLObject> objects = new HashSet<XMLObject>();
404 String defaultName = null;
405
406 for (IdPSiteSet metadataProvider:siteSets) {
407
408
409
410
411
412 if (metadataProvider.containsSP(spName)) {
413
414
415
416
417
418
419 EntityDescriptor entity = metadataProvider.getEntity(spName);
420 List<RoleDescriptor> roles = entity.getRoleDescriptors();
421
422 for (RoleDescriptor role:roles) {
423
424
425
426
427
428 if (role instanceof SPSSODescriptor) {
429
430
431
432
433
434 Extensions exts = role.getExtensions();
435 if (exts != null) {
436 objects.addAll(exts.getOrderedChildren());
437 }
438 }
439 }
440 }
441 }
442
443
444
445
446
447 discoveryServices = new DiscoveryResponse[objects.size()];
448 int dsCount = 0;
449
450 for (XMLObject obj:objects) {
451 if (obj instanceof DiscoveryResponse) {
452 DiscoveryResponse ds = (DiscoveryResponse) obj;
453 discoveryServices[dsCount++] = ds;
454 if (ds.isDefault() || null == defaultName) {
455 defaultName = ds.getLocation();
456 }
457 }
458 }
459
460
461
462
463
464 String returnName = req.getParameter(RETURN_PARAM_NAME);
465
466 if (returnName == null || returnName.length() == 0) {
467 returnName = getValue(req, RETURN_ATTRIBUTE_NAME);
468 }
469
470
471
472
473
474 String returnIndex = req.getParameter(RETURN_INDEX_NAME);
475
476 if (returnName != null && returnName.length() != 0) {
477
478
479
480 String nameNoParam = returnName;
481 URL providedReturnURL;
482 int index = nameNoParam.indexOf('?');
483 boolean found = false;
484
485 if (index >= 0) {
486 nameNoParam = nameNoParam.substring(0,index);
487 }
488
489 try {
490 providedReturnURL = new URL(nameNoParam);
491 } catch (MalformedURLException e) {
492 throw new WayfException("Couldn't parse provided return name " + nameNoParam, e);
493 }
494
495
496 for (DiscoveryResponse disc: discoveryServices) {
497 if (equalsURL(disc, providedReturnURL)) {
498 found = true;
499 break;
500 }
501 }
502 if (!found) {
503 throw new WayfException("Couldn't find endpoint " + nameNoParam + " in metadata");
504 }
505 } else if (returnIndex != null && returnIndex.length() != 0) {
506
507 int index;
508 try {
509 index = Integer.parseInt(returnIndex);
510 } catch (NumberFormatException e) {
511 throw new WayfException("Couldn't convert " + returnIndex + " into an index");
512 }
513
514
515
516
517 boolean found = false;
518
519 for (DiscoveryResponse disc: discoveryServices) {
520 if (index == disc.getIndex()) {
521 found = true;
522 returnName = disc.getLocation();
523 break;
524 }
525 }
526 if (!found) {
527 throw new WayfException("Couldn't not find endpoint " + returnIndex + "in metadata");
528 }
529 } else {
530
531
532
533 returnName = defaultName;
534 }
535
536
537
538
539 req.setAttribute(RETURN_ATTRIBUTE_NAME, returnName);
540 }
541
542
543
544
545
546
547
548
549
550 private static boolean equalsURL(DiscoveryResponse discovery, URL providedName) {
551
552
553
554
555 if (null == discovery) {
556 return false;
557 }
558
559 URL discoveryName;
560 try {
561 discoveryName = new URL(discovery.getLocation());
562 } catch (MalformedURLException e) {
563
564
565
566 LOG.warn("Found invalid discovery end point : " + discovery.getLocation(), e);
567 return false;
568 }
569
570 return providedName.equals(discoveryName);
571
572 }
573
574
575
576
577
578
579
580
581
582
583 private void handleLookup(HttpServletRequest req,
584 HttpServletResponse res,
585 String searchName) throws WayfException, WayfRequestHandled {
586
587 String shire = getValue(req, SHIRE_PARAM_NAME);
588 String providerId = getSPId(req);
589 boolean twoZeroProtocol = (shire == null);
590 boolean isPassive = (twoZeroProtocol &&
591 "true".equalsIgnoreCase(getValue(req, ISPASSIVE_PARAM_NAME)));
592
593 Collection <IdPSiteSetEntry> siteLists = null;
594 Collection<IdPSite> searchResults = null;
595
596 if (config.getProvideListOfLists()) {
597 siteLists = new ArrayList <IdPSiteSetEntry>(siteSets.size());
598 }
599
600 Collection <IdPSite> sites = null;
601 Comparator<IdPSite> comparator = new IdPSite.Compare(req);
602
603 if (config.getProvideList()) {
604 sites = new TreeSet<IdPSite>(comparator);
605 }
606
607 if (searchName != null && !searchName.equals("")) {
608 searchResults = new TreeSet<IdPSite>(comparator);
609 }
610
611 LOG.debug("Processing Idp Lookup for : " + providerId);
612
613
614
615
616
617
618 PluginContext[] ctx = new PluginContext[plugins.size()];
619 List<IdPSite> hintList = new ArrayList<IdPSite>();
620
621 if (twoZeroProtocol) {
622 setupReturnAddress(providerId, req);
623 }
624
625
626
627
628 try {
629 for (IdPSiteSet metadataProvider:siteSets) {
630
631
632
633
634
635 if (metadataProvider.containsSP(providerId) || !config.getLookupSp()) {
636
637 Collection <IdPSite> search = null;
638
639 if (searchResults != null) {
640 search = new TreeSet<IdPSite>(comparator);
641 }
642
643 Map <String, IdPSite> theseSites = metadataProvider.getIdPSites(searchName, config, search);
644
645
646
647
648 for (int i = 0; i < plugins.size(); i++) {
649
650 Plugin plugin = plugins.get(i);
651
652 if (searchResults == null) {
653
654
655
656 ctx[i] = plugin.lookup(req,
657 res,
658 metadataProvider.paramFor(plugin),
659 theseSites,
660 ctx[i],
661 hintList);
662 } else {
663 ctx[i] = plugin.search(req,
664 res,
665 metadataProvider.paramFor(plugin),
666 searchName,
667 theseSites,
668 ctx[i],
669 searchResults,
670 hintList);
671 }
672 }
673
674 if (null == theseSites || theseSites.isEmpty()) {
675 continue;
676 }
677
678
679
680
681
682
683 Collection<IdPSite> values = new TreeSet<IdPSite>(comparator);
684 if (null != theseSites) {
685 values.addAll(theseSites.values());
686 }
687
688 if (siteLists != null) {
689 siteLists.add(new IdPSiteSetEntry(metadataProvider,values));
690 }
691
692 if (sites != null) {
693 sites.addAll(values);
694 }
695
696 if (searchResults != null) {
697 searchResults.addAll(search);
698 }
699 }
700 }
701
702 if (isPassive) {
703
704
705
706 if (0 != hintList.size()) {
707
708
709
710 forwardRequest(req, res, hintList.get(0));
711 } else {
712 forwardRequest(req, res, null);
713 }
714 return;
715 }
716
717
718
719
720
721
722 if (twoZeroProtocol) {
723
724
725
726 String returnString = (String) req.getAttribute(RETURN_ATTRIBUTE_NAME);
727 if (null == returnString || 0 == returnString.length()) {
728 throw new WayfException("Parameter " + RETURN_PARAM_NAME + " not supplied");
729 }
730
731 String returnId = getValue(req, RETURNID_PARAM_NAME);
732 if (null == returnId || 0 == returnId.length()) {
733 returnId = RETURNID_DEFAULT_VALUE;
734 }
735
736
737
738 req.setAttribute(RETURN_ATTRIBUTE_NAME, returnString);
739 req.setAttribute(RETURNID_PARAM_NAME, returnId);
740 req.setAttribute(ENTITYID_PARAM_NAME, providerId);
741
742 } else {
743 String target = getValue(req, TARGET_PARAM_NAME);
744 if (null == target || 0 == target.length()) {
745 throw new WayfException("Could not extract target from provided parameters");
746 }
747 req.setAttribute(SHIRE_PARAM_NAME, shire);
748 req.setAttribute(TARGET_PARAM_NAME, target);
749 req.setAttribute(PROVIDERID_PARAM_NAME, providerId);
750
751
752
753 req.setAttribute("time", new Long(new Date().getTime() / 1000).toString());
754
755 }
756
757
758
759
760
761 setDisplayLanguage(sites, req);
762 req.setAttribute("sites", sites);
763 if (null != siteLists) {
764 for (IdPSiteSetEntry siteSetEntry:siteLists) {
765 setDisplayLanguage(siteSetEntry.getSites(), req);
766 }
767 }
768
769 req.setAttribute("siteLists", siteLists);
770 req.setAttribute("requestURL", req.getRequestURI().toString());
771
772 if (searchResults != null) {
773 if (searchResults.size() != 0) {
774 setDisplayLanguage(searchResults, req);
775 req.setAttribute("searchresults", searchResults);
776 } else {
777 req.setAttribute("searchResultsEmpty", "true");
778 }
779 }
780
781 if (hintList.size() > 0) {
782 setDisplayLanguage(hintList, req);
783 req.setAttribute("cookieList", hintList);
784 }
785
786 LOG.debug("Displaying WAYF selection page.");
787 RequestDispatcher rd = req.getRequestDispatcher(config.getJspFile());
788
789
790
791
792 rd.forward(req, res);
793 } catch (IOException ioe) {
794 LOG.error("Problem displaying WAYF UI.\n" + ioe.getMessage());
795 throw new WayfException("Problem displaying WAYF UI", ioe);
796 } catch (ServletException se) {
797 LOG.error("Problem displaying WAYF UI.\n" + se.getMessage());
798 throw new WayfException("Problem displaying WAYF UI", se);
799 }
800 }
801
802
803
804
805
806
807
808
809 private void setDisplayLanguage(Collection<IdPSite> sites, HttpServletRequest req) {
810
811 if (null == sites) {
812 return;
813 }
814 Locale locale = req.getLocale();
815 if (null == locale) {
816 Locale.getDefault();
817 }
818 String lang = locale.getLanguage();
819
820 for (IdPSite site : sites) {
821 site.setDisplayLanguage(lang);
822 }
823 }
824
825
826
827
828
829
830
831
832
833
834 public static void forwardRequest(HttpServletRequest req, HttpServletResponse res, IdPSite site)
835 throws WayfException {
836
837 String shire = getValue(req, SHIRE_PARAM_NAME);
838 String providerId = getSPId(req);
839 boolean twoZeroProtocol = (shire == null);
840
841 if (!twoZeroProtocol) {
842 String handleService = site.getAddressForWAYF();
843 if (handleService != null ) {
844
845 String target = getValue(req, TARGET_PARAM_NAME);
846 if (null == target || 0 == target.length()) {
847 throw new WayfException("Could not extract target from provided parameters");
848 }
849
850 LOG.info("Redirecting to selected Handle Service: " + handleService);
851 try {
852 StringBuffer buffer = new StringBuffer(handleService +
853 "?" + TARGET_PARAM_NAME + "=");
854 buffer.append(URLEncoder.encode(target, "UTF-8"));
855 buffer.append("&" + SHIRE_PARAM_NAME + "=");
856 buffer.append(URLEncoder.encode(shire, "UTF-8"));
857 buffer.append("&" + PROVIDERID_PARAM_NAME + "=");
858 buffer.append(URLEncoder.encode(providerId, "UTF-8"));
859
860
861
862
863 buffer.append("&" + TIME_PARAM_NAME + "=");
864 buffer.append(new Long(new Date().getTime() / 1000).toString());
865 res.sendRedirect(buffer.toString());
866 } catch (IOException ioe) {
867
868
869
870 throw new WayfException("Error forwarding to IdP: \n" + ioe.getMessage());
871 }
872 } else {
873 String s = "Error finding to IdP: " + site.getDisplayName(req);
874 LOG.error(s);
875 throw new WayfException(s);
876 }
877 } else {
878 String returnUrl = (String) req.getAttribute(RETURN_ATTRIBUTE_NAME);
879
880 if (null == returnUrl || 0 == returnUrl.length()) {
881 throw new WayfException("Could not find return parameter");
882 }
883 try {
884 returnUrl = URLDecoder.decode(returnUrl, "UTF-8");
885 } catch (UnsupportedEncodingException e) {
886 throw new WayfException("Did not understand parameter ", e);
887 }
888 String redirect;
889 if (site != null) {
890 StringBuffer buffer = new StringBuffer(returnUrl);
891
892
893
894 String returnParam = getValue(req, RETURNID_PARAM_NAME);
895 if (null == returnParam || 0 == returnParam.length()) {
896 returnParam = RETURNID_DEFAULT_VALUE;
897 }
898
899
900
901
902 if (returnUrl.indexOf('?') >= 0) {
903
904
905
906 buffer.append("&" + returnParam + "=");
907 } else {
908
909
910
911 buffer.append("?" + returnParam + "=");
912 }
913 buffer.append(site.getName());
914 redirect = buffer.toString();
915 } else {
916
917
918
919 redirect = returnUrl;
920 }
921
922 LOG.debug("Dispatching to " + redirect);
923
924 try {
925 res.sendRedirect(redirect);
926 } catch (IOException ioe) {
927
928
929
930 throw new WayfException("Error forwarding back to Sp: \n" + ioe.getMessage());
931 }
932 }
933 }
934
935
936
937
938
939
940
941
942
943 private void handleError(HttpServletRequest req, HttpServletResponse res, String message) {
944
945 LOG.debug("Displaying WAYF error page.");
946 req.setAttribute("errorText", message);
947 req.setAttribute("requestURL", req.getRequestURI().toString());
948 RequestDispatcher rd = req.getRequestDispatcher(config.getErrorJspFile());
949
950 try {
951 rd.forward(req, res);
952 } catch (IOException ioe) {
953 LOG.error("Problem trying to display WAYF error page: " + ioe.toString());
954 } catch (ServletException se) {
955 LOG.error("Problem trying to display WAYF error page: " + se.toString());
956 }
957 }
958
959
960
961
962
963
964
965 private static String getValue(HttpServletRequest req, String name) {
966
967
968 String value = req.getParameter(name);
969 if (value != null) {
970 return value;
971 }
972 return (String) req.getAttribute(name);
973 }
974
975 private static String getSPId(HttpServletRequest req) throws WayfException {
976
977
978
979
980 String param = req.getParameter(ENTITYID_PARAM_NAME);
981 if (param != null && !(param.length() == 0)) {
982 return param;
983 }
984
985 param = (String) req.getAttribute(ENTITYID_PARAM_NAME);
986 if (param != null && !(param.length() == 0)) {
987 return param;
988 }
989
990
991
992 param = req.getParameter(PROVIDERID_PARAM_NAME);
993 if (param != null && !(param.length() == 0)) {
994 return param;
995 }
996
997 param = (String) req.getAttribute(PROVIDERID_PARAM_NAME);
998 if (param != null && !(param.length() == 0)) {
999 return param;
1000 }
1001 throw new WayfException("Could not locate SP identifier in parameters");
1002 }
1003 }