1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16 package edu.internet2.middleware.shibboleth.wayf;
17
18 import java.io.IOException;
19 import java.io.UnsupportedEncodingException;
20 import java.net.MalformedURLException;
21 import java.net.URL;
22 import java.net.URLDecoder;
23 import java.net.URLEncoder;
24 import java.util.ArrayList;
25 import java.util.Collection;
26 import java.util.Comparator;
27 import java.util.Date;
28 import java.util.HashSet;
29 import java.util.Hashtable;
30 import java.util.List;
31 import java.util.Locale;
32 import java.util.Map;
33 import java.util.Set;
34 import java.util.TreeSet;
35
36 import javax.servlet.RequestDispatcher;
37 import javax.servlet.ServletException;
38 import javax.servlet.http.HttpServletRequest;
39 import javax.servlet.http.HttpServletResponse;
40
41 import org.opensaml.saml2.common.Extensions;
42 import org.opensaml.saml2.metadata.EntityDescriptor;
43 import org.opensaml.saml2.metadata.RoleDescriptor;
44 import org.opensaml.saml2.metadata.SPSSODescriptor;
45 import org.opensaml.xml.XMLObject;
46 import org.slf4j.Logger;
47 import org.slf4j.LoggerFactory;
48 import org.w3c.dom.Element;
49 import org.w3c.dom.NodeList;
50
51 import edu.internet2.middleware.shibboleth.common.ShibbolethConfigurationException;
52 import edu.internet2.middleware.shibboleth.wayf.plugins.Plugin;
53 import edu.internet2.middleware.shibboleth.wayf.plugins.PluginContext;
54 import edu.internet2.middleware.shibboleth.wayf.plugins.PluginMetadataParameter;
55 import edu.internet2.middleware.shibboleth.wayf.plugins.WayfRequestHandled;
56
57
58
59
60 public class DiscoveryServiceHandler {
61
62
63
64
65
66
67
68 private static final String SHIRE_PARAM_NAME = "shire";
69
70
71
72 private static final String TARGET_PARAM_NAME = "target";
73
74
75
76 private static final String TIME_PARAM_NAME = "time";
77
78
79
80 private static final String PROVIDERID_PARAM_NAME = "providerId";
81
82
83
84
85
86
87
88 private static final String ENTITYID_PARAM_NAME = "entityID";
89
90
91
92 private static final String RETURN_PARAM_NAME = "return";
93
94
95
96 private static final String RETURN_ATTRIBUTE_NAME = "returnX";
97
98
99
100 private static final String RETURN_INDEX_NAME = "returnIndex";
101
102
103
104
105 private static final String RETURNID_PARAM_NAME = "returnIDParam";
106
107
108
109
110 private static final String RETURNID_DEFAULT_VALUE = "entityID";
111
112
113
114 private static final String ISPASSIVE_PARAM_NAME = "isPassive";
115
116
117
118
119 private static final String POLICY_PARAM_NAME = "policy";
120
121
122
123
124 private static final String KNOWN_POLICY_NAME
125 = "urn:oasis:names:tc:SAML:profiles:SSO:idp-discoveryprotocol:single";
126
127
128
129
130 private static final Logger LOG = LoggerFactory.getLogger(DiscoveryServiceHandler.class.getName());
131
132
133
134
135 private final String location;
136
137
138
139
140 private final boolean isDefault;
141
142
143
144
145 private final HandlerConfig config;
146
147
148
149
150 private final List <IdPSiteSet> siteSets;
151
152
153
154
155 private final List <Plugin> plugins;
156
157
158
159
160
161
162
163
164
165 protected DiscoveryServiceHandler(Element config,
166 Hashtable <String, IdPSiteSet> federations,
167 Hashtable <String, Plugin> plugins,
168 HandlerConfig defaultConfig) throws ShibbolethConfigurationException
169 {
170 siteSets = new ArrayList <IdPSiteSet>(federations.size());
171 this.plugins = new ArrayList <Plugin>(plugins.size());
172
173
174
175
176
177 this.config = new HandlerConfig(config, defaultConfig);
178
179 location = config.getAttribute("location");
180
181 if (location == null || location.equals("")) {
182
183 LOG.error("DiscoveryService must have a location specified");
184 throw new ShibbolethConfigurationException("DiscoveryService must have a location specified");
185 }
186
187
188
189
190
191 String attribute = config.getAttribute("default");
192 if (attribute != null && !attribute.equals("")) {
193 isDefault = Boolean.valueOf(attribute).booleanValue();
194 } else {
195 isDefault = true;
196 }
197
198
199
200
201
202 NodeList list = config.getElementsByTagName("Federation");
203
204 for (int i = 0; i < list.getLength(); i++ ) {
205
206 attribute = ((Element) list.item(i)).getAttribute("identifier");
207
208 IdPSiteSet siteset = federations.get(attribute);
209
210 if (siteset == null) {
211 LOG.error("Handler " + location + ": could not find metadata for <Federation> with identifier " + attribute + ".");
212 throw new ShibbolethConfigurationException(
213 "Handler " + location + ": could not find metadata for <Federation> identifier " + attribute + ".");
214 }
215
216 siteSets.add(siteset);
217 }
218
219 if (siteSets.size() == 0) {
220
221
222
223 siteSets.addAll(federations.values());
224 }
225
226
227
228
229
230 list = config.getElementsByTagName("PluginInstance");
231
232 for (int i = 0; i < list.getLength(); i++ ) {
233
234 attribute = ((Element) list.item(i)).getAttribute("identifier");
235
236 Plugin plugin = plugins.get(attribute);
237
238 if (plugin == null) {
239 LOG.error("Handler " + location + ": could not find plugin for identifier " + attribute);
240 throw new ShibbolethConfigurationException(
241 "Handler " + location + ": could not find plugin for identifier " + attribute);
242 }
243
244 this.plugins.add(plugin);
245 }
246
247
248
249
250
251
252
253
254 for (IdPSiteSet site: siteSets) {
255 for (Plugin plugin: this.plugins) {
256 site.addPlugin(plugin);
257 }
258 }
259 }
260
261
262
263
264
265
266
267
268
269
270 protected String getLocation() {
271 return location;
272 }
273
274
275
276
277
278 protected boolean isDefault() {
279 return isDefault;
280 }
281
282
283
284
285
286 public void doGet(HttpServletRequest req, HttpServletResponse res) {
287
288 String policy = req.getParameter(POLICY_PARAM_NAME);
289
290 if (null != policy && !KNOWN_POLICY_NAME.equals(policy)) {
291
292
293
294 LOG.error("Unknown policy " + policy);
295 handleError(req, res, "Unknown policy " + policy);
296 return;
297 }
298
299
300
301
302 String requestType = req.getParameter("action");
303
304 if (requestType == null || requestType.equals("")) {
305 requestType = "lookup";
306 }
307
308 try {
309
310 if (requestType.equals("search")) {
311
312 String parameter = req.getParameter("string");
313 if (parameter != null && parameter.equals("")) {
314 parameter = null;
315 }
316 handleLookup(req, res, parameter);
317
318 } else if (requestType.equals("selection")) {
319
320 handleSelection(req, res);
321 } else {
322 handleLookup(req, res, null);
323 }
324 } catch (WayfException we) {
325 LOG.error("Error processing DS request:", we);
326 handleError(req, res, we.getLocalizedMessage());
327 } catch (WayfRequestHandled we) {
328
329
330
331 }
332
333 }
334
335
336
337
338
339
340
341
342
343 private void handleSelection(HttpServletRequest req,
344 HttpServletResponse res) throws WayfRequestHandled, WayfException
345 {
346
347 String idpName = req.getParameter("origin");
348 LOG.debug("Processing handle selection: " + idpName);
349
350 String sPName = getSPId(req);
351
352 if (idpName == null || idpName.equals("")) {
353 handleLookup(req, res, null);
354 return;
355 }
356
357 if (getValue(req, SHIRE_PARAM_NAME) == null) {
358
359
360
361 setupReturnAddress(sPName, req);
362 }
363
364
365
366 IdPSite site = null;
367
368 for (Plugin plugin:plugins) {
369 for (IdPSiteSet idPSiteSet: siteSets) {
370 PluginMetadataParameter param = idPSiteSet.paramFor(plugin);
371 plugin.selected(req, res, param, idpName);
372 if (site == null && idPSiteSet.containsIdP(idpName)) {
373 site = idPSiteSet.getSite(idpName);
374 }
375 }
376 }
377
378 if (site == null) {
379 handleLookup(req, res, null);
380 } else {
381 forwardRequest(req, res, site);
382 }
383 }
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399 private void setupReturnAddress(String spName, HttpServletRequest req) throws WayfException{
400
401 DiscoveryResponseImpl[] discoveryServices;
402 Set<XMLObject> objects = new HashSet<XMLObject>();
403 String defaultName = null;
404
405 for (IdPSiteSet metadataProvider:siteSets) {
406
407
408
409
410
411 if (metadataProvider.containsSP(spName)) {
412
413
414
415
416
417
418 EntityDescriptor entity = metadataProvider.getEntity(spName);
419 List<RoleDescriptor> roles = entity.getRoleDescriptors();
420
421 for (RoleDescriptor role:roles) {
422
423
424
425
426
427 if (role instanceof SPSSODescriptor) {
428
429
430
431
432
433 Extensions exts = role.getExtensions();
434 if (exts != null) {
435 objects.addAll(exts.getOrderedChildren());
436 }
437 }
438 }
439 }
440 }
441
442
443
444
445
446 discoveryServices = new DiscoveryResponseImpl[objects.size()];
447 int dsCount = 0;
448
449 for (XMLObject obj:objects) {
450 if (obj instanceof DiscoveryResponseImpl) {
451 DiscoveryResponseImpl ds = (DiscoveryResponseImpl) obj;
452 discoveryServices[dsCount++] = ds;
453 if (ds.isDefault() || null == defaultName) {
454 defaultName = ds.getLocation();
455 }
456 }
457 }
458
459
460
461
462
463 String returnName = req.getParameter(RETURN_PARAM_NAME);
464
465 if (returnName == null || returnName.length() == 0) {
466 returnName = getValue(req, RETURN_ATTRIBUTE_NAME);
467 }
468
469
470
471
472
473 String returnIndex = req.getParameter(RETURN_INDEX_NAME);
474
475 if (returnName != null && returnName.length() != 0) {
476
477
478
479 String nameNoParam = returnName;
480 URL providedReturnURL;
481 int index = nameNoParam.indexOf('?');
482 boolean found = false;
483
484 if (index >= 0) {
485 nameNoParam = nameNoParam.substring(0,index);
486 }
487
488 try {
489 providedReturnURL = new URL(nameNoParam);
490 } catch (MalformedURLException e) {
491 throw new WayfException("Couldn't parse provided return name " + nameNoParam, e);
492 }
493
494
495 for (DiscoveryResponseImpl disc: discoveryServices) {
496 if (equalsURL(disc, providedReturnURL)) {
497 found = true;
498 break;
499 }
500 }
501 if (!found) {
502 throw new WayfException("Couldn't find endpoint " + nameNoParam + " in metadata");
503 }
504 } else if (returnIndex != null && returnIndex.length() != 0) {
505
506 int index;
507 try {
508 index = Integer.parseInt(returnIndex);
509 } catch (NumberFormatException e) {
510 throw new WayfException("Couldn't convert " + returnIndex + " into an index");
511 }
512
513
514
515
516 boolean found = false;
517
518 for (DiscoveryResponseImpl disc: discoveryServices) {
519 if (index == disc.getIndex()) {
520 found = true;
521 returnName = disc.getLocation();
522 break;
523 }
524 }
525 if (!found) {
526 throw new WayfException("Couldn't not find endpoint " + returnIndex + "in metadata");
527 }
528 } else {
529
530
531
532 returnName = defaultName;
533 }
534
535
536
537
538 req.setAttribute(RETURN_ATTRIBUTE_NAME, returnName);
539 }
540
541
542
543
544
545
546
547
548
549 private static boolean equalsURL(DiscoveryResponseImpl discovery, URL providedName) {
550
551
552
553
554 if (null == discovery) {
555 return false;
556 }
557
558 URL discoveryName;
559 try {
560 discoveryName = new URL(discovery.getLocation());
561 } catch (MalformedURLException e) {
562
563
564
565 LOG.warn("Found invalid discovery end point : " + discovery.getLocation(), e);
566 return false;
567 }
568
569 return providedName.equals(discoveryName);
570
571 }
572
573
574
575
576
577
578
579
580
581
582 private void handleLookup(HttpServletRequest req,
583 HttpServletResponse res,
584 String searchName) throws WayfException, WayfRequestHandled {
585
586 String shire = getValue(req, SHIRE_PARAM_NAME);
587 String providerId = getSPId(req);
588 boolean twoZeroProtocol = (shire == null);
589 boolean isPassive = (twoZeroProtocol &&
590 "true".equalsIgnoreCase(getValue(req, ISPASSIVE_PARAM_NAME)));
591
592 Collection <IdPSiteSetEntry> siteLists = null;
593 Collection<IdPSite> searchResults = null;
594
595 if (config.getProvideListOfLists()) {
596 siteLists = new ArrayList <IdPSiteSetEntry>(siteSets.size());
597 }
598
599 Collection <IdPSite> sites = null;
600 Comparator<IdPSite> comparator = new IdPSite.Compare(req);
601
602 if (config.getProvideList()) {
603 sites = new TreeSet<IdPSite>(comparator);
604 }
605
606 if (searchName != null && !searchName.equals("")) {
607 searchResults = new TreeSet<IdPSite>(comparator);
608 }
609
610 LOG.debug("Processing Idp Lookup for : " + providerId);
611
612
613
614
615
616
617 PluginContext[] ctx = new PluginContext[plugins.size()];
618 List<IdPSite> hintList = new ArrayList<IdPSite>();
619
620 if (twoZeroProtocol) {
621 setupReturnAddress(providerId, req);
622 }
623
624
625
626
627 try {
628 for (IdPSiteSet metadataProvider:siteSets) {
629
630
631
632
633
634 if (metadataProvider.containsSP(providerId) || !config.getLookupSp()) {
635
636 Collection <IdPSite> search = null;
637
638 if (searchResults != null) {
639 search = new TreeSet<IdPSite>(comparator);
640 }
641
642 Map <String, IdPSite> theseSites = metadataProvider.getIdPSites(searchName, config, search);
643
644
645
646
647 for (int i = 0; i < plugins.size(); i++) {
648
649 Plugin plugin = plugins.get(i);
650
651 if (searchResults == null) {
652
653
654
655 ctx[i] = plugin.lookup(req,
656 res,
657 metadataProvider.paramFor(plugin),
658 theseSites,
659 ctx[i],
660 hintList);
661 } else {
662 ctx[i] = plugin.search(req,
663 res,
664 metadataProvider.paramFor(plugin),
665 searchName,
666 theseSites,
667 ctx[i],
668 searchResults,
669 hintList);
670 }
671 }
672
673 if (null == theseSites || theseSites.isEmpty()) {
674 continue;
675 }
676
677
678
679
680
681
682 Collection<IdPSite> values = new TreeSet<IdPSite>(comparator);
683 if (null != theseSites) {
684 values.addAll(theseSites.values());
685 }
686
687 if (siteLists != null) {
688 siteLists.add(new IdPSiteSetEntry(metadataProvider,values));
689 }
690
691 if (sites != null) {
692 sites.addAll(values);
693 }
694
695 if (searchResults != null) {
696 searchResults.addAll(search);
697 }
698 }
699 }
700
701 if (isPassive) {
702
703
704
705 if (0 != hintList.size()) {
706
707
708
709 forwardRequest(req, res, hintList.get(0));
710 } else {
711 forwardRequest(req, res, null);
712 }
713 return;
714 }
715
716
717
718
719
720
721 if (twoZeroProtocol) {
722
723
724
725 String returnString = (String) req.getAttribute(RETURN_ATTRIBUTE_NAME);
726 if (null == returnString || 0 == returnString.length()) {
727 throw new WayfException("Parameter " + RETURN_PARAM_NAME + " not supplied");
728 }
729
730 String returnId = getValue(req, RETURNID_PARAM_NAME);
731 if (null == returnId || 0 == returnId.length()) {
732 returnId = RETURNID_DEFAULT_VALUE;
733 }
734
735
736
737 req.setAttribute(RETURN_ATTRIBUTE_NAME, returnString);
738 req.setAttribute(RETURNID_PARAM_NAME, returnId);
739 req.setAttribute(ENTITYID_PARAM_NAME, providerId);
740
741 } else {
742 String target = getValue(req, TARGET_PARAM_NAME);
743 if (null == target || 0 == target.length()) {
744 throw new WayfException("Could not extract target from provided parameters");
745 }
746 req.setAttribute(SHIRE_PARAM_NAME, shire);
747 req.setAttribute(TARGET_PARAM_NAME, target);
748 req.setAttribute(PROVIDERID_PARAM_NAME, providerId);
749
750
751
752 req.setAttribute("time", new Long(new Date().getTime() / 1000).toString());
753
754 }
755
756
757
758
759
760 setDisplayLanguage(sites, req);
761 req.setAttribute("sites", sites);
762 if (null != siteLists) {
763 for (IdPSiteSetEntry siteSetEntry:siteLists) {
764 setDisplayLanguage(siteSetEntry.getSites(), req);
765 }
766 }
767
768 req.setAttribute("siteLists", siteLists);
769 req.setAttribute("requestURL", req.getRequestURI().toString());
770
771 if (searchResults != null) {
772 if (searchResults.size() != 0) {
773 setDisplayLanguage(searchResults, req);
774 req.setAttribute("searchresults", searchResults);
775 } else {
776 req.setAttribute("searchResultsEmpty", "true");
777 }
778 }
779
780 if (hintList.size() > 0) {
781 setDisplayLanguage(hintList, req);
782 req.setAttribute("cookieList", hintList);
783 }
784
785 LOG.debug("Displaying WAYF selection page.");
786 RequestDispatcher rd = req.getRequestDispatcher(config.getJspFile());
787
788
789
790
791 rd.forward(req, res);
792 } catch (IOException ioe) {
793 LOG.error("Problem displaying WAYF UI.\n" + ioe.getMessage());
794 throw new WayfException("Problem displaying WAYF UI", ioe);
795 } catch (ServletException se) {
796 LOG.error("Problem displaying WAYF UI.\n" + se.getMessage());
797 throw new WayfException("Problem displaying WAYF UI", se);
798 }
799 }
800
801
802
803
804
805
806
807
808 private void setDisplayLanguage(Collection<IdPSite> sites, HttpServletRequest req) {
809
810 if (null == sites) {
811 return;
812 }
813 Locale locale = req.getLocale();
814 if (null == locale) {
815 Locale.getDefault();
816 }
817 String lang = locale.getLanguage();
818
819 for (IdPSite site : sites) {
820 site.setDisplayLanguage(lang);
821 }
822 }
823
824
825
826
827
828
829
830
831
832
833 public static void forwardRequest(HttpServletRequest req, HttpServletResponse res, IdPSite site)
834 throws WayfException {
835
836 String shire = getValue(req, SHIRE_PARAM_NAME);
837 String providerId = getSPId(req);
838 boolean twoZeroProtocol = (shire == null);
839
840 if (!twoZeroProtocol) {
841 String handleService = site.getAddressForWAYF();
842 if (handleService != null ) {
843
844 String target = getValue(req, TARGET_PARAM_NAME);
845 if (null == target || 0 == target.length()) {
846 throw new WayfException("Could not extract target from provided parameters");
847 }
848
849 LOG.info("Redirecting to selected Handle Service: " + handleService);
850 try {
851 StringBuffer buffer = new StringBuffer(handleService +
852 "?" + TARGET_PARAM_NAME + "=");
853 buffer.append(URLEncoder.encode(target, "UTF-8"));
854 buffer.append("&" + SHIRE_PARAM_NAME + "=");
855 buffer.append(URLEncoder.encode(shire, "UTF-8"));
856 buffer.append("&" + PROVIDERID_PARAM_NAME + "=");
857 buffer.append(URLEncoder.encode(providerId, "UTF-8"));
858
859
860
861
862 buffer.append("&" + TIME_PARAM_NAME + "=");
863 buffer.append(new Long(new Date().getTime() / 1000).toString());
864 res.sendRedirect(buffer.toString());
865 } catch (IOException ioe) {
866
867
868
869 throw new WayfException("Error forwarding to IdP: \n" + ioe.getMessage());
870 }
871 } else {
872 String s = "Error finding to IdP: " + site.getDisplayName(req);
873 LOG.error(s);
874 throw new WayfException(s);
875 }
876 } else {
877 String returnUrl = (String) req.getAttribute(RETURN_ATTRIBUTE_NAME);
878
879 if (null == returnUrl || 0 == returnUrl.length()) {
880 throw new WayfException("Could not find return parameter");
881 }
882 try {
883 returnUrl = URLDecoder.decode(returnUrl, "UTF-8");
884 } catch (UnsupportedEncodingException e) {
885 throw new WayfException("Did not understand parameter ", e);
886 }
887 String redirect;
888 if (site != null) {
889 StringBuffer buffer = new StringBuffer(returnUrl);
890
891
892
893 String returnParam = getValue(req, RETURNID_PARAM_NAME);
894 if (null == returnParam || 0 == returnParam.length()) {
895 returnParam = RETURNID_DEFAULT_VALUE;
896 }
897
898
899
900
901 if (returnUrl.indexOf('?') >= 0) {
902
903
904
905 buffer.append("&" + returnParam + "=");
906 } else {
907
908
909
910 buffer.append("?" + returnParam + "=");
911 }
912 buffer.append(site.getName());
913 redirect = buffer.toString();
914 } else {
915
916
917
918 redirect = returnUrl;
919 }
920
921 LOG.debug("Dispatching to " + redirect);
922
923 try {
924 res.sendRedirect(redirect);
925 } catch (IOException ioe) {
926
927
928
929 throw new WayfException("Error forwarding back to Sp: \n" + ioe.getMessage());
930 }
931 }
932 }
933
934
935
936
937
938
939
940
941
942 private void handleError(HttpServletRequest req, HttpServletResponse res, String message) {
943
944 LOG.debug("Displaying WAYF error page.");
945 req.setAttribute("errorText", message);
946 req.setAttribute("requestURL", req.getRequestURI().toString());
947 RequestDispatcher rd = req.getRequestDispatcher(config.getErrorJspFile());
948
949 try {
950 rd.forward(req, res);
951 } catch (IOException ioe) {
952 LOG.error("Problem trying to display WAYF error page: " + ioe.toString());
953 } catch (ServletException se) {
954 LOG.error("Problem trying to display WAYF error page: " + se.toString());
955 }
956 }
957
958
959
960
961
962
963
964 private static String getValue(HttpServletRequest req, String name) {
965
966
967 String value = req.getParameter(name);
968 if (value != null) {
969 return value;
970 }
971 return (String) req.getAttribute(name);
972 }
973
974 private static String getSPId(HttpServletRequest req) throws WayfException {
975
976
977
978
979 String param = req.getParameter(ENTITYID_PARAM_NAME);
980 if (param != null && !(param.length() == 0)) {
981 return param;
982 }
983
984 param = (String) req.getAttribute(ENTITYID_PARAM_NAME);
985 if (param != null && !(param.length() == 0)) {
986 return param;
987 }
988
989
990
991 param = req.getParameter(PROVIDERID_PARAM_NAME);
992 if (param != null && !(param.length() == 0)) {
993 return param;
994 }
995
996 param = (String) req.getAttribute(PROVIDERID_PARAM_NAME);
997 if (param != null && !(param.length() == 0)) {
998 return param;
999 }
1000 throw new WayfException("Could not locate SP identifier in parameters");
1001 }
1002 }